mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix paraformer-en model python onnx postprocess
This commit is contained in:
parent
d90de51e76
commit
db149dd897
@ -14,7 +14,8 @@ import numpy as np
|
||||
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
|
||||
OrtInferSession, TokenIDConverter, get_logger,
|
||||
read_yaml)
|
||||
from .utils.postprocess_utils import sentence_postprocess
|
||||
from .utils.postprocess_utils import (sentence_postprocess,
|
||||
sentence_postprocess_sentencepiece)
|
||||
from .utils.frontend import WavFrontend
|
||||
from .utils.timestamp_utils import time_stamp_lfr6_onnx
|
||||
from .utils.utils import pad_list, make_pad_mask
|
||||
@ -86,6 +87,10 @@ class Paraformer():
|
||||
self.pred_bias = config['model_conf']['predictor_bias']
|
||||
else:
|
||||
self.pred_bias = 0
|
||||
if "lang" in config:
|
||||
self.language = config['lang']
|
||||
else:
|
||||
self.language = None
|
||||
|
||||
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
|
||||
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
||||
@ -111,7 +116,10 @@ class Paraformer():
|
||||
preds = self.decode(am_scores, valid_token_lens)
|
||||
if us_peaks is None:
|
||||
for pred in preds:
|
||||
pred = sentence_postprocess(pred)
|
||||
if self.language == "en-bpe":
|
||||
pred = sentence_postprocess_sentencepiece(pred)
|
||||
else:
|
||||
pred = sentence_postprocess(pred)
|
||||
asr_res.append({'preds': pred})
|
||||
else:
|
||||
for pred, us_peaks_ in zip(preds, us_peaks):
|
||||
|
||||
@ -240,3 +240,54 @@ def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
|
||||
real_word_lists.append(ch)
|
||||
sentence = ''.join(word_lists).strip()
|
||||
return sentence, real_word_lists
|
||||
|
||||
def sentence_postprocess_sentencepiece(words):
|
||||
middle_lists = []
|
||||
word_lists = []
|
||||
word_item = ''
|
||||
|
||||
# wash words lists
|
||||
for i in words:
|
||||
word = ''
|
||||
if isinstance(i, str):
|
||||
word = i
|
||||
else:
|
||||
word = i.decode('utf-8')
|
||||
|
||||
if word in ['<s>', '</s>', '<unk>', '<OOV>']:
|
||||
continue
|
||||
else:
|
||||
middle_lists.append(word)
|
||||
|
||||
# all alpha characters
|
||||
for i, ch in enumerate(middle_lists):
|
||||
word = ''
|
||||
if '\u2581' in ch and i == 0:
|
||||
word_item = ''
|
||||
word = ch.replace('\u2581', '')
|
||||
word_item += word
|
||||
elif '\u2581' in ch and i != 0:
|
||||
word_lists.append(word_item)
|
||||
word_lists.append(' ')
|
||||
word_item = ''
|
||||
word = ch.replace('\u2581', '')
|
||||
word_item += word
|
||||
else:
|
||||
word_item += ch
|
||||
if word_item is not None:
|
||||
word_lists.append(word_item)
|
||||
#word_lists = abbr_dispose(word_lists)
|
||||
real_word_lists = []
|
||||
for ch in word_lists:
|
||||
if ch != ' ':
|
||||
if ch == "i":
|
||||
ch = ch.replace("i", "I")
|
||||
elif ch == "i'm":
|
||||
ch = ch.replace("i'm", "I'm")
|
||||
elif ch == "i've":
|
||||
ch = ch.replace("i've", "I've")
|
||||
elif ch == "i'll":
|
||||
ch = ch.replace("i'll", "I'll")
|
||||
real_word_lists.append(ch)
|
||||
sentence = ''.join(word_lists)
|
||||
return sentence, real_word_lists
|
||||
Loading…
Reference in New Issue
Block a user