From b22c0d228493fe6a662bc8f188b3d46090c09b9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BB=81=E8=BF=B7?= Date: Wed, 15 Feb 2023 19:27:33 +0800 Subject: [PATCH] add decoding model parameters --- funasr/bin/asr_inference_uniasr.py | 14 +++++++++++++- funasr/bin/asr_inference_uniasr_vad.py | 12 ++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py index 0a5824c5c..db09d31b3 100644 --- a/funasr/bin/asr_inference_uniasr.py +++ b/funasr/bin/asr_inference_uniasr.py @@ -397,7 +397,7 @@ def inference_modelscope( device = "cuda" else: device = "cpu" - + # 1. Set random-seed set_all_random_seed(seed) @@ -439,6 +439,18 @@ def inference_modelscope( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + if param_dict is not None and "decoding_model" in param_dict: + if param_dict["decoding_model"] == "fast": + speech2text.decoding_ind = 0 + speech2text.decoding_mode = "model1" + elif param_dict["decoding_model"] == "normal": + speech2text.decoding_ind = 0 + speech2text.decoding_mode = "model2" + elif param_dict["decoding_model"] == "offline": + speech2text.decoding_ind = 1 + speech2text.decoding_mode = "model2" + else: + raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) loader = ASRTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py index 0a5824c5c..de32dcf71 100644 --- a/funasr/bin/asr_inference_uniasr_vad.py +++ b/funasr/bin/asr_inference_uniasr_vad.py @@ -439,6 +439,18 @@ def inference_modelscope( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + if param_dict is not None and "decoding_model" in param_dict: + if param_dict["decoding_model"] == "fast": + speech2text.decoding_ind = 0 + speech2text.decoding_mode = "model1" + elif param_dict["decoding_model"] == "normal": + speech2text.decoding_ind = 0 + speech2text.decoding_mode = "model2" + elif param_dict["decoding_model"] == "offline": + speech2text.decoding_ind = 1 + speech2text.decoding_mode = "model2" + else: + raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) loader = ASRTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype,