mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #1763 from coolEphemeroptera/main
fixed the issues about seaco-onnx timestamp
This commit is contained in:
commit
50b2668019
@ -163,7 +163,11 @@ def export_backbone_forward(
|
||||
dha_ids = dha_pred.max(-1)[-1]
|
||||
dha_mask = (dha_ids == self.NOBIAS).int().unsqueeze(-1)
|
||||
decoder_out = decoder_out * dha_mask + dha_pred * (1 - dha_mask)
|
||||
return decoder_out, pre_token_length, alphas
|
||||
|
||||
# get predicted timestamps
|
||||
us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
|
||||
|
||||
return decoder_out, pre_token_length, us_alphas, us_cif_peak
|
||||
|
||||
|
||||
def export_backbone_dummy_inputs(self):
|
||||
@ -178,7 +182,7 @@ def export_backbone_input_names(self):
|
||||
|
||||
|
||||
def export_backbone_output_names(self):
|
||||
return ["logits", "token_num", "alphas"]
|
||||
return ["logits", "token_num", "us_alphas", "us_cif_peak"]
|
||||
|
||||
|
||||
def export_backbone_dynamic_axes(self):
|
||||
@ -190,6 +194,8 @@ def export_backbone_dynamic_axes(self):
|
||||
"bias_embed": {0: "batch_size", 1: "num_hotwords"},
|
||||
"logits": {0: "batch_size", 1: "logits_length"},
|
||||
"pre_acoustic_embeds": {1: "feats_length1"},
|
||||
"us_alphas": {0: "batch_size", 1: "alphas_length"},
|
||||
"us_cif_peak": {0: "batch_size", 1: "alphas_length"},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -326,6 +326,9 @@ class ContextualParaformer(Paraformer):
|
||||
def __call__(
|
||||
self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
|
||||
) -> List:
|
||||
# def __call__(
|
||||
# self, waveform_list:list, hotwords: str, **kwargs
|
||||
# ) -> List:
|
||||
# make hotword list
|
||||
hotwords, hotwords_length = self.proc_hotword(hotwords)
|
||||
# import pdb; pdb.set_trace()
|
||||
@ -345,15 +348,47 @@ class ContextualParaformer(Paraformer):
|
||||
try:
|
||||
outputs = self.bb_infer(feats, feats_len, bias_embed)
|
||||
am_scores, valid_token_lens = outputs[0], outputs[1]
|
||||
|
||||
if len(outputs) == 4:
|
||||
# for BiCifParaformer Inference
|
||||
us_alphas, us_peaks = outputs[2], outputs[3]
|
||||
else:
|
||||
us_alphas, us_peaks = None, None
|
||||
|
||||
except ONNXRuntimeError:
|
||||
# logging.warning(traceback.format_exc())
|
||||
logging.warning("input wav is silence or noise")
|
||||
preds = [""]
|
||||
else:
|
||||
preds = self.decode(am_scores, valid_token_lens)
|
||||
for pred in preds:
|
||||
pred = sentence_postprocess(pred)
|
||||
asr_res.append({"preds": pred})
|
||||
if us_peaks is None:
|
||||
for pred in preds:
|
||||
if self.language == "en-bpe":
|
||||
pred = sentence_postprocess_sentencepiece(pred)
|
||||
else:
|
||||
pred = sentence_postprocess(pred)
|
||||
asr_res.append({"preds": pred})
|
||||
else:
|
||||
for pred, us_peaks_ in zip(preds, us_peaks):
|
||||
raw_tokens = pred
|
||||
timestamp, timestamp_raw = time_stamp_lfr6_onnx(
|
||||
us_peaks_, copy.copy(raw_tokens)
|
||||
)
|
||||
text_proc, timestamp_proc, _ = sentence_postprocess(
|
||||
raw_tokens, timestamp_raw
|
||||
)
|
||||
# logging.warning(timestamp)
|
||||
if len(self.plot_timestamp_to):
|
||||
self.plot_wave_timestamp(
|
||||
waveform_list[0], timestamp, self.plot_timestamp_to
|
||||
)
|
||||
asr_res.append(
|
||||
{
|
||||
"preds": text_proc,
|
||||
"timestamp": timestamp_proc,
|
||||
"raw_tokens": raw_tokens,
|
||||
}
|
||||
)
|
||||
return asr_res
|
||||
|
||||
def proc_hotword(self, hotwords):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user