update raw_text handling

This commit is contained in:
shixian.shi 2024-02-21 14:42:36 +08:00
parent f91205399f
commit 571dc8b55a
3 changed files with 27 additions and 14 deletions

View File

@ -5,11 +5,14 @@
from funasr import AutoModel from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4", model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
# vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4",
# vad_model_revision="v2.0.4", vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", vad_model_revision="v2.0.4",
# punc_model_revision="v2.0.4", punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.4",
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
# spk_model_revision="v2.0.2",
) )
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")

View File

@ -1,14 +1,13 @@
import json import json
import time import time
import copy
import torch import torch
import hydra
import random import random
import string import string
import logging import logging
import os.path import os.path
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.register import tables from funasr.register import tables
from funasr.utils.load_utils import load_bytes from funasr.utils.load_utils import load_bytes
@ -17,7 +16,7 @@ from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank from funasr.utils.load_utils import load_audio_text_image_video
from funasr.utils.timestamp_tools import timestamp_sentence from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
try: try:
@ -385,11 +384,15 @@ class AutoModel:
if self.punc_model is not None: if self.punc_model is not None:
self.punc_kwargs.update(cfg) self.punc_kwargs.update(cfg)
punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg) punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg)
import copy; raw_text = copy.copy(result["text"]) raw_text = copy.copy(result["text"])
result["text"] = punc_res[0]["text"] result["text"] = punc_res[0]["text"]
else:
raw_text = None
# speaker embedding cluster after resorted # speaker embedding cluster after resorted
if self.spk_model is not None and kwargs.get('return_spk_res', True): if self.spk_model is not None and kwargs.get('return_spk_res', True):
if raw_text is None:
logging.error("Missing punc_model, which is required by spk_model.")
all_segments = sorted(all_segments, key=lambda x: x[0]) all_segments = sorted(all_segments, key=lambda x: x[0])
spk_embedding = result['spk_embedding'] spk_embedding = result['spk_embedding']
labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs.get('preset_spk_num', None)) labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs.get('preset_spk_num', None))
@ -398,20 +401,28 @@ class AutoModel:
if self.spk_mode == 'vad_segment': # recover sentence_list if self.spk_mode == 'vad_segment': # recover sentence_list
sentence_list = [] sentence_list = []
for res, vadsegment in zip(restored_data, vadsegments): for res, vadsegment in zip(restored_data, vadsegments):
if 'timestamp' not in res:
logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
can predict timestamp, and speaker diarization relies on timestamps.")
sentence_list.append({"start": vadsegment[0],\ sentence_list.append({"start": vadsegment[0],\
"end": vadsegment[1], "end": vadsegment[1],
"sentence": res['raw_text'], "sentence": res['text'],
"timestamp": res['timestamp']}) "timestamp": res['timestamp']})
elif self.spk_mode == 'punc_segment': elif self.spk_mode == 'punc_segment':
if 'timestamp' not in result:
logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
can predict timestamp, and speaker diarization relies on timestamps.")
sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
result['timestamp'], \ result['timestamp'], \
result['raw_text']) raw_text)
distribute_spk(sentence_list, sv_output) distribute_spk(sentence_list, sv_output)
result['sentence_info'] = sentence_list result['sentence_info'] = sentence_list
elif kwargs.get("sentence_timestamp", False): elif kwargs.get("sentence_timestamp", False):
sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
result['timestamp'], \ result['timestamp'], \
result['raw_text']) raw_text)
result['sentence_info'] = sentence_list result['sentence_info'] = sentence_list
if "spk_embedding" in result: del result['spk_embedding'] if "spk_embedding" in result: del result['spk_embedding']

View File

@ -415,12 +415,11 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
token, timestamp) token, timestamp)
result_i = {"key": key[i], "text": text_postprocessed, result_i = {"key": key[i], "text": text_postprocessed,
"timestamp": time_stamp_postprocessed, "raw_text": copy.copy(text_postprocessed) "timestamp": time_stamp_postprocessed
} }
if ibest_writer is not None: if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token) ibest_writer["token"][key[i]] = " ".join(token)
# ibest_writer["raw_text"][key[i]] = text
ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
ibest_writer["text"][key[i]] = text_postprocessed ibest_writer["text"][key[i]] = text_postprocessed
else: else: