funasr1.0 fix punc model

This commit is contained in:
游雁 2024-01-13 22:42:18 +08:00
parent c0b186b5b6
commit 835369d631
11 changed files with 125 additions and 31 deletions

View File

@ -10,7 +10,7 @@ model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1", vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0", punc_model_revision="v2.0.1",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common", spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
) )
@ -23,7 +23,7 @@ model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1", vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0", punc_model_revision="v2.0.1",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common", spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
spk_mode='punc_segment', spk_mode='punc_segment',
) )

View File

@ -4,7 +4,7 @@ model_revision="v2.0.0"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.0" vad_model_revision="v2.0.0"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.0" punc_model_revision="v2.0.1"
python funasr/bin/inference.py \ python funasr/bin/inference.py \
+model=${model} \ +model=${model} \

View File

@ -5,7 +5,15 @@
from funasr import AutoModel from funasr import AutoModel
model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.0") model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.1")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
from funasr import AutoModel
model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.1")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt") res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res) print(res)

View File

@ -1,6 +1,9 @@
model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
model_revision="v2.0.0" model_revision="v2.0.1"
model="damo/punc_ct-transformer_cn-en-common-vocab471067-large"
model_revision="v2.0.1"
python funasr/bin/inference.py \ python funasr/bin/inference.py \
+model=${model} \ +model=${model} \

View File

@ -10,7 +10,7 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1", vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0", punc_model_revision="v2.0.1",
spk_model="damo/speech_campplus_sv_zh-cn_16k-common", spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
spk_model_revision="v2.0.0" spk_model_revision="v2.0.0"
) )

View File

@ -4,7 +4,7 @@ model_revision="v2.0.0"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.1" vad_model_revision="v2.0.1"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.0" punc_model_revision="v2.0.1"
spk_model="damo/speech_campplus_sv_zh-cn_16k-common" spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
spk_model_revision="v2.0.0" spk_model_revision="v2.0.0"

View File

@ -10,7 +10,7 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1", vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0", punc_model_revision="v2.0.1",
) )
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",

View File

@ -4,7 +4,7 @@ model_revision="v2.0.0"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.1" vad_model_revision="v2.0.1"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.0" punc_model_revision="v2.0.1"
python funasr/bin/inference.py \ python funasr/bin/inference.py \
+model=${model} \ +model=${model} \

View File

@ -37,6 +37,8 @@ def download_from_ms(**kwargs):
kwargs["model"] = cfg["model"] kwargs["model"] = cfg["model"]
if os.path.exists(os.path.join(model_or_path, "am.mvn")): if os.path.exists(os.path.join(model_or_path, "am.mvn")):
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
else:# configuration.json else:# configuration.json
assert os.path.exists(os.path.join(model_or_path, "configuration.json")) assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f: with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:

View File

@ -225,8 +225,14 @@ class CTTransformer(nn.Module):
# text = data_in[0] # text = data_in[0]
# text_lengths = data_lengths[0] if data_lengths is not None else None # text_lengths = data_lengths[0] if data_lengths is not None else None
split_size = kwargs.get("split_size", 20) split_size = kwargs.get("split_size", 20)
tokens = split_words(text) jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
if jieba_usr_dict and isinstance(jieba_usr_dict, str):
import jieba
jieba.load_userdict(jieba_usr_dict)
jieba_usr_dict = jieba
kwargs["jieba_usr_dict"] = "jieba_usr_dict"
tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
tokens_int = tokenizer.encode(tokens) tokens_int = tokenizer.encode(tokens)
mini_sentences = split_to_mini_sentence(tokens, split_size) mini_sentences = split_to_mini_sentence(tokens, split_size)

View File

@ -1,4 +1,4 @@
import re
def split_to_mini_sentence(words: list, word_limit: int = 20): def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1 assert word_limit > 1
@ -14,23 +14,98 @@ def split_to_mini_sentence(words: list, word_limit: int = 20):
return sentences return sentences
def split_words(text: str): # def split_words(text: str, **kwargs):
words = [] # words = []
segs = text.split() # segs = text.split()
for seg in segs: # for seg in segs:
# There is no space in seg. # # There is no space in seg.
current_word = "" # current_word = ""
for c in seg: # for c in seg:
if len(c.encode()) == 1: # if len(c.encode()) == 1:
# This is an ASCII char. # # This is an ASCII char.
current_word += c # current_word += c
# else:
# # This is a Chinese char.
# if len(current_word) > 0:
# words.append(current_word)
# current_word = ""
# words.append(c)
# if len(current_word) > 0:
# words.append(current_word)
#
# return words
def split_words(text: str, jieba_usr_dict=None, **kwargs):
if jieba_usr_dict:
input_list = text.split()
token_list_all = []
langauge_list = []
token_list_tmp = []
language_flag = None
for token in input_list:
if isEnglish(token) and language_flag == 'Chinese':
token_list_all.append(token_list_tmp)
langauge_list.append('Chinese')
token_list_tmp = []
elif not isEnglish(token) and language_flag == 'English':
token_list_all.append(token_list_tmp)
langauge_list.append('English')
token_list_tmp = []
token_list_tmp.append(token)
if isEnglish(token):
language_flag = 'English'
else: else:
# This is a Chinese char. language_flag = 'Chinese'
if len(current_word) > 0:
words.append(current_word) if token_list_tmp:
current_word = "" token_list_all.append(token_list_tmp)
words.append(c) langauge_list.append(language_flag)
if len(current_word) > 0:
words.append(current_word) result_list = []
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
return words if language_flag == 'English':
result_list.extend(token_list_tmp)
else:
seg_list = jieba_usr_dict.cut(join_chinese_and_english(token_list_tmp), HMM=False)
result_list.extend(seg_list)
return result_list
else:
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def isEnglish(text:str):
if re.search('^[a-zA-Z\']+$', text):
return True
else:
return False
def join_chinese_and_english(input_list):
line = ''
for token in input_list:
if isEnglish(token):
line = line + ' ' + token
else:
line = line + token
line = line.strip()
return line