mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update proc for oov in hotword onnx inference
This commit is contained in:
parent
b883a78256
commit
9c622feb64
@ -5,7 +5,7 @@ model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-
|
||||
model = ContextualParaformer(model_dir, batch_size=1)
|
||||
|
||||
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
|
||||
hotwords = '随机热词 各种热词 魔搭 阿里巴巴'
|
||||
hotwords = '随机热词 各种热词 魔搭 阿里巴巴 仏'
|
||||
|
||||
result = model(wav_path, hotwords)
|
||||
print(result)
|
||||
|
||||
@ -314,7 +314,14 @@ class ContextualParaformer(Paraformer):
|
||||
hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
|
||||
# hotwords.append('<s>')
|
||||
def word_map(word):
|
||||
return torch.tensor([self.vocab[i] for i in word])
|
||||
hotwords = []
|
||||
for c in word:
|
||||
if c not in self.vocab.keys():
|
||||
hotwords.append(8403)
|
||||
logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
|
||||
else:
|
||||
hotwords.append(self.vocab[c])
|
||||
return torch.tensor(hotwords)
|
||||
hotword_int = [word_map(i) for i in hotwords]
|
||||
# import pdb; pdb.set_trace()
|
||||
hotword_int.append(torch.tensor([1]))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user