diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py index b807a3452..8265fc590 100644 --- a/funasr/bin/asr_inference_paraformer.py +++ b/funasr/bin/asr_inference_paraformer.py @@ -42,6 +42,7 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils from funasr.models.frontend.wav_frontend import WavFrontend from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export +from funasr.utils.timestamp_tools import time_stamp_lfr6_pl, time_stamp_sentence class Speech2Text: @@ -190,7 +191,8 @@ class Speech2Text: @torch.no_grad() def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, + begin_time: int = 0, end_time: int = None, ): """Inference @@ -242,6 +244,10 @@ class Speech2Text: decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + if isinstance(self.asr_model, BiCifParaformer): + _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len, + pre_token_length) # test no bias cif2 + results = [] b, n, d = decoder_out.size() for i in range(b): @@ -284,7 +290,11 @@ class Speech2Text: else: text = None - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) + if isinstance(self.asr_model, BiCifParaformer): + timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) + results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor)) + else: + results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) # assert check_return_type(results) return results @@ -683,6 +693,11 @@ def inference_modelscope( inference=True, ) + if param_dict is not None: + use_timestamp = param_dict.get('use_timestamp', True) + else: + use_timestamp = True + forward_time_total = 0.0 length_total = 0.0 finish_count = 0 @@ -724,7 +739,9 @@ def inference_modelscope( result = [results[batch_id][:-2]] key = keys[batch_id] - for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result): + for n, result in zip(range(1, nbest + 1), result): + text, token, token_int, hyp = result[0], result[1], result[2], result[3] + time_stamp = None if len(result) < 5 else result[4] # Create a directory: outdir/{n}best_recog if writer is not None: ibest_writer = writer[f"{n}best_recog"] @@ -736,8 +753,20 @@ def inference_modelscope( ibest_writer["rtf"][key] = rtf_cur if text is not None: - text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + if use_timestamp and time_stamp is not None: + postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) + else: + postprocessed_result = postprocess_utils.sentence_postprocess(token) + time_stamp_postprocessed = "" + if len(postprocessed_result) == 3: + text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \ + postprocessed_result[1], \ + postprocessed_result[2] + else: + text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] item = {'key': key, 'value': text_postprocessed} + if time_stamp_postprocessed != "": + item['time_stamp'] = time_stamp_postprocessed asr_result_list.append(item) finish_count += 1 # asr_utils.print_progress(finish_count / file_count)