From bdc7a17c1f3efccb437517e74c780f64923ea647 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 26 Dec 2023 22:35:27 +0800 Subject: [PATCH] funasr1.0 --- .../industrial_data_pretraining/punc/infer.sh | 9 ++ funasr/download/download_from_hub.py | 5 +- funasr/models/ct_transformer/model.py | 131 ++++++++++++++++-- funasr/models/ct_transformer/utils.py | 14 ++ 4 files changed, 148 insertions(+), 11 deletions(-) create mode 100644 examples/industrial_data_pretraining/punc/infer.sh create mode 100644 funasr/models/ct_transformer/utils.py diff --git a/examples/industrial_data_pretraining/punc/infer.sh b/examples/industrial_data_pretraining/punc/infer.sh new file mode 100644 index 000000000..9c4054791 --- /dev/null +++ b/examples/industrial_data_pretraining/punc/infer.sh @@ -0,0 +1,9 @@ + +cmd="funasr/bin/inference.py" + +python $cmd \ ++model="/Users/zhifu/Downloads/modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \ ++input="/Users/zhifu/FunASR/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt" \ ++output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_punc" \ ++device="cpu" \ ++debug="true" diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py index 2e7578f37..4f05b42ad 100644 --- a/funasr/download/download_from_hub.py +++ b/funasr/download/download_from_hub.py @@ -26,12 +26,15 @@ def download_fr_ms(**kwargs): kwargs["init_param"] = init_param if os.path.exists(os.path.join(model_or_path, "tokens.txt")): kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt") + if os.path.exists(os.path.join(model_or_path, "tokens.json")): + kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json") if os.path.exists(os.path.join(model_or_path, "seg_dict")): kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict") if os.path.exists(os.path.join(model_or_path, "bpe.model")): kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model") kwargs["model"] = cfg["model"] - kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") + if os.path.exists(os.path.join(model_or_path, "am.mvn")): + kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") return OmegaConf.to_container(kwargs, resolve=True) diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py index d8c7fc3bf..a1aff4720 100644 --- a/funasr/models/ct_transformer/model.py +++ b/funasr/models/ct_transformer/model.py @@ -1,9 +1,16 @@ from typing import Any from typing import List from typing import Tuple +from typing import Optional +import numpy as np +import torch.nn.functional as F +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.train_utils.device_funcs import force_gatherable +from funasr.train_utils.device_funcs import to_device import torch import torch.nn as nn +from funasr.models.ct_transformer.utils import split_to_mini_sentence from funasr.register import tables @@ -17,7 +24,7 @@ class CTTransformer(nn.Module): def __init__( self, encoder: str = None, - encoder_conf: str = None, + encoder_conf: dict = None, vocab_size: int = -1, punc_list: list = None, punc_weight: list = None, @@ -191,7 +198,7 @@ class CTTransformer(nn.Module): punc_lengths: torch.Tensor, vad_indexes: Optional[torch.Tensor] = None, vad_indexes_lengths: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + ): nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes) ntokens = y_lengths.sum() loss = nll.sum() / ntokens @@ -202,11 +209,115 @@ class CTTransformer(nn.Module): return loss, stats, weight def generate(self, - text: torch.Tensor, - text_lengths: torch.Tensor, - vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]: - if self.with_vad(): - assert vad_indexes is not None - return self.punc_forward(text, text_lengths, vad_indexes) - else: - return self.punc_forward(text, text_lengths) \ No newline at end of file + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + vad_indexes = kwargs.get("vad_indexes", None) + text = data_in + text_lengths = data_lengths + split_size = kwargs.get("split_size", 20) + + data = {"text": text} + result = self.preprocessor(data=data, uid="12938712838719") + split_text = self.preprocessor.pop_split_text_data(result) + mini_sentences = split_to_mini_sentence(split_text, split_size) + mini_sentences_id = split_to_mini_sentence(data["text"], split_size) + assert len(mini_sentences) == len(mini_sentences_id) + cache_sent = [] + cache_sent_id = torch.from_numpy(np.array([], dtype='int32')) + new_mini_sentence = "" + new_mini_sentence_punc = [] + cache_pop_trigger_limit = 200 + for mini_sentence_i in range(len(mini_sentences)): + mini_sentence = mini_sentences[mini_sentence_i] + mini_sentence_id = mini_sentences_id[mini_sentence_i] + mini_sentence = cache_sent + mini_sentence + mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0) + data = { + "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0), + "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')), + } + data = to_device(data, self.device) + # y, _ = self.wrapped_model(**data) + y, _ = self.punc_forward(text, text_lengths) + _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) + punctuations = indices + if indices.size()[0] != 1: + punctuations = torch.squeeze(indices) + assert punctuations.size()[0] == len(mini_sentence) + + # Search for the last Period/QuestionMark as cache + if mini_sentence_i < len(mini_sentences) - 1: + sentenceEnd = -1 + last_comma_index = -1 + for i in range(len(punctuations) - 2, 1, -1): + if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?": + sentenceEnd = i + break + if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",": + last_comma_index = i + + if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0: + # The sentence it too long, cut off at a comma. + sentenceEnd = last_comma_index + punctuations[sentenceEnd] = self.period + cache_sent = mini_sentence[sentenceEnd + 1:] + cache_sent_id = mini_sentence_id[sentenceEnd + 1:] + mini_sentence = mini_sentence[0:sentenceEnd + 1] + punctuations = punctuations[0:sentenceEnd + 1] + + # if len(punctuations) == 0: + # continue + + punctuations_np = punctuations.cpu().numpy() + new_mini_sentence_punc += [int(x) for x in punctuations_np] + words_with_punc = [] + for i in range(len(mini_sentence)): + if (i==0 or self.punc_list[punctuations[i-1]] == "。" or self.punc_list[punctuations[i-1]] == "?") and len(mini_sentence[i][0].encode()) == 1: + mini_sentence[i] = mini_sentence[i].capitalize() + if i == 0: + if len(mini_sentence[i][0].encode()) == 1: + mini_sentence[i] = " " + mini_sentence[i] + if i > 0: + if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1: + mini_sentence[i] = " " + mini_sentence[i] + words_with_punc.append(mini_sentence[i]) + if self.punc_list[punctuations[i]] != "_": + punc_res = self.punc_list[punctuations[i]] + if len(mini_sentence[i][0].encode()) == 1: + if punc_res == ",": + punc_res = "," + elif punc_res == "。": + punc_res = "." + elif punc_res == "?": + punc_res = "?" + words_with_punc.append(punc_res) + new_mini_sentence += "".join(words_with_punc) + # Add Period for the end of the sentence + new_mini_sentence_out = new_mini_sentence + new_mini_sentence_punc_out = new_mini_sentence_punc + if mini_sentence_i == len(mini_sentences) - 1: + if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、": + new_mini_sentence_out = new_mini_sentence[:-1] + "。" + new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] + elif new_mini_sentence[-1] == ",": + new_mini_sentence_out = new_mini_sentence[:-1] + "." + new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] + elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==0: + new_mini_sentence_out = new_mini_sentence + "。" + new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] + elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1: + new_mini_sentence_out = new_mini_sentence + "." + new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] + + return new_mini_sentence_out, new_mini_sentence_punc_out + + # if self.with_vad(): + # assert vad_indexes is not None + # return self.punc_forward(text, text_lengths, vad_indexes) + # else: + # return self.punc_forward(text, text_lengths) \ No newline at end of file diff --git a/funasr/models/ct_transformer/utils.py b/funasr/models/ct_transformer/utils.py new file mode 100644 index 000000000..0291dbc43 --- /dev/null +++ b/funasr/models/ct_transformer/utils.py @@ -0,0 +1,14 @@ + + +def split_to_mini_sentence(words: list, word_limit: int = 20): + assert word_limit > 1 + if len(words) <= word_limit: + return [words] + sentences = [] + length = len(words) + sentence_len = length // word_limit + for i in range(sentence_len): + sentences.append(words[i * word_limit:(i + 1) * word_limit]) + if length % word_limit > 0: + sentences.append(words[sentence_len * word_limit:]) + return sentences