update proc for oov in hotword onnx inference

This commit is contained in:
shixian.shi 2023-09-12 19:54:10 +08:00
parent b883a78256
commit 9c622feb64
2 changed files with 9 additions and 2 deletions

View File

@ -5,7 +5,7 @@ model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-
model = ContextualParaformer(model_dir, batch_size=1)
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
hotwords = '随机热词 各种热词 魔搭 阿里巴巴'
hotwords = '随机热词 各种热词 魔搭 阿里巴巴'
result = model(wav_path, hotwords)
print(result)

View File

@ -314,7 +314,14 @@ class ContextualParaformer(Paraformer):
hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
# hotwords.append('<s>')
def word_map(word):
return torch.tensor([self.vocab[i] for i in word])
hotwords = []
for c in word:
if c not in self.vocab.keys():
hotwords.append(8403)
logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
else:
hotwords.append(self.vocab[c])
return torch.tensor(hotwords)
hotword_int = [word_map(i) for i in hotwords]
# import pdb; pdb.set_trace()
hotword_int.append(torch.tensor([1]))