diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py index 7b13654d4..c4c558ea6 100644 --- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py +++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py @@ -14,7 +14,8 @@ import numpy as np from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, OrtInferSession, TokenIDConverter, get_logger, read_yaml) -from .utils.postprocess_utils import sentence_postprocess +from .utils.postprocess_utils import (sentence_postprocess, + sentence_postprocess_sentencepiece) from .utils.frontend import WavFrontend from .utils.timestamp_utils import time_stamp_lfr6_onnx from .utils.utils import pad_list, make_pad_mask @@ -86,6 +87,10 @@ class Paraformer(): self.pred_bias = config['model_conf']['predictor_bias'] else: self.pred_bias = 0 + if "lang" in config: + self.language = config['lang'] + else: + self.language = None def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List: waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) @@ -111,7 +116,10 @@ class Paraformer(): preds = self.decode(am_scores, valid_token_lens) if us_peaks is None: for pred in preds: - pred = sentence_postprocess(pred) + if self.language == "en-bpe": + pred = sentence_postprocess_sentencepiece(pred) + else: + pred = sentence_postprocess(pred) asr_res.append({'preds': pred}) else: for pred, us_peaks_ in zip(preds, us_peaks): diff --git a/runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py b/runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py index c005fc985..14d6c7687 100644 --- a/runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py +++ b/runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py @@ -240,3 +240,54 @@ def sentence_postprocess(words: List[Any], time_stamp: List[List] = None): real_word_lists.append(ch) sentence = ''.join(word_lists).strip() return sentence, real_word_lists + +def sentence_postprocess_sentencepiece(words): + middle_lists = [] + word_lists = [] + word_item = '' + + # wash words lists + for i in words: + word = '' + if isinstance(i, str): + word = i + else: + word = i.decode('utf-8') + + if word in ['', '', '', '']: + continue + else: + middle_lists.append(word) + + # all alpha characters + for i, ch in enumerate(middle_lists): + word = '' + if '\u2581' in ch and i == 0: + word_item = '' + word = ch.replace('\u2581', '') + word_item += word + elif '\u2581' in ch and i != 0: + word_lists.append(word_item) + word_lists.append(' ') + word_item = '' + word = ch.replace('\u2581', '') + word_item += word + else: + word_item += ch + if word_item is not None: + word_lists.append(word_item) + #word_lists = abbr_dispose(word_lists) + real_word_lists = [] + for ch in word_lists: + if ch != ' ': + if ch == "i": + ch = ch.replace("i", "I") + elif ch == "i'm": + ch = ch.replace("i'm", "I'm") + elif ch == "i've": + ch = ch.replace("i've", "I've") + elif ch == "i'll": + ch = ch.replace("i'll", "I'll") + real_word_lists.append(ch) + sentence = ''.join(word_lists) + return sentence, real_word_lists \ No newline at end of file