# -*- encoding: utf-8 -*- import os.path from pathlib import Path from typing import List, Union, Tuple import numpy as np from .utils.utils import (ONNXRuntimeError, OrtInferSession, get_logger, read_yaml) from .utils.preprocessor import CodeMixTokenizerCommonPreprocessor from .utils.utils import split_to_mini_sentence logging = get_logger() class TargetDelayTransformer(): def __init__(self, model_dir: Union[str, Path] = None, batch_size: int = 1, device_id: Union[str, int] = "-1", quantize: bool = False, intra_op_num_threads: int = 4 ): if not Path(model_dir).exists(): raise FileNotFoundError(f'{model_dir} does not exist.') model_file = os.path.join(model_dir, 'model.onnx') if quantize: model_file = os.path.join(model_dir, 'model_quant.onnx') config_file = os.path.join(model_dir, 'punc.yaml') config = read_yaml(config_file) self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) self.batch_size = 1 self.punc_list = config['punc_list'] self.period = 0 for i in range(len(self.punc_list)): if self.punc_list[i] == ",": self.punc_list[i] = "," elif self.punc_list[i] == "?": self.punc_list[i] = "?" elif self.punc_list[i] == "。": self.period = i self.preprocessor = CodeMixTokenizerCommonPreprocessor( train=False, token_type=config['token_type'], token_list=config['token_list'], bpemodel=config['bpemodel'], text_cleaner=config['cleaner'], g2p_type=config['g2p'], text_name="text", non_linguistic_symbols=config['non_linguistic_symbols'], ) def __call__(self, text: Union[list, str], split_size=20): data = {"text": text} result = self.preprocessor(data=data, uid="12938712838719") split_text = self.preprocessor.pop_split_text_data(result) mini_sentences = split_to_mini_sentence(split_text, split_size) mini_sentences_id = split_to_mini_sentence(data["text"], split_size) assert len(mini_sentences) == len(mini_sentences_id) cache_sent = [] cache_sent_id = [] new_mini_sentence = "" new_mini_sentence_punc = [] cache_pop_trigger_limit = 200 for mini_sentence_i in range(len(mini_sentences)): mini_sentence = mini_sentences[mini_sentence_i] mini_sentence_id = mini_sentences_id[mini_sentence_i] mini_sentence = cache_sent + mini_sentence mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0) data = { "text": mini_sentence_id[None,:].astype(np.int64), "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'), } try: outputs = self.infer(data['text'], data['text_lengths']) y = outputs[0] punctuations = np.argmax(y,axis=-1)[0] assert punctuations.size == len(mini_sentence) except ONNXRuntimeError: logging.warning("error") # Search for the last Period/QuestionMark as cache if mini_sentence_i < len(mini_sentences) - 1: sentenceEnd = -1 last_comma_index = -1 for i in range(len(punctuations) - 2, 1, -1): if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?": sentenceEnd = i break if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",": last_comma_index = i if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0: # The sentence it too long, cut off at a comma. sentenceEnd = last_comma_index punctuations[sentenceEnd] = self.period cache_sent = mini_sentence[sentenceEnd + 1:] cache_sent_id = mini_sentence_id[sentenceEnd + 1:] mini_sentence = mini_sentence[0:sentenceEnd + 1] punctuations = punctuations[0:sentenceEnd + 1] punctuations_np = punctuations.cpu().numpy() new_mini_sentence_punc += [int(x) for x in punctuations_np] words_with_punc = [] for i in range(len(mini_sentence)): if i > 0: if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1: mini_sentence[i] = " " + mini_sentence[i] words_with_punc.append(mini_sentence[i]) if self.punc_list[punctuations[i]] != "_": words_with_punc.append(self.punc_list[punctuations[i]]) new_mini_sentence += "".join(words_with_punc) # Add Period for the end of the sentence new_mini_sentence_out = new_mini_sentence new_mini_sentence_punc_out = new_mini_sentence_punc if mini_sentence_i == len(mini_sentences) - 1: if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、": new_mini_sentence_out = new_mini_sentence[:-1] + "。" new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?": new_mini_sentence_out = new_mini_sentence + "。" new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] return new_mini_sentence_out, new_mini_sentence_punc_out def infer(self, feats: np.ndarray, feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: outputs = self.ort_infer([feats, feats_len]) return outputs