|
import onnx |
|
from typing import Optional, Union |
|
from pathlib import Path |
|
import os |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
MAXIMUM_PROTOBUF = 2147483648 |
|
|
|
|
|
def strict_check_model(model_or_path: Union[onnx.ModelProto, str, Path]): |
|
try: |
|
onnx.checker.check_model(model_or_path, full_check=True) |
|
except Exception as e: |
|
if "No Op registered for" in str(e): |
|
pass |
|
else: |
|
raise e |
|
|
|
|
|
def check_and_save_model(model: onnx.ModelProto, save_path: Optional[Union[str, Path]]): |
|
if model.ByteSize() < MAXIMUM_PROTOBUF: |
|
strict_check_model(model) |
|
if save_path: |
|
save_path = Path(save_path).as_posix() |
|
external_file_name = os.path.basename(save_path) + "_data" |
|
external_path = os.path.join(os.path.dirname(save_path), external_file_name) |
|
|
|
if save_path.endswith(".onnx") and os.path.isfile(save_path): |
|
os.remove(save_path) |
|
if os.path.isfile(external_path): |
|
os.remove(external_path) |
|
|
|
onnx.save( |
|
model, |
|
save_path, |
|
convert_attribute=True, |
|
) |
|
elif save_path is not None: |
|
|
|
save_path = Path(save_path).as_posix() |
|
|
|
external_file_name = os.path.basename(save_path) + "_data" |
|
|
|
external_path = os.path.join(os.path.dirname(save_path), external_file_name) |
|
|
|
if save_path.endswith(".onnx") and os.path.isfile(save_path): |
|
os.remove(save_path) |
|
if os.path.isfile(external_path): |
|
os.remove(external_path) |
|
|
|
onnx.save( |
|
model, |
|
save_path, |
|
save_as_external_data=True, |
|
all_tensors_to_one_file=True, |
|
location=external_file_name, |
|
convert_attribute=True, |
|
) |
|
|
|
else: |
|
logger.info( |
|
"Merged ONNX model exceeds 2GB, the model will not be checked without `save_path` given." |
|
) |
|
|