diff --git a/examples/industrial_data_pretraining/bicif_paraformer/demo.py b/examples/industrial_data_pretraining/bicif_paraformer/demo.py index 84b0e80b8..4d921ea46 100644 --- a/examples/industrial_data_pretraining/bicif_paraformer/demo.py +++ b/examples/industrial_data_pretraining/bicif_paraformer/demo.py @@ -10,7 +10,7 @@ model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model_revision="v2.0.1", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", - punc_model_revision="v2.0.0", + punc_model_revision="v2.0.1", spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common", ) @@ -23,7 +23,7 @@ model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model_revision="v2.0.1", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", - punc_model_revision="v2.0.0", + punc_model_revision="v2.0.1", spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common", spk_mode='punc_segment', ) diff --git a/examples/industrial_data_pretraining/bicif_paraformer/infer.sh b/examples/industrial_data_pretraining/bicif_paraformer/infer.sh index 04cb6f2aa..57c5838d1 100644 --- a/examples/industrial_data_pretraining/bicif_paraformer/infer.sh +++ b/examples/industrial_data_pretraining/bicif_paraformer/infer.sh @@ -4,7 +4,7 @@ model_revision="v2.0.0" vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" vad_model_revision="v2.0.0" punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" -punc_model_revision="v2.0.0" +punc_model_revision="v2.0.1" python funasr/bin/inference.py \ +model=${model} \ diff --git a/examples/industrial_data_pretraining/ct_transformer/demo.py b/examples/industrial_data_pretraining/ct_transformer/demo.py index 58ebd2aef..23965e017 100644 --- a/examples/industrial_data_pretraining/ct_transformer/demo.py +++ b/examples/industrial_data_pretraining/ct_transformer/demo.py @@ -5,7 +5,15 @@ from funasr import AutoModel -model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.0") +model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.1") + +res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt") +print(res) + + +from funasr import AutoModel + +model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.1") res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt") print(res) \ No newline at end of file diff --git a/examples/industrial_data_pretraining/ct_transformer/infer.sh b/examples/industrial_data_pretraining/ct_transformer/infer.sh index a48d56208..4b4e94954 100644 --- a/examples/industrial_data_pretraining/ct_transformer/infer.sh +++ b/examples/industrial_data_pretraining/ct_transformer/infer.sh @@ -1,6 +1,9 @@ model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" -model_revision="v2.0.0" +model_revision="v2.0.1" + +model="damo/punc_ct-transformer_cn-en-common-vocab471067-large" +model_revision="v2.0.1" python funasr/bin/inference.py \ +model=${model} \ diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py index 774d757e1..fcf5f60ca 100644 --- a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py +++ b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py @@ -10,7 +10,7 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model_revision="v2.0.1", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", - punc_model_revision="v2.0.0", + punc_model_revision="v2.0.1", spk_model="damo/speech_campplus_sv_zh-cn_16k-common", spk_model_revision="v2.0.0" ) diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh index a45740194..63347b6a1 100644 --- a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh +++ b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh @@ -4,7 +4,7 @@ model_revision="v2.0.0" vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" vad_model_revision="v2.0.1" punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" -punc_model_revision="v2.0.0" +punc_model_revision="v2.0.1" spk_model="damo/speech_campplus_sv_zh-cn_16k-common" spk_model_revision="v2.0.0" diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index 63f155eb2..3b5963a4b 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -10,7 +10,7 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model_revision="v2.0.1", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", - punc_model_revision="v2.0.0", + punc_model_revision="v2.0.1", ) res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", diff --git a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh index 26eeee1d3..c46449f74 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh +++ b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh @@ -4,7 +4,7 @@ model_revision="v2.0.0" vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" vad_model_revision="v2.0.1" punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" -punc_model_revision="v2.0.0" +punc_model_revision="v2.0.1" python funasr/bin/inference.py \ +model=${model} \ diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py index 946572fce..27bd79d81 100644 --- a/funasr/download/download_from_hub.py +++ b/funasr/download/download_from_hub.py @@ -37,6 +37,8 @@ def download_from_ms(**kwargs): kwargs["model"] = cfg["model"] if os.path.exists(os.path.join(model_or_path, "am.mvn")): kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") + if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")): + kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict") else:# configuration.json assert os.path.exists(os.path.join(model_or_path, "configuration.json")) with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f: diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py index fbf180408..d84368636 100644 --- a/funasr/models/ct_transformer/model.py +++ b/funasr/models/ct_transformer/model.py @@ -225,8 +225,14 @@ class CTTransformer(nn.Module): # text = data_in[0] # text_lengths = data_lengths[0] if data_lengths is not None else None split_size = kwargs.get("split_size", 20) - - tokens = split_words(text) + + jieba_usr_dict = kwargs.get("jieba_usr_dict", None) + if jieba_usr_dict and isinstance(jieba_usr_dict, str): + import jieba + jieba.load_userdict(jieba_usr_dict) + jieba_usr_dict = jieba + kwargs["jieba_usr_dict"] = "jieba_usr_dict" + tokens = split_words(text, jieba_usr_dict=jieba_usr_dict) tokens_int = tokenizer.encode(tokens) mini_sentences = split_to_mini_sentence(tokens, split_size) diff --git a/funasr/models/ct_transformer/utils.py b/funasr/models/ct_transformer/utils.py index a4a00e0f7..917f2e035 100644 --- a/funasr/models/ct_transformer/utils.py +++ b/funasr/models/ct_transformer/utils.py @@ -1,4 +1,4 @@ - +import re def split_to_mini_sentence(words: list, word_limit: int = 20): assert word_limit > 1 @@ -14,23 +14,98 @@ def split_to_mini_sentence(words: list, word_limit: int = 20): return sentences -def split_words(text: str): - words = [] - segs = text.split() - for seg in segs: - # There is no space in seg. - current_word = "" - for c in seg: - if len(c.encode()) == 1: - # This is an ASCII char. - current_word += c +# def split_words(text: str, **kwargs): +# words = [] +# segs = text.split() +# for seg in segs: +# # There is no space in seg. +# current_word = "" +# for c in seg: +# if len(c.encode()) == 1: +# # This is an ASCII char. +# current_word += c +# else: +# # This is a Chinese char. +# if len(current_word) > 0: +# words.append(current_word) +# current_word = "" +# words.append(c) +# if len(current_word) > 0: +# words.append(current_word) +# +# return words + +def split_words(text: str, jieba_usr_dict=None, **kwargs): + if jieba_usr_dict: + input_list = text.split() + token_list_all = [] + langauge_list = [] + token_list_tmp = [] + language_flag = None + for token in input_list: + if isEnglish(token) and language_flag == 'Chinese': + token_list_all.append(token_list_tmp) + langauge_list.append('Chinese') + token_list_tmp = [] + elif not isEnglish(token) and language_flag == 'English': + token_list_all.append(token_list_tmp) + langauge_list.append('English') + token_list_tmp = [] + + token_list_tmp.append(token) + + if isEnglish(token): + language_flag = 'English' else: - # This is a Chinese char. - if len(current_word) > 0: - words.append(current_word) - current_word = "" - words.append(c) - if len(current_word) > 0: - words.append(current_word) - - return words + language_flag = 'Chinese' + + if token_list_tmp: + token_list_all.append(token_list_tmp) + langauge_list.append(language_flag) + + result_list = [] + for token_list_tmp, language_flag in zip(token_list_all, langauge_list): + if language_flag == 'English': + result_list.extend(token_list_tmp) + else: + seg_list = jieba_usr_dict.cut(join_chinese_and_english(token_list_tmp), HMM=False) + result_list.extend(seg_list) + + return result_list + + else: + words = [] + segs = text.split() + for seg in segs: + # There is no space in seg. + current_word = "" + for c in seg: + if len(c.encode()) == 1: + # This is an ASCII char. + current_word += c + else: + # This is a Chinese char. + if len(current_word) > 0: + words.append(current_word) + current_word = "" + words.append(c) + if len(current_word) > 0: + words.append(current_word) + return words + +def isEnglish(text:str): + if re.search('^[a-zA-Z\']+$', text): + return True + else: + return False + +def join_chinese_and_english(input_list): + line = '' + for token in input_list: + if isEnglish(token): + line = line + ' ' + token + else: + line = line + token + + line = line.strip() + return line