diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py index c50bf17f6..8b31fad13 100644 --- a/funasr/bin/asr_inference_uniasr.py +++ b/funasr/bin/asr_inference_uniasr.py @@ -398,6 +398,19 @@ def inference_modelscope( else: device = "cpu" + if param_dict is not None and "decoding_model" in param_dict: + if param_dict["decoding_model"] == "fast": + decoding_ind = 0 + decoding_mode = "model1" + elif param_dict["decoding_model"] == "normal": + decoding_ind = 0 + decoding_mode = "model2" + elif param_dict["decoding_model"] == "offline": + decoding_ind = 1 + decoding_mode = "model2" + else: + raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) + # 1. Set random-seed set_all_random_seed(seed) @@ -440,18 +453,6 @@ def inference_modelscope( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - if param_dict is not None and "decoding_model" in param_dict: - if param_dict["decoding_model"] == "fast": - speech2text.decoding_ind = 0 - speech2text.decoding_mode = "model1" - elif param_dict["decoding_model"] == "normal": - speech2text.decoding_ind = 0 - speech2text.decoding_mode = "model2" - elif param_dict["decoding_model"] == "offline": - speech2text.decoding_ind = 1 - speech2text.decoding_mode = "model2" - else: - raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) loader = ASRTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py index ac3b4b6a8..e5815df11 100644 --- a/funasr/bin/asr_inference_uniasr_vad.py +++ b/funasr/bin/asr_inference_uniasr_vad.py @@ -398,6 +398,19 @@ def inference_modelscope( else: device = "cpu" + if param_dict is not None and "decoding_model" in param_dict: + if param_dict["decoding_model"] == "fast": + decoding_ind = 0 + decoding_mode = "model1" + elif param_dict["decoding_model"] == "normal": + decoding_ind = 0 + decoding_mode = "model2" + elif param_dict["decoding_model"] == "offline": + decoding_ind = 1 + decoding_mode = "model2" + else: + raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) + # 1. Set random-seed set_all_random_seed(seed) @@ -440,18 +453,6 @@ def inference_modelscope( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - if param_dict is not None and "decoding_model" in param_dict: - if param_dict["decoding_model"] == "fast": - speech2text.decoding_ind = 0 - speech2text.decoding_mode = "model1" - elif param_dict["decoding_model"] == "normal": - speech2text.decoding_ind = 0 - speech2text.decoding_mode = "model2" - elif param_dict["decoding_model"] == "offline": - speech2text.decoding_ind = 1 - speech2text.decoding_mode = "model2" - else: - raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) loader = ASRTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype,