From 571dc8b55a9b036a5b36f968bb3a5baf5858e395 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 21 Feb 2024 14:42:36 +0800 Subject: [PATCH] update raw_text handling --- .../paraformer/demo.py | 13 ++++++---- funasr/auto/auto_model.py | 25 +++++++++++++------ funasr/models/seaco_paraformer/model.py | 3 +-- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py index a0c740625..0265b123e 100644 --- a/examples/industrial_data_pretraining/paraformer/demo.py +++ b/examples/industrial_data_pretraining/paraformer/demo.py @@ -5,11 +5,14 @@ from funasr import AutoModel -model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4", - # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - # vad_model_revision="v2.0.4", - # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", - # punc_model_revision="v2.0.4", +model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + model_revision="v2.0.4", + vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", + vad_model_revision="v2.0.4", + punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", + punc_model_revision="v2.0.4", + # spk_model="iic/speech_campplus_sv_zh-cn_16k-common", + # spk_model_revision="v2.0.2", ) res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index f59bb6b7a..78e47ccd4 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -1,14 +1,13 @@ import json import time +import copy import torch -import hydra import random import string import logging import os.path import numpy as np from tqdm import tqdm -from omegaconf import DictConfig, OmegaConf, ListConfig from funasr.register import tables from funasr.utils.load_utils import load_bytes @@ -17,7 +16,7 @@ from funasr.download.download_from_hub import download_model from funasr.utils.vad_utils import slice_padding_audio_samples from funasr.train_utils.set_all_random_seed import set_all_random_seed from funasr.train_utils.load_pretrained_model import load_pretrained_model -from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from funasr.utils.load_utils import load_audio_text_image_video from funasr.utils.timestamp_tools import timestamp_sentence from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk try: @@ -385,11 +384,15 @@ class AutoModel: if self.punc_model is not None: self.punc_kwargs.update(cfg) punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg) - import copy; raw_text = copy.copy(result["text"]) + raw_text = copy.copy(result["text"]) result["text"] = punc_res[0]["text"] + else: + raw_text = None # speaker embedding cluster after resorted if self.spk_model is not None and kwargs.get('return_spk_res', True): + if raw_text is None: + logging.error("Missing punc_model, which is required by spk_model.") all_segments = sorted(all_segments, key=lambda x: x[0]) spk_embedding = result['spk_embedding'] labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs.get('preset_spk_num', None)) @@ -398,20 +401,28 @@ class AutoModel: if self.spk_mode == 'vad_segment': # recover sentence_list sentence_list = [] for res, vadsegment in zip(restored_data, vadsegments): + if 'timestamp' not in res: + logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \ + and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\ + can predict timestamp, and speaker diarization relies on timestamps.") sentence_list.append({"start": vadsegment[0],\ "end": vadsegment[1], - "sentence": res['raw_text'], + "sentence": res['text'], "timestamp": res['timestamp']}) elif self.spk_mode == 'punc_segment': + if 'timestamp' not in result: + logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \ + and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\ + can predict timestamp, and speaker diarization relies on timestamps.") sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ result['timestamp'], \ - result['raw_text']) + raw_text) distribute_spk(sentence_list, sv_output) result['sentence_info'] = sentence_list elif kwargs.get("sentence_timestamp", False): sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ result['timestamp'], \ - result['raw_text']) + raw_text) result['sentence_info'] = sentence_list if "spk_embedding" in result: del result['spk_embedding'] diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py index 06103963b..caf2b15c7 100644 --- a/funasr/models/seaco_paraformer/model.py +++ b/funasr/models/seaco_paraformer/model.py @@ -415,12 +415,11 @@ class SeacoParaformer(BiCifParaformer, Paraformer): token, timestamp) result_i = {"key": key[i], "text": text_postprocessed, - "timestamp": time_stamp_postprocessed, "raw_text": copy.copy(text_postprocessed) + "timestamp": time_stamp_postprocessed } if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) - # ibest_writer["raw_text"][key[i]] = text ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed ibest_writer["text"][key[i]] = text_postprocessed else: