update repo

This commit is contained in:
嘉渊 2023-06-14 15:02:38 +08:00
parent 0b07e6af04
commit fee48e7a62
3 changed files with 26 additions and 31 deletions

View File

@ -8,6 +8,7 @@ on:
branches:
- dev_wjm
- dev_jy
- dev_wjm_infer
jobs:
build:

View File

@ -31,6 +31,7 @@ from funasr.bin.asr_infer import Speech2TextUniASR
from funasr.bin.punc_infer import Text2Punc
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.subsampling import TooShortUttError
@ -142,18 +143,16 @@ def inference_asr(
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
loader = build_streaming_iterator(
task_name="asr",
preprocess_args=speech2text.asr_train_args,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
@ -329,17 +328,15 @@ def inference_paraformer(
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
loader = build_streaming_iterator(
task_name="asr",
preprocess_args=speech2text.asr_train_args,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
if param_dict is not None:
@ -580,17 +577,15 @@ def inference_paraformer_vad_punc(
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
loader = build_streaming_iterator(
task_name="asr",
preprocess_args=None,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=1,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
if param_dict is not None:
@ -1027,17 +1022,15 @@ def inference_uniasr(
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
loader = build_streaming_iterator(
task_name="asr",
preprocess_args=speech2text.asr_train_args,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
@ -1182,18 +1175,16 @@ def inference_mfcca(
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
loader = build_streaming_iterator(
task_name="asr",
preprocess_args=speech2text.asr_train_args,
data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
fs=fs,
mc=True,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0

View File

@ -24,7 +24,10 @@ def build_streaming_iterator(
assert check_argument_types()
# preprocess
preprocess_fn = build_preprocess(preprocess_args, train)
if preprocess_args is not None:
preprocess_fn = build_preprocess(preprocess_args, train)
else:
preprocess_fn = None
# collate
if task_name in ["punc", "lm"]: