fix uniasr decoding bug

This commit is contained in:
仁迷 2023-02-23 16:28:05 +08:00
parent 7bb2dfba0c
commit 09ff7d4516
2 changed files with 26 additions and 24 deletions

View File

@ -398,6 +398,19 @@ def inference_modelscope(
else:
device = "cpu"
if param_dict is not None and "decoding_model" in param_dict:
if param_dict["decoding_model"] == "fast":
decoding_ind = 0
decoding_mode = "model1"
elif param_dict["decoding_model"] == "normal":
decoding_ind = 0
decoding_mode = "model2"
elif param_dict["decoding_model"] == "offline":
decoding_ind = 1
decoding_mode = "model2"
else:
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
# 1. Set random-seed
set_all_random_seed(seed)
@ -440,18 +453,6 @@ def inference_modelscope(
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
if param_dict is not None and "decoding_model" in param_dict:
if param_dict["decoding_model"] == "fast":
speech2text.decoding_ind = 0
speech2text.decoding_mode = "model1"
elif param_dict["decoding_model"] == "normal":
speech2text.decoding_ind = 0
speech2text.decoding_mode = "model2"
elif param_dict["decoding_model"] == "offline":
speech2text.decoding_ind = 1
speech2text.decoding_mode = "model2"
else:
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,

View File

@ -398,6 +398,19 @@ def inference_modelscope(
else:
device = "cpu"
if param_dict is not None and "decoding_model" in param_dict:
if param_dict["decoding_model"] == "fast":
decoding_ind = 0
decoding_mode = "model1"
elif param_dict["decoding_model"] == "normal":
decoding_ind = 0
decoding_mode = "model2"
elif param_dict["decoding_model"] == "offline":
decoding_ind = 1
decoding_mode = "model2"
else:
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
# 1. Set random-seed
set_all_random_seed(seed)
@ -440,18 +453,6 @@ def inference_modelscope(
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
if param_dict is not None and "decoding_model" in param_dict:
if param_dict["decoding_model"] == "fast":
speech2text.decoding_ind = 0
speech2text.decoding_mode = "model1"
elif param_dict["decoding_model"] == "normal":
speech2text.decoding_ind = 0
speech2text.decoding_mode = "model2"
elif param_dict["decoding_model"] == "offline":
speech2text.decoding_ind = 1
speech2text.decoding_mode = "model2"
else:
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,