diff --git a/funasr/bin/tp_inference.py b/funasr/bin/tp_inference.py index b3f15d484..766f94f1c 100644 --- a/funasr/bin/tp_inference.py +++ b/funasr/bin/tp_inference.py @@ -100,9 +100,9 @@ def time_stamp_lfr6_advance(us_alphas, us_cif_peak, char_list): class SpeechText2Timestamp: def __init__( self, - tp_train_config: Union[Path, str] = None, - tp_model_file: Union[Path, str] = None, - tp_cmvn_file: Union[Path, str] = None, + timestamp_infer_config: Union[Path, str] = None, + timestamp_model_file: Union[Path, str] = None, + timestamp_cmvn_file: Union[Path, str] = None, device: str = "cpu", dtype: str = "float32", **kwargs, @@ -110,11 +110,11 @@ class SpeechText2Timestamp: assert check_argument_types() # 1. Build ASR model tp_model, tp_train_args = ASRTask.build_model_from_file( - tp_train_config, tp_model_file, device + timestamp_infer_config, timestamp_model_file, device ) frontend = None if tp_train_args.frontend is not None: - frontend = WavFrontend(cmvn_file=tp_cmvn_file, **tp_train_args.frontend_conf) + frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf) logging.info("tp_model: {}".format(tp_model)) logging.info("tp_train_args: {}".format(tp_train_args)) @@ -178,9 +178,9 @@ def inference( ngpu: int, log_level: Union[int, str], data_path_and_name_and_type, - tp_train_config: Optional[str], - tp_model_file: Optional[str], - tp_cmvn_file: Optional[str] = None, + timestamp_infer_config: Optional[str], + timestamp_model_file: Optional[str], + timestamp_cmvn_file: Optional[str] = None, raw_inputs: Union[np.ndarray, torch.Tensor] = None, key_file: Optional[str] = None, allow_variable_data_keys: bool = False, @@ -194,9 +194,9 @@ def inference( batch_size=batch_size, ngpu=ngpu, log_level=log_level, - tp_train_config=tp_train_config, - tp_model_file=tp_model_file, - tp_cmvn_file=tp_cmvn_file, + timestamp_infer_config=timestamp_infer_config, + timestamp_model_file=timestamp_model_file, + timestamp_cmvn_file=timestamp_cmvn_file, key_file=key_file, allow_variable_data_keys=allow_variable_data_keys, output_dir=output_dir, @@ -213,9 +213,9 @@ def inference_modelscope( ngpu: int, log_level: Union[int, str], # data_path_and_name_and_type, - tp_train_config: Optional[str], - tp_model_file: Optional[str], - tp_cmvn_file: Optional[str] = None, + timestamp_infer_config: Optional[str], + timestamp_model_file: Optional[str], + timestamp_cmvn_file: Optional[str] = None, # raw_inputs: Union[np.ndarray, torch.Tensor] = None, key_file: Optional[str] = None, allow_variable_data_keys: bool = False, @@ -246,9 +246,9 @@ def inference_modelscope( # 2. Build speech2vadsegment speechtext2timestamp_kwargs = dict( - tp_train_config=tp_train_config, - tp_model_file=tp_model_file, - tp_cmvn_file=tp_cmvn_file, + timestamp_infer_config=timestamp_infer_config, + timestamp_model_file=timestamp_model_file, + timestamp_cmvn_file=timestamp_cmvn_file, device=device, dtype=dtype, ) @@ -365,17 +365,17 @@ def get_parser(): group = parser.add_argument_group("The model configuration related") group.add_argument( - "--tp_train_config", + "--timestamp_infer_config", type=str, help="VAD infer configuration", ) group.add_argument( - "--tp_model_file", + "--timestamp_model_file", type=str, help="VAD model parameter file", ) group.add_argument( - "--tp_cmvn_file", + "--timestamp_cmvn_file", type=str, help="Global cmvn file", ) diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py index 903e041fe..dd76df61b 100644 --- a/funasr/bin/tp_inference_launch.py +++ b/funasr/bin/tp_inference_launch.py @@ -76,17 +76,17 @@ def get_parser(): group = parser.add_argument_group("The model configuration related") group.add_argument( - "--tp_train_config", + "--timestamp_infer_config", type=str, help="VAD infer configuration", ) group.add_argument( - "--tp_model_file", + "--timestamp_model_file", type=str, help="VAD model parameter file", ) group.add_argument( - "--tp_cmvn_file", + "--timestamp_cmvn_file", type=str, help="Global CMVN file", )