update repo

This commit is contained in:
嘉渊 2023-06-14 15:02:38 +08:00
parent 0b07e6af04
commit fee48e7a62
3 changed files with 26 additions and 31 deletions

View File

@ -8,6 +8,7 @@ on:
branches: branches:
- dev_wjm - dev_wjm
- dev_jy - dev_jy
- dev_wjm_infer
jobs: jobs:
build: build:

View File

@ -31,6 +31,7 @@ from funasr.bin.asr_infer import Speech2TextUniASR
from funasr.bin.punc_infer import Text2Punc from funasr.bin.punc_infer import Text2Punc
from funasr.bin.tp_infer import Speech2Timestamp from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.vad_infer import Speech2VadSegment from funasr.bin.vad_infer import Speech2VadSegment
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.fileio.datadir_writer import DatadirWriter from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import Hypothesis from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.subsampling import TooShortUttError from funasr.modules.subsampling import TooShortUttError
@ -142,18 +143,16 @@ def inference_asr(
if isinstance(raw_inputs, torch.Tensor): if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy() raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator( loader = build_streaming_iterator(
data_path_and_name_and_type, task_name="asr",
preprocess_args=speech2text.asr_train_args,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype, dtype=dtype,
fs=fs, fs=fs,
mc=mc, mc=mc,
batch_size=batch_size, batch_size=batch_size,
key_file=key_file, key_file=key_file,
num_workers=num_workers, num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
) )
finish_count = 0 finish_count = 0
@ -329,17 +328,15 @@ def inference_paraformer(
if isinstance(raw_inputs, torch.Tensor): if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy() raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator( loader = build_streaming_iterator(
data_path_and_name_and_type, task_name="asr",
preprocess_args=speech2text.asr_train_args,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype, dtype=dtype,
fs=fs, fs=fs,
batch_size=batch_size, batch_size=batch_size,
key_file=key_file, key_file=key_file,
num_workers=num_workers, num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
) )
if param_dict is not None: if param_dict is not None:
@ -580,17 +577,15 @@ def inference_paraformer_vad_punc(
if isinstance(raw_inputs, torch.Tensor): if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy() raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator( loader = build_streaming_iterator(
data_path_and_name_and_type, task_name="asr",
preprocess_args=None,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype, dtype=dtype,
fs=fs, fs=fs,
batch_size=1, batch_size=1,
key_file=key_file, key_file=key_file,
num_workers=num_workers, num_workers=num_workers,
preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
) )
if param_dict is not None: if param_dict is not None:
@ -1027,17 +1022,15 @@ def inference_uniasr(
if isinstance(raw_inputs, torch.Tensor): if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy() raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator( loader = build_streaming_iterator(
data_path_and_name_and_type, task_name="asr",
preprocess_args=speech2text.asr_train_args,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype, dtype=dtype,
fs=fs, fs=fs,
batch_size=batch_size, batch_size=batch_size,
key_file=key_file, key_file=key_file,
num_workers=num_workers, num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
) )
finish_count = 0 finish_count = 0
@ -1182,18 +1175,16 @@ def inference_mfcca(
if isinstance(raw_inputs, torch.Tensor): if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy() raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator( loader = build_streaming_iterator(
data_path_and_name_and_type, task_name="asr",
preprocess_args=speech2text.asr_train_args,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype, dtype=dtype,
batch_size=batch_size, batch_size=batch_size,
fs=fs, fs=fs,
mc=True, mc=True,
key_file=key_file, key_file=key_file,
num_workers=num_workers, num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
) )
finish_count = 0 finish_count = 0

View File

@ -24,7 +24,10 @@ def build_streaming_iterator(
assert check_argument_types() assert check_argument_types()
# preprocess # preprocess
preprocess_fn = build_preprocess(preprocess_args, train) if preprocess_args is not None:
preprocess_fn = build_preprocess(preprocess_args, train)
else:
preprocess_fn = None
# collate # collate
if task_name in ["punc", "lm"]: if task_name in ["punc", "lm"]: