From 09ff7d4516128bfe1db8a81ca6de0d89ea55d88c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BB=81=E8=BF=B7?= Date: Thu, 23 Feb 2023 16:28:05 +0800 Subject: [PATCH] fix uniasr decoding bug --- funasr/bin/asr_inference_uniasr.py | 25 +++++++++++++------------ funasr/bin/asr_inference_uniasr_vad.py | 25 +++++++++++++------------ 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py index c50bf17f6..8b31fad13 100644 --- a/funasr/bin/asr_inference_uniasr.py +++ b/funasr/bin/asr_inference_uniasr.py @@ -398,6 +398,19 @@ def inference_modelscope( else: device = "cpu" + if param_dict is not None and "decoding_model" in param_dict: + if param_dict["decoding_model"] == "fast": + decoding_ind = 0 + decoding_mode = "model1" + elif param_dict["decoding_model"] == "normal": + decoding_ind = 0 + decoding_mode = "model2" + elif param_dict["decoding_model"] == "offline": + decoding_ind = 1 + decoding_mode = "model2" + else: + raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) + # 1. Set random-seed set_all_random_seed(seed) @@ -440,18 +453,6 @@ def inference_modelscope( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - if param_dict is not None and "decoding_model" in param_dict: - if param_dict["decoding_model"] == "fast": - speech2text.decoding_ind = 0 - speech2text.decoding_mode = "model1" - elif param_dict["decoding_model"] == "normal": - speech2text.decoding_ind = 0 - speech2text.decoding_mode = "model2" - elif param_dict["decoding_model"] == "offline": - speech2text.decoding_ind = 1 - speech2text.decoding_mode = "model2" - else: - raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) loader = ASRTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py index ac3b4b6a8..e5815df11 100644 --- a/funasr/bin/asr_inference_uniasr_vad.py +++ b/funasr/bin/asr_inference_uniasr_vad.py @@ -398,6 +398,19 @@ def inference_modelscope( else: device = "cpu" + if param_dict is not None and "decoding_model" in param_dict: + if param_dict["decoding_model"] == "fast": + decoding_ind = 0 + decoding_mode = "model1" + elif param_dict["decoding_model"] == "normal": + decoding_ind = 0 + decoding_mode = "model2" + elif param_dict["decoding_model"] == "offline": + decoding_ind = 1 + decoding_mode = "model2" + else: + raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) + # 1. Set random-seed set_all_random_seed(seed) @@ -440,18 +453,6 @@ def inference_modelscope( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - if param_dict is not None and "decoding_model" in param_dict: - if param_dict["decoding_model"] == "fast": - speech2text.decoding_ind = 0 - speech2text.decoding_mode = "model1" - elif param_dict["decoding_model"] == "normal": - speech2text.decoding_ind = 0 - speech2text.decoding_mode = "model2" - elif param_dict["decoding_model"] == "offline": - speech2text.decoding_ind = 1 - speech2text.decoding_mode = "model2" - else: - raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) loader = ASRTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype,