From 3e9319263835bd018abe2dcd59e029603b714022 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BB=B4=E7=9F=B3?= Date: Tue, 11 Jun 2024 11:52:26 +0800 Subject: [PATCH] english timestamp for valilla paraformer --- .../paraformer/demo.py | 13 +++ funasr/auto/auto_model.py | 43 +++++++--- funasr/utils/timestamp_tools.py | 85 +++++++++++++++++++ 3 files changed, 128 insertions(+), 13 deletions(-) diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py index eb7e72f74..647669f42 100644 --- a/examples/industrial_data_pretraining/paraformer/demo.py +++ b/examples/industrial_data_pretraining/paraformer/demo.py @@ -21,6 +21,19 @@ res = model.generate( print(res) +""" call english model like below for detailed timestamps +# choose english paraformer model first +# iic/speech_paraformer_asr-en-16k-vocab4199-pytorch +res = model.generate( + input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav", + cache={}, + pred_timestamp=True, + return_raw_text=True, + sentence_timestamp=True, + en_post_proc=True, +) +""" + """ can not use currently from funasr import AutoFrontend diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 7b5a02fff..fb81608c6 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -19,6 +19,7 @@ from funasr.register import tables from funasr.utils.load_utils import load_bytes from funasr.download.file import download_from_url from funasr.utils.timestamp_tools import timestamp_sentence +from funasr.utils.timestamp_tools import timestamp_sentence_en from funasr.download.download_from_hub import download_model from funasr.utils.vad_utils import slice_padding_audio_samples from funasr.utils.vad_utils import merge_vad @@ -321,7 +322,7 @@ class AutoModel: input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg ) end_vad = time.time() - + # FIX(gcf): concat the vad clips for sense vocie model for better aed if kwargs.get("merge_vad", False): for i in range(len(res)): @@ -513,24 +514,40 @@ class AutoModel: and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\ can predict timestamp, and speaker diarization relies on timestamps." ) - sentence_list = timestamp_sentence( - punc_res[0]["punc_array"], - result["timestamp"], - raw_text, - return_raw_text=return_raw_text, - ) + if kwargs.get("en_post_proc", False): + sentence_list = timestamp_sentence_en( + punc_res[0]["punc_array"], + result["timestamp"], + raw_text, + return_raw_text=return_raw_text, + ) + else: + sentence_list = timestamp_sentence( + punc_res[0]["punc_array"], + result["timestamp"], + raw_text, + return_raw_text=return_raw_text, + ) distribute_spk(sentence_list, sv_output) result["sentence_info"] = sentence_list elif kwargs.get("sentence_timestamp", False): if not len(result["text"].strip()): sentence_list = [] else: - sentence_list = timestamp_sentence( - punc_res[0]["punc_array"], - result["timestamp"], - raw_text, - return_raw_text=return_raw_text, - ) + if kwargs.get("en_post_proc", False): + sentence_list = timestamp_sentence_en( + punc_res[0]["punc_array"], + result["timestamp"], + raw_text, + return_raw_text=return_raw_text, + ) + else: + sentence_list = timestamp_sentence( + punc_res[0]["punc_array"], + result["timestamp"], + raw_text, + return_raw_text=return_raw_text, + ) result["sentence_info"] = sentence_list if "spk_embedding" in result: del result["spk_embedding"] diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index af61e5a8f..6abebe165 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -185,3 +185,88 @@ def timestamp_sentence( ts_list = [] sentence_start = sentence_end return res + + +def timestamp_sentence_en( + punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False +): + punc_list = [",", ".", "?", ","] + res = [] + if text_postprocessed is None: + return res + if timestamp_postprocessed is None: + return res + if len(timestamp_postprocessed) == 0: + return res + if len(text_postprocessed) == 0: + return res + + if punc_id_list is None or len(punc_id_list) == 0: + res.append( + { + "text": text_postprocessed.split(), + "start": timestamp_postprocessed[0][0], + "end": timestamp_postprocessed[-1][1], + "timestamp": timestamp_postprocessed, + } + ) + return res + if len(punc_id_list) != len(timestamp_postprocessed): + logging.warning("length mismatch between punc and timestamp") + sentence_text = "" + sentence_text_seg = "" + ts_list = [] + sentence_start = timestamp_postprocessed[0][0] + sentence_end = timestamp_postprocessed[0][1] + texts = text_postprocessed.split() + punc_stamp_text_list = list( + zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None) + ) + for punc_stamp_text in punc_stamp_text_list: + punc_id, timestamp, text = punc_stamp_text + # sentence_text += text if text is not None else '' + if text is not None: + if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z": + sentence_text += " " + text + elif len(sentence_text) and ( + "a" <= sentence_text[-1] <= "z" or "A" <= sentence_text[-1] <= "Z" + ): + sentence_text += " " + text + else: + sentence_text += text + sentence_text_seg += text + " " + ts_list.append(timestamp) + + punc_id = int(punc_id) if punc_id is not None else 1 + sentence_end = timestamp[1] if timestamp is not None else sentence_end + sentence_text = sentence_text[1:] if sentence_text[0] == ' ' else sentence_text + + if punc_id > 1: + sentence_text += punc_list[punc_id - 2] + sentence_text_seg = ( + sentence_text_seg[:-1] if sentence_text_seg[-1] == " " else sentence_text_seg + ) + if return_raw_text: + res.append( + { + "text": sentence_text, + "start": sentence_start, + "end": sentence_end, + "timestamp": ts_list, + "raw_text": sentence_text_seg, + } + ) + else: + res.append( + { + "text": sentence_text, + "start": sentence_start, + "end": sentence_end, + "timestamp": ts_list, + } + ) + sentence_text = "" + sentence_text_seg = "" + ts_list = [] + sentence_start = sentence_end + return res \ No newline at end of file