update raw_text handling

This commit is contained in:
shixian.shi 2024-02-21 14:42:36 +08:00
parent f91205399f
commit 571dc8b55a
3 changed files with 27 additions and 14 deletions

View File

@ -5,11 +5,14 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4",
# vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# vad_model_revision="v2.0.4",
# punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
# punc_model_revision="v2.0.4",
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.4",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.4",
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.4",
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
# spk_model_revision="v2.0.2",
)
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")

View File

@ -1,14 +1,13 @@
import json
import time
import copy
import torch
import hydra
import random
import string
import logging
import os.path
import numpy as np
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
@ -17,7 +16,7 @@ from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
try:
@ -385,11 +384,15 @@ class AutoModel:
if self.punc_model is not None:
self.punc_kwargs.update(cfg)
punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg)
import copy; raw_text = copy.copy(result["text"])
raw_text = copy.copy(result["text"])
result["text"] = punc_res[0]["text"]
else:
raw_text = None
# speaker embedding cluster after resorted
if self.spk_model is not None and kwargs.get('return_spk_res', True):
if raw_text is None:
logging.error("Missing punc_model, which is required by spk_model.")
all_segments = sorted(all_segments, key=lambda x: x[0])
spk_embedding = result['spk_embedding']
labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs.get('preset_spk_num', None))
@ -398,20 +401,28 @@ class AutoModel:
if self.spk_mode == 'vad_segment': # recover sentence_list
sentence_list = []
for res, vadsegment in zip(restored_data, vadsegments):
if 'timestamp' not in res:
logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
can predict timestamp, and speaker diarization relies on timestamps.")
sentence_list.append({"start": vadsegment[0],\
"end": vadsegment[1],
"sentence": res['raw_text'],
"sentence": res['text'],
"timestamp": res['timestamp']})
elif self.spk_mode == 'punc_segment':
if 'timestamp' not in result:
logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
can predict timestamp, and speaker diarization relies on timestamps.")
sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
result['timestamp'], \
result['raw_text'])
raw_text)
distribute_spk(sentence_list, sv_output)
result['sentence_info'] = sentence_list
elif kwargs.get("sentence_timestamp", False):
sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
result['timestamp'], \
result['raw_text'])
raw_text)
result['sentence_info'] = sentence_list
if "spk_embedding" in result: del result['spk_embedding']

View File

@ -415,12 +415,11 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
token, timestamp)
result_i = {"key": key[i], "text": text_postprocessed,
"timestamp": time_stamp_postprocessed, "raw_text": copy.copy(text_postprocessed)
"timestamp": time_stamp_postprocessed
}
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
# ibest_writer["raw_text"][key[i]] = text
ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
ibest_writer["text"][key[i]] = text_postprocessed
else: