from typing import Any, Optional,Dict, List, Tuple, Union import math import time import cv2 import numpy as np from pathlib import Path from anyocr.infer_engine import OrtInferSession class CTCLabelDecode: def __init__( self, character: Optional[List[str]] = None, character_path: Union[str, Path, None] = None, ): self.character = self.get_character(character, character_path) self.dict = {char: i for i, char in enumerate(self.character)} def __call__( self, preds: np.ndarray, return_word_box: bool = False, **kwargs ) -> List[Tuple[str, float]]: preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode( preds_idx, preds_prob, return_word_box, is_remove_duplicate=True ) if return_word_box: for rec_idx, rec in enumerate(text): wh_ratio = kwargs["wh_ratio_list"][rec_idx] max_wh_ratio = kwargs["max_wh_ratio"] rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio) return text def get_character( self, character: Optional[List[str]] = None, character_path: Union[str, Path, None] = None, ) -> List[str]: if character is None and character_path is None: raise ValueError("character must not be None") character_list = None if character: character_list = character if character_path: character_list = self.read_character_file(character_path) if character_list is None: raise ValueError("character must not be None") character_list = self.insert_special_char( character_list, " ", len(character_list) ) character_list = self.insert_special_char(character_list, "blank", 0) return character_list @staticmethod def read_character_file(character_path: Union[str, Path]) -> List[str]: character_list = [] with open(character_path, "rb") as f: lines = f.readlines() for line in lines: line = line.decode("utf-8").strip("\n").strip("\r\n") character_list.append(line) return character_list @staticmethod def insert_special_char( character_list: List[str], special_char: str, loc: int = -1 ) -> List[str]: character_list.insert(loc, special_char) return character_list def decode( self, text_index: np.ndarray, text_prob: Optional[np.ndarray] = None, return_word_box: bool = False, is_remove_duplicate: bool = False, ) -> List[Tuple[str, float]]: """convert text-index into text-label.""" result_list = [] ignored_tokens = self.get_ignored_tokens() batch_size = len(text_index) for batch_idx in range(batch_size): selection = np.ones(len(text_index[batch_idx]), dtype=bool) if is_remove_duplicate: selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] for ignored_token in ignored_tokens: selection &= text_index[batch_idx] != ignored_token if text_prob is not None: conf_list = np.array(text_prob[batch_idx][selection]).tolist() else: conf_list = [1] * len(selection) if len(conf_list) == 0: conf_list = [0] char_list = [ self.character[text_id] for text_id in text_index[batch_idx][selection] ] text = "".join(char_list) if return_word_box: word_list, word_col_list, state_list = self.get_word_info( text, selection ) result_list.append( ( text, np.mean(conf_list).tolist(), [ len(text_index[batch_idx]), word_list, word_col_list, state_list, conf_list, ], ) ) else: result_list.append((text, np.mean(conf_list).tolist())) return result_list @staticmethod def get_word_info( text: str, selection: np.ndarray ) -> Tuple[List[List[str]], List[List[int]], List[str]]: """ Group the decoded characters and record the corresponding decoded positions. from https://github.com/PaddlePaddle/PaddleOCR/blob/fbba2178d7093f1dffca65a5b963ec277f1a6125/ppocr/postprocess/rec_postprocess.py#L70 Args: text: the decoded text selection: the bool array that identifies which columns of features are decoded as non-separated characters Returns: word_list: list of the grouped words word_col_list: list of decoding positions corresponding to each character in the grouped word state_list: list of marker to identify the type of grouping words, including two types of grouping words: - 'cn': continous chinese characters (e.g., 你好啊) - 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16) """ state = None word_content = [] word_col_content = [] word_list = [] word_col_list = [] state_list = [] valid_col = np.where(selection)[0] col_width = np.zeros(valid_col.shape) if len(valid_col) > 0: col_width[1:] = valid_col[1:] - valid_col[:-1] col_width[0] = min( 3 if "\u4e00" <= text[0] <= "\u9fff" else 2, int(valid_col[0]) ) for c_i, char in enumerate(text): if "\u4e00" <= char <= "\u9fff": c_state = "cn" else: c_state = "en&num" if state is None: state = c_state if state != c_state or col_width[c_i] > 4: if len(word_content) != 0: word_list.append(word_content) word_col_list.append(word_col_content) state_list.append(state) word_content = [] word_col_content = [] state = c_state word_content.append(char) word_col_content.append(int(valid_col[c_i])) if len(word_content) != 0: word_list.append(word_content) word_col_list.append(word_col_content) state_list.append(state) return word_list, word_col_list, state_list @staticmethod def get_ignored_tokens() -> List[int]: return [0] # for ctc blank class TextRecognizer: def __init__(self, config: Dict[str, Any]): self.session = OrtInferSession(config) character = None if self.session.have_key(): character = self.session.get_character_list() character_path = config.get("rec_keys_path", None) self.postprocess_op = CTCLabelDecode( character=character, character_path=character_path ) self.rec_batch_num = config["rec_batch_num"] self.rec_image_shape = config["rec_img_shape"] def __call__( self, img_list: Union[np.ndarray, List[np.ndarray]], return_word_box: bool = False, ) -> Tuple[List[Tuple[str, float]], float]: if isinstance(img_list, np.ndarray): img_list = [img_list] # Calculate the aspect ratio of all text bars width_list = [img.shape[1] / float(img.shape[0]) for img in img_list] # Sorting can speed up the recognition process indices = np.argsort(np.array(width_list)) img_num = len(img_list) rec_res = [("", 0.0)] * img_num batch_num = self.rec_batch_num elapse = 0 for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) # Parameter Alignment for PaddleOCR imgC, imgH, imgW = self.rec_image_shape[:3] max_wh_ratio = imgW / imgH wh_ratio_list = [] for ino in range(beg_img_no, end_img_no): h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) wh_ratio_list.append(wh_ratio) norm_img_batch = [] for ino in range(beg_img_no, end_img_no): norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) norm_img_batch.append(norm_img[np.newaxis, :]) norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32) starttime = time.time() preds = self.session(norm_img_batch)[0] rec_result = self.postprocess_op( preds, return_word_box, wh_ratio_list=wh_ratio_list, max_wh_ratio=max_wh_ratio, ) for rno, one_res in enumerate(rec_result): rec_res[indices[beg_img_no + rno]] = one_res elapse += time.time() - starttime return rec_res, elapse def resize_norm_img(self, img: np.ndarray, max_wh_ratio: float) -> np.ndarray: img_channel, img_height, img_width = self.rec_image_shape assert img_channel == img.shape[2] img_width = int(img_height * max_wh_ratio) h, w = img.shape[:2] ratio = w / float(h) if math.ceil(img_height * ratio) > img_width: resized_w = img_width else: resized_w = int(math.ceil(img_height * ratio)) resized_image = cv2.resize(img, (resized_w, img_height)) resized_image = resized_image.astype("float32") resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 padding_im = np.zeros((img_channel, img_height, img_width), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image return padding_im