From 509bc889037a42204ffff9a1c014560c69f676ad Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 21 Feb 2024 15:07:19 +0800 Subject: [PATCH] update raw_text related funcs --- .../seaco_paraformer/demo.py | 7 ++-- funasr/auto/auto_model.py | 34 +++++++++++-------- funasr/models/paraformer/model.py | 1 - funasr/utils/timestamp_tools.py | 25 +++++++++----- 4 files changed, 40 insertions(+), 27 deletions(-) diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index 804acddb5..a44c649ae 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -11,15 +11,16 @@ model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-com vad_model_revision="v2.0.4", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model_revision="v2.0.4", - spk_model="damo/speech_campplus_sv_zh-cn_16k-common", - spk_model_revision="v2.0.2", + # spk_model="damo/speech_campplus_sv_zh-cn_16k-common", + # spk_model_revision="v2.0.2", ) # example1 res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", hotword='达摩院 魔搭', - # preset_spk_num=2, + # return_raw_text=True, # return raw text recognition results splited by space of equal length with timestamp + # preset_spk_num=2, # preset speaker num for speaker cluster model # sentence_timestamp=True, # return sentence level information when spk_model is not given ) print(res) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 78e47ccd4..e5faa2aaa 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -379,12 +379,14 @@ class AutoModel: result[k] = restored_data[j][k] else: result[k] += restored_data[j][k] - + + return_raw_text = kwargs.get('return_raw_text', False) # step.3 compute punc model if self.punc_model is not None: self.punc_kwargs.update(cfg) punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg) raw_text = copy.copy(result["text"]) + if return_raw_text: result['raw_text'] = raw_text result["text"] = punc_res[0]["text"] else: raw_text = None @@ -403,26 +405,28 @@ class AutoModel: for res, vadsegment in zip(restored_data, vadsegments): if 'timestamp' not in res: logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \ - and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\ - can predict timestamp, and speaker diarization relies on timestamps.") - sentence_list.append({"start": vadsegment[0],\ - "end": vadsegment[1], - "sentence": res['text'], - "timestamp": res['timestamp']}) + and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\ + can predict timestamp, and speaker diarization relies on timestamps.") + sentence_list.append({"start": vadsegment[0], + "end": vadsegment[1], + "sentence": res['text'], + "timestamp": res['timestamp']}) elif self.spk_mode == 'punc_segment': if 'timestamp' not in result: logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \ - and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\ - can predict timestamp, and speaker diarization relies on timestamps.") - sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ - result['timestamp'], \ - raw_text) + and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\ + can predict timestamp, and speaker diarization relies on timestamps.") + sentence_list = timestamp_sentence(punc_res[0]['punc_array'], + result['timestamp'], + raw_text, + return_raw_text=return_raw_text) distribute_spk(sentence_list, sv_output) result['sentence_info'] = sentence_list elif kwargs.get("sentence_timestamp", False): - sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ - result['timestamp'], \ - raw_text) + sentence_list = timestamp_sentence(punc_res[0]['punc_array'], + result['timestamp'], + raw_text, + return_raw_text=return_raw_text) result['sentence_info'] = sentence_list if "spk_embedding" in result: del result['spk_embedding'] diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index cf31cdb8b..729b8f500 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -537,7 +537,6 @@ class Paraformer(torch.nn.Module): result_i = {"key": key[i], "text": text_postprocessed} - if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) # ibest_writer["text"][key[i]] = text diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index 63f179a08..32f0f8488 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -98,7 +98,7 @@ def ts_prediction_lfr6_standard(us_alphas, return res_txt, res -def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed): +def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False): punc_list = [',', '。', '?', '、'] res = [] if text_postprocessed is None: @@ -142,15 +142,24 @@ def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed punc_id = int(punc_id) if punc_id is not None else 1 sentence_end = timestamp[1] if timestamp is not None else sentence_end - + sentence_text_seg = sentence_text_seg[:-1] if sentence_text_seg[-1] == ' ' else sentence_text_seg if punc_id > 1: sentence_text += punc_list[punc_id - 2] - res.append({ - 'text': sentence_text, - "start": sentence_start, - "end": sentence_end, - "timestamp": ts_list - }) + if return_raw_text: + res.append({ + 'text': sentence_text, + "start": sentence_start, + "end": sentence_end, + "timestamp": ts_list, + 'raw_text': sentence_text_seg, + }) + else: + res.append({ + 'text': sentence_text, + "start": sentence_start, + "end": sentence_end, + "timestamp": ts_list, + }) sentence_text = '' sentence_text_seg = '' ts_list = []