|
from typing import Any, Dict, List, Tuple, Union |
|
import os |
|
import platform |
|
import traceback |
|
from enum import Enum |
|
from pathlib import Path |
|
import numpy as np |
|
from onnxruntime import ( |
|
GraphOptimizationLevel, |
|
InferenceSession, |
|
SessionOptions, |
|
get_available_providers, |
|
get_device, |
|
) |
|
|
|
|
|
class EP(Enum): |
|
CPU_EP = "CPUExecutionProvider" |
|
CUDA_EP = "CUDAExecutionProvider" |
|
DIRECTML_EP = "DmlExecutionProvider" |
|
|
|
|
|
class OrtInferSession: |
|
def __init__(self, config: Dict[str, Any]): |
|
self.model_path = config.get("model_path", None) |
|
self._verify_model(self.model_path) |
|
self.config = config |
|
self.cfg_use_cuda = config.get("use_cuda", None) |
|
self.cfg_use_dml = config.get("use_dml", None) |
|
|
|
self.had_providers: List[str] = get_available_providers() |
|
self.EP_list = self._get_ep_list() |
|
|
|
self.sess_opt = self._init_sess_opts(self.config) |
|
self.session = InferenceSession( |
|
self.model_path, |
|
sess_options=self.sess_opt, |
|
providers=self.EP_list, |
|
) |
|
self._verify_providers() |
|
|
|
@staticmethod |
|
def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: |
|
sess_opt = SessionOptions() |
|
sess_opt.log_severity_level = 4 |
|
sess_opt.enable_cpu_mem_arena = False |
|
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
|
cpu_nums = os.cpu_count() |
|
intra_op_num_threads = config.get("intra_op_num_threads", -1) |
|
if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: |
|
sess_opt.intra_op_num_threads = intra_op_num_threads |
|
|
|
inter_op_num_threads = config.get("inter_op_num_threads", -1) |
|
if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: |
|
sess_opt.inter_op_num_threads = inter_op_num_threads |
|
|
|
return sess_opt |
|
|
|
def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: |
|
cpu_provider_opts = { |
|
"arena_extend_strategy": "kSameAsRequested", |
|
} |
|
EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] |
|
|
|
cuda_provider_opts = { |
|
"device_id": 0, |
|
"arena_extend_strategy": "kNextPowerOfTwo", |
|
"cudnn_conv_algo_search": "EXHAUSTIVE", |
|
"do_copy_in_default_stream": True, |
|
} |
|
self.use_cuda = self._check_cuda() |
|
if self.use_cuda: |
|
EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) |
|
|
|
self.use_directml = self._check_dml() |
|
if self.use_directml: |
|
directml_options = ( |
|
cuda_provider_opts if self.use_cuda else cpu_provider_opts |
|
) |
|
EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) |
|
return EP_list |
|
|
|
def _check_cuda(self) -> bool: |
|
if not self.cfg_use_cuda: |
|
return False |
|
|
|
cur_device = get_device() |
|
if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: |
|
return True |
|
return False |
|
|
|
def _check_dml(self) -> bool: |
|
if not self.cfg_use_dml: |
|
return False |
|
|
|
cur_os = platform.system() |
|
if cur_os != "Windows": |
|
return False |
|
|
|
cur_window_version = int(platform.release().split(".")[0]) |
|
if cur_window_version < 10: |
|
return False |
|
|
|
if EP.DIRECTML_EP.value in self.had_providers: |
|
return True |
|
return False |
|
|
|
def _verify_providers(self): |
|
session_providers = self.session.get_providers() |
|
first_provider = session_providers[0] |
|
|
|
def __call__(self, input_content: np.ndarray) -> np.ndarray: |
|
try: |
|
if not self.session: |
|
self.session = InferenceSession( |
|
self.model_path, |
|
sess_options=self.sess_opt, |
|
providers=self.EP_list, |
|
) |
|
self._verify_providers() |
|
input_dict = dict(zip(self.get_input_names(), [input_content])) |
|
res = self.session.run(self.get_output_names(), input_dict) |
|
return res |
|
except Exception as e: |
|
error_info = traceback.format_exc() |
|
raise ONNXRuntimeError(error_info) from e |
|
finally: |
|
del input_dict |
|
self.session = None |
|
|
|
def get_input_names(self) -> List[str]: |
|
return [v.name for v in self.session.get_inputs()] |
|
|
|
def get_output_names(self) -> List[str]: |
|
return [v.name for v in self.session.get_outputs()] |
|
|
|
def get_character_list(self, key: str = "character") -> List[str]: |
|
meta_dict = self.session.get_modelmeta().custom_metadata_map |
|
return meta_dict[key].splitlines() |
|
|
|
def have_key(self, key: str = "character") -> bool: |
|
meta_dict = self.session.get_modelmeta().custom_metadata_map |
|
if key in meta_dict.keys(): |
|
return True |
|
return False |
|
|
|
@staticmethod |
|
def _verify_model(model_path: Union[str, Path, None]): |
|
if model_path is None: |
|
raise ValueError("model_path is None!") |
|
|
|
model_path = Path(model_path) |
|
if not model_path.exists(): |
|
raise FileNotFoundError(f"{model_path} does not exists.") |
|
|
|
if not model_path.is_file(): |
|
raise FileExistsError(f"{model_path} is not a file.") |
|
|
|
|
|
class ONNXRuntimeError(Exception): |
|
pass |