Update asr_inference_launch.py (#719)

update bat infer for modelscope
This commit is contained in:
aky15 2023-07-10 12:48:50 +08:00 committed by GitHub
parent e97ea2974f
commit c5274e728a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1272,27 +1272,27 @@ def inference_transducer(
nbest: int, nbest: int,
num_workers: int, num_workers: int,
log_level: Union[int, str], log_level: Union[int, str],
data_path_and_name_and_type: Sequence[Tuple[str, str, str]], # data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
asr_train_config: Optional[str], asr_train_config: Optional[str],
asr_model_file: Optional[str], asr_model_file: Optional[str],
cmvn_file: Optional[str], cmvn_file: Optional[str] = None,
beam_search_config: Optional[dict], beam_search_config: Optional[dict] = None,
lm_train_config: Optional[str], lm_train_config: Optional[str] = None,
lm_file: Optional[str], lm_file: Optional[str] = None,
model_tag: Optional[str], model_tag: Optional[str] = None,
token_type: Optional[str], token_type: Optional[str] = None,
bpemodel: Optional[str], bpemodel: Optional[str] = None,
key_file: Optional[str], key_file: Optional[str] = None,
allow_variable_data_keys: bool, allow_variable_data_keys: bool = False,
quantize_asr_model: Optional[bool], quantize_asr_model: Optional[bool] = False,
quantize_modules: Optional[List[str]], quantize_modules: Optional[List[str]] = None,
quantize_dtype: Optional[str], quantize_dtype: Optional[str] = "float16",
streaming: Optional[bool], streaming: Optional[bool] = False,
simu_streaming: Optional[bool], simu_streaming: Optional[bool] = False,
chunk_size: Optional[int], chunk_size: Optional[int] = 16,
left_context: Optional[int], left_context: Optional[int] = 16,
right_context: Optional[int], right_context: Optional[int] = 0,
display_partial_hypotheses: bool, display_partial_hypotheses: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
"""Transducer model inference. """Transducer model inference.
@ -1327,6 +1327,7 @@ def inference_transducer(
right_context: Number of frames in right context AFTER subsampling. right_context: Number of frames in right context AFTER subsampling.
display_partial_hypotheses: Whether to display partial hypotheses. display_partial_hypotheses: Whether to display partial hypotheses.
""" """
# assert check_argument_types()
if batch_size > 1: if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented") raise NotImplementedError("batch decoding is not implemented")
@ -1369,7 +1370,10 @@ def inference_transducer(
left_context=left_context, left_context=left_context,
right_context=right_context, right_context=right_context,
) )
speech2text = Speech2TextTransducer(**speech2text_kwargs) speech2text = Speech2TextTransducer.from_pretrained(
model_tag=model_tag,
**speech2text_kwargs,
)
def _forward(data_path_and_name_and_type, def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None, raw_inputs: Union[np.ndarray, torch.Tensor] = None,
@ -1388,9 +1392,14 @@ def inference_transducer(
key_file=key_file, key_file=key_file,
num_workers=num_workers, num_workers=num_workers,
) )
asr_result_list = []
if output_dir is not None:
writer = DatadirWriter(output_dir)
else:
writer = None
# 4 .Start for-loop # 4 .Start for-loop
with DatadirWriter(output_dir) as writer:
for keys, batch in loader: for keys, batch in loader:
assert isinstance(batch, dict), type(batch) assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys assert all(isinstance(s, str) for s in keys), keys
@ -1429,6 +1438,9 @@ def inference_transducer(
key = keys[0] key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
item = {'key': key, 'value': text}
asr_result_list.append(item)
if writer is not None:
ibest_writer = writer[f"{n}best_recog"] ibest_writer = writer[f"{n}best_recog"]
ibest_writer["token"][key] = " ".join(token) ibest_writer["token"][key] = " ".join(token)
@ -1438,6 +1450,8 @@ def inference_transducer(
if text is not None: if text is not None:
ibest_writer["text"][key] = text ibest_writer["text"][key] = text
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
return asr_result_list
return _forward return _forward