diff --git a/.github/workflows/UnitTest.yml b/.github/workflows/UnitTest.yml index 3b0a1ee2e..19ad1f1f8 100644 --- a/.github/workflows/UnitTest.yml +++ b/.github/workflows/UnitTest.yml @@ -8,6 +8,7 @@ on: branches: - dev_wjm - dev_jy + - dev_wjm_infer jobs: build: diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index a56552dc6..539e82399 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -31,6 +31,7 @@ from funasr.bin.asr_infer import Speech2TextUniASR from funasr.bin.punc_infer import Text2Punc from funasr.bin.tp_infer import Speech2Timestamp from funasr.bin.vad_infer import Speech2VadSegment +from funasr.build_utils.build_streaming_iterator import build_streaming_iterator from funasr.fileio.datadir_writer import DatadirWriter from funasr.modules.beam_search.beam_search import Hypothesis from funasr.modules.subsampling import TooShortUttError @@ -142,18 +143,16 @@ def inference_asr( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, + loader = build_streaming_iterator( + task_name="asr", + preprocess_args=speech2text.asr_train_args, + data_path_and_name_and_type=data_path_and_name_and_type, dtype=dtype, fs=fs, mc=mc, batch_size=batch_size, key_file=key_file, num_workers=num_workers, - preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), - collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, ) finish_count = 0 @@ -329,17 +328,15 @@ def inference_paraformer( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, + loader = build_streaming_iterator( + task_name="asr", + preprocess_args=speech2text.asr_train_args, + data_path_and_name_and_type=data_path_and_name_and_type, dtype=dtype, fs=fs, batch_size=batch_size, key_file=key_file, num_workers=num_workers, - preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), - collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, ) if param_dict is not None: @@ -580,17 +577,15 @@ def inference_paraformer_vad_punc( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, + loader = build_streaming_iterator( + task_name="asr", + preprocess_args=None, + data_path_and_name_and_type=data_path_and_name_and_type, dtype=dtype, fs=fs, batch_size=1, key_file=key_file, num_workers=num_workers, - preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False), - collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, ) if param_dict is not None: @@ -1027,17 +1022,15 @@ def inference_uniasr( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, + loader = build_streaming_iterator( + task_name="asr", + preprocess_args=speech2text.asr_train_args, + data_path_and_name_and_type=data_path_and_name_and_type, dtype=dtype, fs=fs, batch_size=batch_size, key_file=key_file, num_workers=num_workers, - preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), - collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, ) finish_count = 0 @@ -1182,18 +1175,16 @@ def inference_mfcca( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, + loader = build_streaming_iterator( + task_name="asr", + preprocess_args=speech2text.asr_train_args, + data_path_and_name_and_type=data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, fs=fs, mc=True, key_file=key_file, num_workers=num_workers, - preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), - collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, ) finish_count = 0 diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py index 57cf8cfbc..732fe097d 100644 --- a/funasr/build_utils/build_streaming_iterator.py +++ b/funasr/build_utils/build_streaming_iterator.py @@ -24,7 +24,10 @@ def build_streaming_iterator( assert check_argument_types() # preprocess - preprocess_fn = build_preprocess(preprocess_args, train) + if preprocess_args is not None: + preprocess_fn = build_preprocess(preprocess_args, train) + else: + preprocess_fn = None # collate if task_name in ["punc", "lm"]: