From c5274e728aa3350a778889b77dac288234dbb9a0 Mon Sep 17 00:00:00 2001 From: aky15 Date: Mon, 10 Jul 2023 12:48:50 +0800 Subject: [PATCH] Update asr_inference_launch.py (#719) update bat infer for modelscope --- funasr/bin/asr_inference_launch.py | 118 ++++++++++++++++------------- 1 file changed, 66 insertions(+), 52 deletions(-) diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index de1889453..10f8e5024 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -1272,27 +1272,27 @@ def inference_transducer( nbest: int, num_workers: int, log_level: Union[int, str], - data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + # data_path_and_name_and_type: Sequence[Tuple[str, str, str]], asr_train_config: Optional[str], asr_model_file: Optional[str], - cmvn_file: Optional[str], - beam_search_config: Optional[dict], - lm_train_config: Optional[str], - lm_file: Optional[str], - model_tag: Optional[str], - token_type: Optional[str], - bpemodel: Optional[str], - key_file: Optional[str], - allow_variable_data_keys: bool, - quantize_asr_model: Optional[bool], - quantize_modules: Optional[List[str]], - quantize_dtype: Optional[str], - streaming: Optional[bool], - simu_streaming: Optional[bool], - chunk_size: Optional[int], - left_context: Optional[int], - right_context: Optional[int], - display_partial_hypotheses: bool, + cmvn_file: Optional[str] = None, + beam_search_config: Optional[dict] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + model_tag: Optional[str] = None, + token_type: Optional[str] = None, + bpemodel: Optional[str] = None, + key_file: Optional[str] = None, + allow_variable_data_keys: bool = False, + quantize_asr_model: Optional[bool] = False, + quantize_modules: Optional[List[str]] = None, + quantize_dtype: Optional[str] = "float16", + streaming: Optional[bool] = False, + simu_streaming: Optional[bool] = False, + chunk_size: Optional[int] = 16, + left_context: Optional[int] = 16, + right_context: Optional[int] = 0, + display_partial_hypotheses: bool = False, **kwargs, ) -> None: """Transducer model inference. @@ -1327,6 +1327,7 @@ def inference_transducer( right_context: Number of frames in right context AFTER subsampling. display_partial_hypotheses: Whether to display partial hypotheses. """ + # assert check_argument_types() if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") @@ -1369,7 +1370,10 @@ def inference_transducer( left_context=left_context, right_context=right_context, ) - speech2text = Speech2TextTransducer(**speech2text_kwargs) + speech2text = Speech2TextTransducer.from_pretrained( + model_tag=model_tag, + **speech2text_kwargs, + ) def _forward(data_path_and_name_and_type, raw_inputs: Union[np.ndarray, torch.Tensor] = None, @@ -1388,47 +1392,55 @@ def inference_transducer( key_file=key_file, num_workers=num_workers, ) + asr_result_list = [] + + if output_dir is not None: + writer = DatadirWriter(output_dir) + else: + writer = None # 4 .Start for-loop - with DatadirWriter(output_dir) as writer: - for keys, batch in loader: - assert isinstance(batch, dict), type(batch) - assert all(isinstance(s, str) for s in keys), keys + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys - _bs = len(next(iter(batch.values()))) - assert len(keys) == _bs, f"{len(keys)} != {_bs}" - batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} - assert len(batch.keys()) == 1 + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + assert len(batch.keys()) == 1 - try: - if speech2text.streaming: - speech = batch["speech"] + try: + if speech2text.streaming: + speech = batch["speech"] - _steps = len(speech) // speech2text._ctx - _end = 0 - for i in range(_steps): - _end = (i + 1) * speech2text._ctx + _steps = len(speech) // speech2text._ctx + _end = 0 + for i in range(_steps): + _end = (i + 1) * speech2text._ctx - speech2text.streaming_decode( - speech[i * speech2text._ctx: _end], is_final=False - ) - - final_hyps = speech2text.streaming_decode( - speech[_end: len(speech)], is_final=True + speech2text.streaming_decode( + speech[i * speech2text._ctx: _end], is_final=False ) - elif speech2text.simu_streaming: - final_hyps = speech2text.simu_streaming_decode(**batch) - else: - final_hyps = speech2text(**batch) - results = speech2text.hypotheses_to_results(final_hyps) - except TooShortUttError as e: - logging.warning(f"Utterance {keys} {e}") - hyp = Hypothesis(score=0.0, yseq=[], dec_state=None) - results = [[" ", [""], [2], hyp]] * nbest + final_hyps = speech2text.streaming_decode( + speech[_end: len(speech)], is_final=True + ) + elif speech2text.simu_streaming: + final_hyps = speech2text.simu_streaming_decode(**batch) + else: + final_hyps = speech2text(**batch) - key = keys[0] - for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): + results = speech2text.hypotheses_to_results(final_hyps) + except TooShortUttError as e: + logging.warning(f"Utterance {keys} {e}") + hyp = Hypothesis(score=0.0, yseq=[], dec_state=None) + results = [[" ", [""], [2], hyp]] * nbest + + key = keys[0] + for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): + item = {'key': key, 'value': text} + asr_result_list.append(item) + if writer is not None: ibest_writer = writer[f"{n}best_recog"] ibest_writer["token"][key] = " ".join(token) @@ -1438,6 +1450,8 @@ def inference_transducer( if text is not None: ibest_writer["text"][key] = text + logging.info("decoding, utt: {}, predictions: {}".format(key, text)) + return asr_result_list return _forward