funasr1.0 fix punc model

This commit is contained in:
游雁 2024-01-13 22:42:18 +08:00
parent c0b186b5b6
commit 835369d631
11 changed files with 125 additions and 31 deletions

View File

@ -10,7 +10,7 @@ model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
punc_model_revision="v2.0.1",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
)
@ -23,7 +23,7 @@ model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
punc_model_revision="v2.0.1",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
spk_mode='punc_segment',
)

View File

@ -4,7 +4,7 @@ model_revision="v2.0.0"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.0"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.0"
punc_model_revision="v2.0.1"
python funasr/bin/inference.py \
+model=${model} \

View File

@ -5,7 +5,15 @@
from funasr import AutoModel
model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.0")
model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.1")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
from funasr import AutoModel
model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.1")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)

View File

@ -1,6 +1,9 @@
model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
model_revision="v2.0.0"
model_revision="v2.0.1"
model="damo/punc_ct-transformer_cn-en-common-vocab471067-large"
model_revision="v2.0.1"
python funasr/bin/inference.py \
+model=${model} \

View File

@ -10,7 +10,7 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
punc_model_revision="v2.0.1",
spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
spk_model_revision="v2.0.0"
)

View File

@ -4,7 +4,7 @@ model_revision="v2.0.0"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.1"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.0"
punc_model_revision="v2.0.1"
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
spk_model_revision="v2.0.0"

View File

@ -10,7 +10,7 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
punc_model_revision="v2.0.1",
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",

View File

@ -4,7 +4,7 @@ model_revision="v2.0.0"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.1"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.0"
punc_model_revision="v2.0.1"
python funasr/bin/inference.py \
+model=${model} \

View File

@ -37,6 +37,8 @@ def download_from_ms(**kwargs):
kwargs["model"] = cfg["model"]
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
else:# configuration.json
assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:

View File

@ -225,8 +225,14 @@ class CTTransformer(nn.Module):
# text = data_in[0]
# text_lengths = data_lengths[0] if data_lengths is not None else None
split_size = kwargs.get("split_size", 20)
tokens = split_words(text)
jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
if jieba_usr_dict and isinstance(jieba_usr_dict, str):
import jieba
jieba.load_userdict(jieba_usr_dict)
jieba_usr_dict = jieba
kwargs["jieba_usr_dict"] = "jieba_usr_dict"
tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
tokens_int = tokenizer.encode(tokens)
mini_sentences = split_to_mini_sentence(tokens, split_size)

View File

@ -1,4 +1,4 @@
import re
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
@ -14,23 +14,98 @@ def split_to_mini_sentence(words: list, word_limit: int = 20):
return sentences
def split_words(text: str):
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
# def split_words(text: str, **kwargs):
# words = []
# segs = text.split()
# for seg in segs:
# # There is no space in seg.
# current_word = ""
# for c in seg:
# if len(c.encode()) == 1:
# # This is an ASCII char.
# current_word += c
# else:
# # This is a Chinese char.
# if len(current_word) > 0:
# words.append(current_word)
# current_word = ""
# words.append(c)
# if len(current_word) > 0:
# words.append(current_word)
#
# return words
def split_words(text: str, jieba_usr_dict=None, **kwargs):
if jieba_usr_dict:
input_list = text.split()
token_list_all = []
langauge_list = []
token_list_tmp = []
language_flag = None
for token in input_list:
if isEnglish(token) and language_flag == 'Chinese':
token_list_all.append(token_list_tmp)
langauge_list.append('Chinese')
token_list_tmp = []
elif not isEnglish(token) and language_flag == 'English':
token_list_all.append(token_list_tmp)
langauge_list.append('English')
token_list_tmp = []
token_list_tmp.append(token)
if isEnglish(token):
language_flag = 'English'
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
language_flag = 'Chinese'
if token_list_tmp:
token_list_all.append(token_list_tmp)
langauge_list.append(language_flag)
result_list = []
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
if language_flag == 'English':
result_list.extend(token_list_tmp)
else:
seg_list = jieba_usr_dict.cut(join_chinese_and_english(token_list_tmp), HMM=False)
result_list.extend(seg_list)
return result_list
else:
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def isEnglish(text:str):
if re.search('^[a-zA-Z\']+$', text):
return True
else:
return False
def join_chinese_and_english(input_list):
line = ''
for token in input_list:
if isEnglish(token):
line = line + ' ' + token
else:
line = line + token
line = line.strip()
return line