mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix uniasr decoding bug
This commit is contained in:
parent
7bb2dfba0c
commit
09ff7d4516
@ -398,6 +398,19 @@ def inference_modelscope(
|
|||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
|
if param_dict is not None and "decoding_model" in param_dict:
|
||||||
|
if param_dict["decoding_model"] == "fast":
|
||||||
|
decoding_ind = 0
|
||||||
|
decoding_mode = "model1"
|
||||||
|
elif param_dict["decoding_model"] == "normal":
|
||||||
|
decoding_ind = 0
|
||||||
|
decoding_mode = "model2"
|
||||||
|
elif param_dict["decoding_model"] == "offline":
|
||||||
|
decoding_ind = 1
|
||||||
|
decoding_mode = "model2"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
|
||||||
|
|
||||||
# 1. Set random-seed
|
# 1. Set random-seed
|
||||||
set_all_random_seed(seed)
|
set_all_random_seed(seed)
|
||||||
|
|
||||||
@ -440,18 +453,6 @@ def inference_modelscope(
|
|||||||
if isinstance(raw_inputs, torch.Tensor):
|
if isinstance(raw_inputs, torch.Tensor):
|
||||||
raw_inputs = raw_inputs.numpy()
|
raw_inputs = raw_inputs.numpy()
|
||||||
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||||
if param_dict is not None and "decoding_model" in param_dict:
|
|
||||||
if param_dict["decoding_model"] == "fast":
|
|
||||||
speech2text.decoding_ind = 0
|
|
||||||
speech2text.decoding_mode = "model1"
|
|
||||||
elif param_dict["decoding_model"] == "normal":
|
|
||||||
speech2text.decoding_ind = 0
|
|
||||||
speech2text.decoding_mode = "model2"
|
|
||||||
elif param_dict["decoding_model"] == "offline":
|
|
||||||
speech2text.decoding_ind = 1
|
|
||||||
speech2text.decoding_mode = "model2"
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
|
|
||||||
loader = ASRTask.build_streaming_iterator(
|
loader = ASRTask.build_streaming_iterator(
|
||||||
data_path_and_name_and_type,
|
data_path_and_name_and_type,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|||||||
@ -398,6 +398,19 @@ def inference_modelscope(
|
|||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
|
if param_dict is not None and "decoding_model" in param_dict:
|
||||||
|
if param_dict["decoding_model"] == "fast":
|
||||||
|
decoding_ind = 0
|
||||||
|
decoding_mode = "model1"
|
||||||
|
elif param_dict["decoding_model"] == "normal":
|
||||||
|
decoding_ind = 0
|
||||||
|
decoding_mode = "model2"
|
||||||
|
elif param_dict["decoding_model"] == "offline":
|
||||||
|
decoding_ind = 1
|
||||||
|
decoding_mode = "model2"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
|
||||||
|
|
||||||
# 1. Set random-seed
|
# 1. Set random-seed
|
||||||
set_all_random_seed(seed)
|
set_all_random_seed(seed)
|
||||||
|
|
||||||
@ -440,18 +453,6 @@ def inference_modelscope(
|
|||||||
if isinstance(raw_inputs, torch.Tensor):
|
if isinstance(raw_inputs, torch.Tensor):
|
||||||
raw_inputs = raw_inputs.numpy()
|
raw_inputs = raw_inputs.numpy()
|
||||||
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||||
if param_dict is not None and "decoding_model" in param_dict:
|
|
||||||
if param_dict["decoding_model"] == "fast":
|
|
||||||
speech2text.decoding_ind = 0
|
|
||||||
speech2text.decoding_mode = "model1"
|
|
||||||
elif param_dict["decoding_model"] == "normal":
|
|
||||||
speech2text.decoding_ind = 0
|
|
||||||
speech2text.decoding_mode = "model2"
|
|
||||||
elif param_dict["decoding_model"] == "offline":
|
|
||||||
speech2text.decoding_ind = 1
|
|
||||||
speech2text.decoding_mode = "model2"
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
|
|
||||||
loader = ASRTask.build_streaming_iterator(
|
loader = ASRTask.build_streaming_iterator(
|
||||||
data_path_and_name_and_type,
|
data_path_and_name_and_type,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user