Update asr_inference_launch.py (#719)

update bat infer for modelscope
This commit is contained in:
aky15 2023-07-10 12:48:50 +08:00 committed by GitHub
parent e97ea2974f
commit c5274e728a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1272,27 +1272,27 @@ def inference_transducer(
nbest: int, nbest: int,
num_workers: int, num_workers: int,
log_level: Union[int, str], log_level: Union[int, str],
data_path_and_name_and_type: Sequence[Tuple[str, str, str]], # data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
asr_train_config: Optional[str], asr_train_config: Optional[str],
asr_model_file: Optional[str], asr_model_file: Optional[str],
cmvn_file: Optional[str], cmvn_file: Optional[str] = None,
beam_search_config: Optional[dict], beam_search_config: Optional[dict] = None,
lm_train_config: Optional[str], lm_train_config: Optional[str] = None,
lm_file: Optional[str], lm_file: Optional[str] = None,
model_tag: Optional[str], model_tag: Optional[str] = None,
token_type: Optional[str], token_type: Optional[str] = None,
bpemodel: Optional[str], bpemodel: Optional[str] = None,
key_file: Optional[str], key_file: Optional[str] = None,
allow_variable_data_keys: bool, allow_variable_data_keys: bool = False,
quantize_asr_model: Optional[bool], quantize_asr_model: Optional[bool] = False,
quantize_modules: Optional[List[str]], quantize_modules: Optional[List[str]] = None,
quantize_dtype: Optional[str], quantize_dtype: Optional[str] = "float16",
streaming: Optional[bool], streaming: Optional[bool] = False,
simu_streaming: Optional[bool], simu_streaming: Optional[bool] = False,
chunk_size: Optional[int], chunk_size: Optional[int] = 16,
left_context: Optional[int], left_context: Optional[int] = 16,
right_context: Optional[int], right_context: Optional[int] = 0,
display_partial_hypotheses: bool, display_partial_hypotheses: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
"""Transducer model inference. """Transducer model inference.
@ -1327,6 +1327,7 @@ def inference_transducer(
right_context: Number of frames in right context AFTER subsampling. right_context: Number of frames in right context AFTER subsampling.
display_partial_hypotheses: Whether to display partial hypotheses. display_partial_hypotheses: Whether to display partial hypotheses.
""" """
# assert check_argument_types()
if batch_size > 1: if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented") raise NotImplementedError("batch decoding is not implemented")
@ -1369,7 +1370,10 @@ def inference_transducer(
left_context=left_context, left_context=left_context,
right_context=right_context, right_context=right_context,
) )
speech2text = Speech2TextTransducer(**speech2text_kwargs) speech2text = Speech2TextTransducer.from_pretrained(
model_tag=model_tag,
**speech2text_kwargs,
)
def _forward(data_path_and_name_and_type, def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None, raw_inputs: Union[np.ndarray, torch.Tensor] = None,
@ -1388,47 +1392,55 @@ def inference_transducer(
key_file=key_file, key_file=key_file,
num_workers=num_workers, num_workers=num_workers,
) )
asr_result_list = []
if output_dir is not None:
writer = DatadirWriter(output_dir)
else:
writer = None
# 4 .Start for-loop # 4 .Start for-loop
with DatadirWriter(output_dir) as writer: for keys, batch in loader:
for keys, batch in loader: assert isinstance(batch, dict), type(batch)
assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values()))) _bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}" assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
assert len(batch.keys()) == 1 assert len(batch.keys()) == 1
try: try:
if speech2text.streaming: if speech2text.streaming:
speech = batch["speech"] speech = batch["speech"]
_steps = len(speech) // speech2text._ctx _steps = len(speech) // speech2text._ctx
_end = 0 _end = 0
for i in range(_steps): for i in range(_steps):
_end = (i + 1) * speech2text._ctx _end = (i + 1) * speech2text._ctx
speech2text.streaming_decode( speech2text.streaming_decode(
speech[i * speech2text._ctx: _end], is_final=False speech[i * speech2text._ctx: _end], is_final=False
)
final_hyps = speech2text.streaming_decode(
speech[_end: len(speech)], is_final=True
) )
elif speech2text.simu_streaming:
final_hyps = speech2text.simu_streaming_decode(**batch)
else:
final_hyps = speech2text(**batch)
results = speech2text.hypotheses_to_results(final_hyps) final_hyps = speech2text.streaming_decode(
except TooShortUttError as e: speech[_end: len(speech)], is_final=True
logging.warning(f"Utterance {keys} {e}") )
hyp = Hypothesis(score=0.0, yseq=[], dec_state=None) elif speech2text.simu_streaming:
results = [[" ", ["<space>"], [2], hyp]] * nbest final_hyps = speech2text.simu_streaming_decode(**batch)
else:
final_hyps = speech2text(**batch)
key = keys[0] results = speech2text.hypotheses_to_results(final_hyps)
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
results = [[" ", ["<space>"], [2], hyp]] * nbest
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
item = {'key': key, 'value': text}
asr_result_list.append(item)
if writer is not None:
ibest_writer = writer[f"{n}best_recog"] ibest_writer = writer[f"{n}best_recog"]
ibest_writer["token"][key] = " ".join(token) ibest_writer["token"][key] = " ".join(token)
@ -1438,6 +1450,8 @@ def inference_transducer(
if text is not None: if text is not None:
ibest_writer["text"][key] = text ibest_writer["text"][key] = text
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
return asr_result_list
return _forward return _forward