File size: 2,071 Bytes
a1edf95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import onnx
from typing import Optional, Union
from pathlib import Path
import os
import logging

logger = logging.getLogger(__name__)


# https://github.com/onnx/onnx/pull/6556
MAXIMUM_PROTOBUF = 2147483648  # 2GiB


def strict_check_model(model_or_path: Union[onnx.ModelProto, str, Path]):
    try:
        onnx.checker.check_model(model_or_path, full_check=True)
    except Exception as e:
        if "No Op registered for" in str(e):
            pass
        else:
            raise e


def check_and_save_model(model: onnx.ModelProto, save_path: Optional[Union[str, Path]]):
    if model.ByteSize() < MAXIMUM_PROTOBUF:
        strict_check_model(model)
        if save_path:
            save_path = Path(save_path).as_posix()
            external_file_name = os.path.basename(save_path) + "_data"
            external_path = os.path.join(os.path.dirname(save_path), external_file_name)

            if save_path.endswith(".onnx") and os.path.isfile(save_path):
                os.remove(save_path)
            if os.path.isfile(external_path):
                os.remove(external_path)

            onnx.save(
                model,
                save_path,
                convert_attribute=True,
            )
    elif save_path is not None:
        # path/to/model.onnx
        save_path = Path(save_path).as_posix()

        external_file_name = os.path.basename(save_path) + "_data"
        # path/to/model.onnx_data
        external_path = os.path.join(os.path.dirname(save_path), external_file_name)

        if save_path.endswith(".onnx") and os.path.isfile(save_path):
            os.remove(save_path)
        if os.path.isfile(external_path):
            os.remove(external_path)

        onnx.save(
            model,
            save_path,
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            location=external_file_name,
            convert_attribute=True,
        )

    else:
        logger.info(
            "Merged ONNX model exceeds 2GB, the model will not be checked without `save_path` given."
        )