Update asr_inference_launch.py (#719)

update bat infer for modelscope
This commit is contained in:
aky15 2023-07-10 12:48:50 +08:00 committed by GitHub
parent e97ea2974f
commit c5274e728a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1272,27 +1272,27 @@ def inference_transducer(
nbest: int,
num_workers: int,
log_level: Union[int, str],
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
# data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str],
beam_search_config: Optional[dict],
lm_train_config: Optional[str],
lm_file: Optional[str],
model_tag: Optional[str],
token_type: Optional[str],
bpemodel: Optional[str],
key_file: Optional[str],
allow_variable_data_keys: bool,
quantize_asr_model: Optional[bool],
quantize_modules: Optional[List[str]],
quantize_dtype: Optional[str],
streaming: Optional[bool],
simu_streaming: Optional[bool],
chunk_size: Optional[int],
left_context: Optional[int],
right_context: Optional[int],
display_partial_hypotheses: bool,
cmvn_file: Optional[str] = None,
beam_search_config: Optional[dict] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
model_tag: Optional[str] = None,
token_type: Optional[str] = None,
bpemodel: Optional[str] = None,
key_file: Optional[str] = None,
allow_variable_data_keys: bool = False,
quantize_asr_model: Optional[bool] = False,
quantize_modules: Optional[List[str]] = None,
quantize_dtype: Optional[str] = "float16",
streaming: Optional[bool] = False,
simu_streaming: Optional[bool] = False,
chunk_size: Optional[int] = 16,
left_context: Optional[int] = 16,
right_context: Optional[int] = 0,
display_partial_hypotheses: bool = False,
**kwargs,
) -> None:
"""Transducer model inference.
@ -1327,6 +1327,7 @@ def inference_transducer(
right_context: Number of frames in right context AFTER subsampling.
display_partial_hypotheses: Whether to display partial hypotheses.
"""
# assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
@ -1369,7 +1370,10 @@ def inference_transducer(
left_context=left_context,
right_context=right_context,
)
speech2text = Speech2TextTransducer(**speech2text_kwargs)
speech2text = Speech2TextTransducer.from_pretrained(
model_tag=model_tag,
**speech2text_kwargs,
)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
@ -1388,9 +1392,14 @@ def inference_transducer(
key_file=key_file,
num_workers=num_workers,
)
asr_result_list = []
if output_dir is not None:
writer = DatadirWriter(output_dir)
else:
writer = None
# 4 .Start for-loop
with DatadirWriter(output_dir) as writer:
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
@ -1429,6 +1438,9 @@ def inference_transducer(
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
item = {'key': key, 'value': text}
asr_result_list.append(item)
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
ibest_writer["token"][key] = " ".join(token)
@ -1438,6 +1450,8 @@ def inference_transducer(
if text is not None:
ibest_writer["text"][key] = text
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
return asr_result_list
return _forward