diff --git a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/infer.py b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/infer.py index f54399a14..6f810ffa1 100644 --- a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/infer.py +++ b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/infer.py @@ -16,7 +16,7 @@ def modelscope_infer(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") + parser.add_argument('--model', type=str, default="damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020") parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp") parser.add_argument('--output_dir', type=str, default="./results/") parser.add_argument('--decoding_mode', type=str, default="normal") diff --git a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/infer.sh b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/infer.sh index ef49d7a60..36f40b6b3 100644 --- a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/infer.sh +++ b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/infer.sh @@ -6,7 +6,7 @@ set -o pipefail stage=1 stop_stage=2 -model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +model="damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020" data_dir="./data/test" output_dir="./results" batch_size=64 diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index 43da8bf92..80732133f 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -1918,6 +1918,8 @@ class Speech2TextWhisper: nbest: int = 1, streaming: bool = False, frontend_conf: dict = None, + language: str = None, + task: str = "transcribe", **kwargs, ): @@ -1960,6 +1962,8 @@ class Speech2TextWhisper: self.device = device self.dtype = dtype self.frontend = frontend + self.language = language + self.task = task @torch.no_grad() def __call__( @@ -1986,10 +1990,10 @@ class Speech2TextWhisper: mel = log_mel_spectrogram(speech).to(self.device) if self.asr_model.is_multilingual: - options = DecodingOptions(fp16=False) + options = DecodingOptions(fp16=False, language=self.language, task=self.task) asr_res = decode(self.asr_model, mel, options) text = asr_res.text - language = asr_res.language + language = self.language if self.language else asr_res.language else: asr_res = transcribe(self.asr_model, speech, fp16=False) text = asr_res["text"] diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index e3de05b09..1040f6f61 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -2056,6 +2056,8 @@ def inference_whisper( ncpu = kwargs.get("ncpu", 1) torch.set_num_threads(ncpu) + language = param_dict.get("language", None) + task = param_dict.get("task", "transcribe") if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if word_lm_train_config is not None: @@ -2099,6 +2101,8 @@ def inference_whisper( penalty=penalty, nbest=nbest, streaming=streaming, + language=language, + task=task, ) logging.info("speech2text_kwargs: {}".format(speech2text_kwargs)) speech2text = Speech2TextWhisper(**speech2text_kwargs)