Merge pull request #1801 from modelscope/dev_sx2

english timestamp for vanilla paraformer
This commit is contained in:
Shi Xian 2024-06-12 11:20:06 +08:00 committed by GitHub
commit 1300d38bf0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 128 additions and 13 deletions

View File

@ -21,6 +21,19 @@ res = model.generate(
print(res)
""" call english model like below for detailed timestamps
# choose english paraformer model first
# iic/speech_paraformer_asr-en-16k-vocab4199-pytorch
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav",
cache={},
pred_timestamp=True,
return_raw_text=True,
sentence_timestamp=True,
en_post_proc=True,
)
"""
""" can not use currently
from funasr import AutoFrontend

View File

@ -19,6 +19,7 @@ from funasr.register import tables
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.utils.timestamp_tools import timestamp_sentence_en
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.utils.vad_utils import merge_vad
@ -323,7 +324,7 @@ class AutoModel:
input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg
)
end_vad = time.time()
# FIX(gcf): concat the vad clips for sense vocie model for better aed
if kwargs.get("merge_vad", False):
for i in range(len(res)):
@ -519,24 +520,40 @@ class AutoModel:
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
can predict timestamp, and speaker diarization relies on timestamps."
)
sentence_list = timestamp_sentence(
punc_res[0]["punc_array"],
result["timestamp"],
raw_text,
return_raw_text=return_raw_text,
)
if kwargs.get("en_post_proc", False):
sentence_list = timestamp_sentence_en(
punc_res[0]["punc_array"],
result["timestamp"],
raw_text,
return_raw_text=return_raw_text,
)
else:
sentence_list = timestamp_sentence(
punc_res[0]["punc_array"],
result["timestamp"],
raw_text,
return_raw_text=return_raw_text,
)
distribute_spk(sentence_list, sv_output)
result["sentence_info"] = sentence_list
elif kwargs.get("sentence_timestamp", False):
if not len(result["text"].strip()):
sentence_list = []
else:
sentence_list = timestamp_sentence(
punc_res[0]["punc_array"],
result["timestamp"],
raw_text,
return_raw_text=return_raw_text,
)
if kwargs.get("en_post_proc", False):
sentence_list = timestamp_sentence_en(
punc_res[0]["punc_array"],
result["timestamp"],
raw_text,
return_raw_text=return_raw_text,
)
else:
sentence_list = timestamp_sentence(
punc_res[0]["punc_array"],
result["timestamp"],
raw_text,
return_raw_text=return_raw_text,
)
result["sentence_info"] = sentence_list
if "spk_embedding" in result:
del result["spk_embedding"]

View File

@ -185,3 +185,88 @@ def timestamp_sentence(
ts_list = []
sentence_start = sentence_end
return res
def timestamp_sentence_en(
punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
):
punc_list = [",", ".", "?", ","]
res = []
if text_postprocessed is None:
return res
if timestamp_postprocessed is None:
return res
if len(timestamp_postprocessed) == 0:
return res
if len(text_postprocessed) == 0:
return res
if punc_id_list is None or len(punc_id_list) == 0:
res.append(
{
"text": text_postprocessed.split(),
"start": timestamp_postprocessed[0][0],
"end": timestamp_postprocessed[-1][1],
"timestamp": timestamp_postprocessed,
}
)
return res
if len(punc_id_list) != len(timestamp_postprocessed):
logging.warning("length mismatch between punc and timestamp")
sentence_text = ""
sentence_text_seg = ""
ts_list = []
sentence_start = timestamp_postprocessed[0][0]
sentence_end = timestamp_postprocessed[0][1]
texts = text_postprocessed.split()
punc_stamp_text_list = list(
zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None)
)
for punc_stamp_text in punc_stamp_text_list:
punc_id, timestamp, text = punc_stamp_text
# sentence_text += text if text is not None else ''
if text is not None:
if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z":
sentence_text += " " + text
elif len(sentence_text) and (
"a" <= sentence_text[-1] <= "z" or "A" <= sentence_text[-1] <= "Z"
):
sentence_text += " " + text
else:
sentence_text += text
sentence_text_seg += text + " "
ts_list.append(timestamp)
punc_id = int(punc_id) if punc_id is not None else 1
sentence_end = timestamp[1] if timestamp is not None else sentence_end
sentence_text = sentence_text[1:] if sentence_text[0] == ' ' else sentence_text
if punc_id > 1:
sentence_text += punc_list[punc_id - 2]
sentence_text_seg = (
sentence_text_seg[:-1] if sentence_text_seg[-1] == " " else sentence_text_seg
)
if return_raw_text:
res.append(
{
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"timestamp": ts_list,
"raw_text": sentence_text_seg,
}
)
else:
res.append(
{
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"timestamp": ts_list,
}
)
sentence_text = ""
sentence_text_seg = ""
ts_list = []
sentence_start = sentence_end
return res