From 0fd9640ced9c8ae9af43e5300068a8837d8ce26e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=9D=E8=80=B3?= Date: Thu, 30 Mar 2023 16:48:55 +0800 Subject: [PATCH] fix --- .../onnxruntime/funasr_onnx/punc_bin.py | 25 ++++++------------- .../onnxruntime/funasr_onnx/utils/utils.py | 19 ++++++++++++++ 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py index 3f649bcdb..e1f35f207 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py @@ -8,8 +8,7 @@ import numpy as np from .utils.utils import (ONNXRuntimeError, OrtInferSession, get_logger, read_yaml) -from .utils.preprocessor import CodeMixTokenizerCommonPreprocessor -from .utils.utils import split_to_mini_sentence +from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words) logging = get_logger() @@ -30,6 +29,7 @@ class TargetDelayTransformer(): config_file = os.path.join(model_dir, 'punc.yaml') config = read_yaml(config_file) + self.converter = TokenIDConverter(config['token_list']) self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) self.batch_size = 1 self.punc_list = config['punc_list'] @@ -41,23 +41,12 @@ class TargetDelayTransformer(): self.punc_list[i] = "?" elif self.punc_list[i] == "。": self.period = i - self.preprocessor = CodeMixTokenizerCommonPreprocessor( - train=False, - token_type=config['token_type'], - token_list=config['token_list'], - bpemodel=config['bpemodel'], - text_cleaner=config['cleaner'], - g2p_type=config['g2p'], - text_name="text", - non_linguistic_symbols=config['non_linguistic_symbols'], - ) def __call__(self, text: Union[list, str], split_size=20): - data = {"text": text} - result = self.preprocessor(data=data, uid="12938712838719") - split_text = self.preprocessor.pop_split_text_data(result) + split_text = code_mix_split_words(text) + split_text_id = self.converter.tokens2ids(split_text) mini_sentences = split_to_mini_sentence(split_text, split_size) - mini_sentences_id = split_to_mini_sentence(data["text"], split_size) + mini_sentences_id = split_to_mini_sentence(split_text_id, split_size) assert len(mini_sentences) == len(mini_sentences_id) cache_sent = [] cache_sent_id = [] @@ -68,9 +57,9 @@ class TargetDelayTransformer(): mini_sentence = mini_sentences[mini_sentence_i] mini_sentence_id = mini_sentences_id[mini_sentence_i] mini_sentence = cache_sent + mini_sentence - mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0) + mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64') data = { - "text": mini_sentence_id[None,:].astype(np.int64), + "text": mini_sentence_id[None,:], "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'), } try: diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py index 63bc0e46f..0df954ed7 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py @@ -228,6 +228,25 @@ def split_to_mini_sentence(words: list, word_limit: int = 20): sentences.append(words[sentence_len * word_limit:]) return sentences +def code_mix_split_words(text: str): + words = [] + segs = text.split() + for seg in segs: + # There is no space in seg. + current_word = "" + for c in seg: + if len(c.encode()) == 1: + # This is an ASCII char. + current_word += c + else: + # This is a Chinese char. + if len(current_word) > 0: + words.append(current_word) + current_word = "" + words.append(c) + if len(current_word) > 0: + words.append(current_word) + return words def read_yaml(yaml_path: Union[str, Path]) -> Dict: if not Path(yaml_path).exists():