diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py index e45e575ed..2eeffcd91 100644 --- a/funasr/bin/asr_inference_paraformer.py +++ b/funasr/bin/asr_inference_paraformer.py @@ -43,6 +43,7 @@ from funasr.models.frontend.wav_frontend import WavFrontend from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard +from funasr.bin.tp_inference import SpeechText2Timestamp class Speech2Text: @@ -540,7 +541,8 @@ def inference( ngram_weight: float = 0.9, nbest: int = 1, num_workers: int = 1, - + timestamp_infer_config: Union[Path, str] = None, + timestamp_model_file: Union[Path, str] = None, **kwargs, ): inference_pipeline = inference_modelscope( @@ -604,6 +606,8 @@ def inference_modelscope( nbest: int = 1, num_workers: int = 1, output_dir: Optional[str] = None, + timestamp_infer_config: Union[Path, str] = None, + timestamp_model_file: Union[Path, str] = None, param_dict: dict = None, **kwargs, ): @@ -661,6 +665,15 @@ def inference_modelscope( else: speech2text = Speech2Text(**speech2text_kwargs) + if timestamp_model_file is not None: + speechtext2timestamp = SpeechText2Timestamp( + timestamp_cmvn_file=cmvn_file, + timestamp_model_file=timestamp_model_file, + timestamp_infer_config=timestamp_infer_config, + ) + else: + speechtext2timestamp = None + def _forward( data_path_and_name_and_type, raw_inputs: Union[np.ndarray, torch.Tensor] = None, @@ -744,7 +757,17 @@ def inference_modelscope( key = keys[batch_id] for n, result in zip(range(1, nbest + 1), result): text, token, token_int, hyp = result[0], result[1], result[2], result[3] - time_stamp = None if len(result) < 5 else result[4] + timestamp = None if len(result) < 5 else result[4] + # conduct timestamp prediction here + # timestamp inference requires token length + # thus following inference cannot be conducted in batch + if timestamp is None and speechtext2timestamp: + ts_batch = {} + ts_batch['speech'] = batch['speech'][batch_id].unsqueeze(0) + ts_batch['speech_lengths'] = torch.tensor([batch['speech_lengths'][batch_id]]) + ts_batch['text_lengths'] = torch.tensor([len(token)]) + us_alphas, us_peaks = speechtext2timestamp(**ts_batch) + ts_str, timestamp = ts_prediction_lfr6_standard(us_alphas[0], us_peaks[0], token, force_time_shift=-3.0) # Create a directory: outdir/{n}best_recog if writer is not None: ibest_writer = writer[f"{n}best_recog"] @@ -756,20 +779,20 @@ def inference_modelscope( ibest_writer["rtf"][key] = rtf_cur if text is not None: - if use_timestamp and time_stamp is not None: - postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) + if use_timestamp and timestamp is not None: + postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp) else: postprocessed_result = postprocess_utils.sentence_postprocess(token) - time_stamp_postprocessed = "" + timestamp_postprocessed = "" if len(postprocessed_result) == 3: - text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \ + text_postprocessed, timestamp_postprocessed, word_lists = postprocessed_result[0], \ postprocessed_result[1], \ postprocessed_result[2] else: text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] item = {'key': key, 'value': text_postprocessed} - if time_stamp_postprocessed != "": - item['time_stamp'] = time_stamp_postprocessed + if timestamp_postprocessed != "": + item['timestamp'] = timestamp_postprocessed asr_result_list.append(item) finish_count += 1 # asr_utils.print_progress(finish_count / file_count) diff --git a/funasr/bin/tp_inference.py b/funasr/bin/tp_inference.py index e374a227a..6360b17db 100644 --- a/funasr/bin/tp_inference.py +++ b/funasr/bin/tp_inference.py @@ -116,8 +116,8 @@ class SpeechText2Timestamp: enc = enc[0] # c. Forward Predictor - _, _, us_alphas, us_cif_peak = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1) - return us_alphas, us_cif_peak + _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1) + return us_alphas, us_peaks def inference(