fix paraformer-en model python onnx postprocess

This commit is contained in:
北念 2023-11-09 11:49:56 +08:00
parent d90de51e76
commit db149dd897
2 changed files with 61 additions and 2 deletions

View File

@ -14,7 +14,8 @@ import numpy as np
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
OrtInferSession, TokenIDConverter, get_logger,
read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.postprocess_utils import (sentence_postprocess,
sentence_postprocess_sentencepiece)
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask
@ -86,6 +87,10 @@ class Paraformer():
self.pred_bias = config['model_conf']['predictor_bias']
else:
self.pred_bias = 0
if "lang" in config:
self.language = config['lang']
else:
self.language = None
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@ -111,7 +116,10 @@ class Paraformer():
preds = self.decode(am_scores, valid_token_lens)
if us_peaks is None:
for pred in preds:
pred = sentence_postprocess(pred)
if self.language == "en-bpe":
pred = sentence_postprocess_sentencepiece(pred)
else:
pred = sentence_postprocess(pred)
asr_res.append({'preds': pred})
else:
for pred, us_peaks_ in zip(preds, us_peaks):

View File

@ -240,3 +240,54 @@ def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
real_word_lists.append(ch)
sentence = ''.join(word_lists).strip()
return sentence, real_word_lists
def sentence_postprocess_sentencepiece(words):
middle_lists = []
word_lists = []
word_item = ''
# wash words lists
for i in words:
word = ''
if isinstance(i, str):
word = i
else:
word = i.decode('utf-8')
if word in ['<s>', '</s>', '<unk>', '<OOV>']:
continue
else:
middle_lists.append(word)
# all alpha characters
for i, ch in enumerate(middle_lists):
word = ''
if '\u2581' in ch and i == 0:
word_item = ''
word = ch.replace('\u2581', '')
word_item += word
elif '\u2581' in ch and i != 0:
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
word = ch.replace('\u2581', '')
word_item += word
else:
word_item += ch
if word_item is not None:
word_lists.append(word_item)
#word_lists = abbr_dispose(word_lists)
real_word_lists = []
for ch in word_lists:
if ch != ' ':
if ch == "i":
ch = ch.replace("i", "I")
elif ch == "i'm":
ch = ch.replace("i'm", "I'm")
elif ch == "i've":
ch = ch.replace("i've", "I've")
elif ch == "i'll":
ch = ch.replace("i'll", "I'll")
real_word_lists.append(ch)
sentence = ''.join(word_lists)
return sentence, real_word_lists