mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update model class
This commit is contained in:
parent
fd0992af3d
commit
d2e9bf0142
@ -7,7 +7,7 @@ device_id = 0 if torch.cuda.is_available() else -1
|
|||||||
model = ContextualParaformer(model_dir, batch_size=1, device_id=device_id) # gpu
|
model = ContextualParaformer(model_dir, batch_size=1, device_id=device_id) # gpu
|
||||||
|
|
||||||
wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
|
wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
|
||||||
hotwords = "你的热词 魔搭 达摩苑"
|
hotwords = "你的热词 魔哒"
|
||||||
|
|
||||||
result = model(wav_path, hotwords)
|
result = model(wav_path, hotwords)
|
||||||
print(result)
|
print(result)
|
||||||
|
|||||||
@ -282,7 +282,7 @@ class ContextualParaformer(Paraformer):
|
|||||||
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
|
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
|
||||||
|
|
||||||
model = AutoModel(model=model_dir)
|
model = AutoModel(model=model_dir)
|
||||||
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
|
model_dir = model.export(type="torchscripts", quantize=quantize, **kwargs)
|
||||||
|
|
||||||
config_file = os.path.join(model_dir, "config.yaml")
|
config_file = os.path.join(model_dir, "config.yaml")
|
||||||
cmvn_file = os.path.join(model_dir, "am.mvn")
|
cmvn_file = os.path.join(model_dir, "am.mvn")
|
||||||
@ -316,9 +316,9 @@ class ContextualParaformer(Paraformer):
|
|||||||
) -> List:
|
) -> List:
|
||||||
# make hotword list
|
# make hotword list
|
||||||
hotwords, hotwords_length = self.proc_hotword(hotwords)
|
hotwords, hotwords_length = self.proc_hotword(hotwords)
|
||||||
[bias_embed] = self.eb_infer(torch.Tensor(hotwords), torch.Tensor(hotwords_length))
|
bias_embed = self.eb_infer(torch.Tensor(hotwords))
|
||||||
# index from bias_embed
|
# index from bias_embed
|
||||||
bias_embed = bias_embed.transpose(1, 0, 2)
|
bias_embed = torch.transpose(bias_embed, 0, 1)
|
||||||
_ind = np.arange(0, len(hotwords)).tolist()
|
_ind = np.arange(0, len(hotwords)).tolist()
|
||||||
bias_embed = bias_embed[_ind, hotwords_length.tolist()]
|
bias_embed = bias_embed[_ind, hotwords_length.tolist()]
|
||||||
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
||||||
@ -327,15 +327,14 @@ class ContextualParaformer(Paraformer):
|
|||||||
for beg_idx in range(0, waveform_nums, self.batch_size):
|
for beg_idx in range(0, waveform_nums, self.batch_size):
|
||||||
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
||||||
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
||||||
bias_embed = np.expand_dims(bias_embed, axis=0)
|
bias_embed = torch.unsqueeze(bias_embed, 0).repeat(feats.shape[0], 1, 1)
|
||||||
bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
|
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if int(self.device_id) == -1:
|
if int(self.device_id) == -1:
|
||||||
outputs = self.bb_infer(feats, feats_len)
|
outputs = self.bb_infer(feats, feats_len, bias_embed)
|
||||||
am_scores, valid_token_lens = outputs[0], outputs[1]
|
am_scores, valid_token_lens = outputs[0], outputs[1]
|
||||||
else:
|
else:
|
||||||
outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda())
|
outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda(), bias_embed)
|
||||||
am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
|
am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
|
||||||
except:
|
except:
|
||||||
# logging.warning(traceback.format_exc())
|
# logging.warning(traceback.format_exc())
|
||||||
@ -374,12 +373,12 @@ class ContextualParaformer(Paraformer):
|
|||||||
|
|
||||||
def bb_infer(
|
def bb_infer(
|
||||||
self, feats, feats_len, bias_embed
|
self, feats, feats_len, bias_embed
|
||||||
) -> Tuple[np.ndarray, np.ndarray]:
|
):
|
||||||
outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
|
outputs = self.ort_infer_bb(feats, feats_len, bias_embed)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def eb_infer(self, hotwords, hotwords_length):
|
def eb_infer(self, hotwords):
|
||||||
outputs = self.ort_infer_eb([hotwords, hotwords_length])
|
outputs = self.ort_infer_eb(hotwords.long())
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
|
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user