Merge pull request #939 from alibaba-damo-academy/dev_sxfix

Bug fix
This commit is contained in:
Xian Shi 2023-09-12 19:56:29 +08:00 committed by GitHub
commit 57ccdf04e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 7 deletions

View File

@ -415,7 +415,7 @@ def inference_paraformer(
ibest_writer["rtf"][key] = rtf_cur
if text is not None:
if use_timestamp and timestamp is not None:
if use_timestamp and timestamp is not None and len(timestamp):
postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
else:
postprocessed_result = postprocess_utils.sentence_postprocess(token)
@ -427,7 +427,7 @@ def inference_paraformer(
else:
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
item = {'key': key, 'value': text_postprocessed}
if timestamp_postprocessed != "":
if timestamp_postprocessed != "" or len(timestamp) == 0:
item['timestamp'] = timestamp_postprocessed
asr_result_list.append(item)
finish_count += 1
@ -692,7 +692,7 @@ def inference_paraformer_vad_punc(
text, token, token_int = result[0], result[1], result[2]
time_stamp = result[4] if len(result[4]) > 0 else None
if use_timestamp and time_stamp is not None:
if use_timestamp and time_stamp is not None and len(time_stamp):
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
else:
postprocessed_result = postprocess_utils.sentence_postprocess(token)
@ -717,7 +717,7 @@ def inference_paraformer_vad_punc(
item = {'key': key, 'value': text_postprocessed_punc}
if text_postprocessed != "":
item['text_postprocessed'] = text_postprocessed
if time_stamp_postprocessed != "":
if time_stamp_postprocessed != "" or len(time_stamp) == 0:
item['time_stamp'] = time_stamp_postprocessed
item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)

View File

@ -254,7 +254,7 @@ class ModelExport:
if not os.path.exists(quant_model_path):
onnx_model = onnx.load(model_path)
nodes = [n.name for n in onnx_model.graph.node]
nodes_to_exclude = [m for m in nodes if 'output' in m]
nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m or 'bias_decoder' in m]
quantize_dynamic(
model_input=model_path,
model_output=quant_model_path,

View File

@ -5,7 +5,7 @@ model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-
model = ContextualParaformer(model_dir, batch_size=1)
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
hotwords = '随机热词 各种热词 魔搭 阿里巴巴'
hotwords = '随机热词 各种热词 魔搭 阿里巴巴'
result = model(wav_path, hotwords)
print(result)

View File

@ -314,7 +314,14 @@ class ContextualParaformer(Paraformer):
hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
# hotwords.append('<s>')
def word_map(word):
return torch.tensor([self.vocab[i] for i in word])
hotwords = []
for c in word:
if c not in self.vocab.keys():
hotwords.append(8403)
logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
else:
hotwords.append(self.vocab[c])
return torch.tensor(hotwords)
hotword_int = [word_map(i) for i in hotwords]
# import pdb; pdb.set_trace()
hotword_int.append(torch.tensor([1]))