From 8cd93a1fc77e8a4a119f248ec0eb9018a74f728b Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Tue, 21 Mar 2023 19:41:40 +0800 Subject: [PATCH] asr_inference pipeline supports combine tp model --- funasr/bin/asr_inference_paraformer.py | 24 +++++++++++++----------- funasr/bin/tp_inference.py | 4 ++-- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py index 7e159fd1c..2eeffcd91 100644 --- a/funasr/bin/asr_inference_paraformer.py +++ b/funasr/bin/asr_inference_paraformer.py @@ -756,16 +756,18 @@ def inference_modelscope( key = keys[batch_id] for n, result in zip(range(1, nbest + 1), result): - # import pdb; pdb.set_trace() text, token, token_int, hyp = result[0], result[1], result[2], result[3] - time_stamp = None if len(result) < 5 else result[4] + timestamp = None if len(result) < 5 else result[4] # conduct timestamp prediction here - if time_stamp is None and speechtext2timestamp: + # timestamp inference requires token length + # thus following inference cannot be conducted in batch + if timestamp is None and speechtext2timestamp: ts_batch = {} - ts_batch['speech'] = batch['speech'][batch_id].squeeze(0) + ts_batch['speech'] = batch['speech'][batch_id].unsqueeze(0) ts_batch['speech_lengths'] = torch.tensor([batch['speech_lengths'][batch_id]]) ts_batch['text_lengths'] = torch.tensor([len(token)]) - import pdb; pdb.set_trace() + us_alphas, us_peaks = speechtext2timestamp(**ts_batch) + ts_str, timestamp = ts_prediction_lfr6_standard(us_alphas[0], us_peaks[0], token, force_time_shift=-3.0) # Create a directory: outdir/{n}best_recog if writer is not None: ibest_writer = writer[f"{n}best_recog"] @@ -777,20 +779,20 @@ def inference_modelscope( ibest_writer["rtf"][key] = rtf_cur if text is not None: - if use_timestamp and time_stamp is not None: - postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) + if use_timestamp and timestamp is not None: + postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp) else: postprocessed_result = postprocess_utils.sentence_postprocess(token) - time_stamp_postprocessed = "" + timestamp_postprocessed = "" if len(postprocessed_result) == 3: - text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \ + text_postprocessed, timestamp_postprocessed, word_lists = postprocessed_result[0], \ postprocessed_result[1], \ postprocessed_result[2] else: text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] item = {'key': key, 'value': text_postprocessed} - if time_stamp_postprocessed != "": - item['time_stamp'] = time_stamp_postprocessed + if timestamp_postprocessed != "": + item['timestamp'] = timestamp_postprocessed asr_result_list.append(item) finish_count += 1 # asr_utils.print_progress(finish_count / file_count) diff --git a/funasr/bin/tp_inference.py b/funasr/bin/tp_inference.py index e374a227a..6360b17db 100644 --- a/funasr/bin/tp_inference.py +++ b/funasr/bin/tp_inference.py @@ -116,8 +116,8 @@ class SpeechText2Timestamp: enc = enc[0] # c. Forward Predictor - _, _, us_alphas, us_cif_peak = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1) - return us_alphas, us_cif_peak + _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1) + return us_alphas, us_peaks def inference(