Merge pull request #1763 from coolEphemeroptera/main

fixed the issues about seaco-onnx timestamp
This commit is contained in:
Shi Xian 2024-05-28 17:46:30 +08:00 committed by GitHub
commit 50b2668019
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 5 deletions

View File

@ -163,7 +163,11 @@ def export_backbone_forward(
dha_ids = dha_pred.max(-1)[-1]
dha_mask = (dha_ids == self.NOBIAS).int().unsqueeze(-1)
decoder_out = decoder_out * dha_mask + dha_pred * (1 - dha_mask)
return decoder_out, pre_token_length, alphas
# get predicted timestamps
us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
return decoder_out, pre_token_length, us_alphas, us_cif_peak
def export_backbone_dummy_inputs(self):
@ -178,7 +182,7 @@ def export_backbone_input_names(self):
def export_backbone_output_names(self):
return ["logits", "token_num", "alphas"]
return ["logits", "token_num", "us_alphas", "us_cif_peak"]
def export_backbone_dynamic_axes(self):
@ -190,6 +194,8 @@ def export_backbone_dynamic_axes(self):
"bias_embed": {0: "batch_size", 1: "num_hotwords"},
"logits": {0: "batch_size", 1: "logits_length"},
"pre_acoustic_embeds": {1: "feats_length1"},
"us_alphas": {0: "batch_size", 1: "alphas_length"},
"us_cif_peak": {0: "batch_size", 1: "alphas_length"},
}

View File

@ -326,6 +326,9 @@ class ContextualParaformer(Paraformer):
def __call__(
self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
) -> List:
# def __call__(
# self, waveform_list:list, hotwords: str, **kwargs
# ) -> List:
# make hotword list
hotwords, hotwords_length = self.proc_hotword(hotwords)
# import pdb; pdb.set_trace()
@ -345,15 +348,47 @@ class ContextualParaformer(Paraformer):
try:
outputs = self.bb_infer(feats, feats_len, bias_embed)
am_scores, valid_token_lens = outputs[0], outputs[1]
if len(outputs) == 4:
# for BiCifParaformer Inference
us_alphas, us_peaks = outputs[2], outputs[3]
else:
us_alphas, us_peaks = None, None
except ONNXRuntimeError:
# logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
preds = [""]
else:
preds = self.decode(am_scores, valid_token_lens)
for pred in preds:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
if us_peaks is None:
for pred in preds:
if self.language == "en-bpe":
pred = sentence_postprocess_sentencepiece(pred)
else:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
else:
for pred, us_peaks_ in zip(preds, us_peaks):
raw_tokens = pred
timestamp, timestamp_raw = time_stamp_lfr6_onnx(
us_peaks_, copy.copy(raw_tokens)
)
text_proc, timestamp_proc, _ = sentence_postprocess(
raw_tokens, timestamp_raw
)
# logging.warning(timestamp)
if len(self.plot_timestamp_to):
self.plot_wave_timestamp(
waveform_list[0], timestamp, self.plot_timestamp_to
)
asr_res.append(
{
"preds": text_proc,
"timestamp": timestamp_proc,
"raw_tokens": raw_tokens,
}
)
return asr_res
def proc_hotword(self, hotwords):