diff --git a/examples/industrial_data_pretraining/paraformer-large-long/infer.sh b/examples/industrial_data_pretraining/paraformer-large-long/infer.sh index d77329e6a..2e6ec0dba 100644 --- a/examples/industrial_data_pretraining/paraformer-large-long/infer.sh +++ b/examples/industrial_data_pretraining/paraformer-large-long/infer.sh @@ -4,6 +4,7 @@ cmd="funasr/bin/inference.py" python $cmd \ +model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \ +vad_model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \ ++punc_model="/Users/zhifu/Downloads/modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \ +input="/Users/zhifu/funasr_github/test_local/vad_example.wav" \ +output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \ +device="cpu" \ diff --git a/examples/industrial_data_pretraining/punc/infer.sh b/examples/industrial_data_pretraining/punc/infer.sh index 9c4054791..367581502 100644 --- a/examples/industrial_data_pretraining/punc/infer.sh +++ b/examples/industrial_data_pretraining/punc/infer.sh @@ -2,8 +2,17 @@ cmd="funasr/bin/inference.py" python $cmd \ -+model="/Users/zhifu/Downloads/modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \ +input="/Users/zhifu/FunASR/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt" \ ++model="/Users/zhifu/Downloads/modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \ +output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_punc" \ +device="cpu" \ +debug="true" + + +#+input="/Users/zhifu/FunASR/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt" \ + +#+"input='跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益'" \ + +#+input="/Users/zhifu/FunASR/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt" \ + +#+"input='那今天的会就到这里吧 happy new year 明年见'" \ \ No newline at end of file diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py index fda7abea9..16ad0e2a9 100644 --- a/funasr/bin/inference.py +++ b/funasr/bin/inference.py @@ -18,6 +18,7 @@ import string from funasr.register import tables from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio from funasr.utils.vad_utils import slice_padding_audio_samples +from funasr.utils.timestamp_tools import time_stamp_sentence def build_iter_for_infer(data_in, input_len=None, data_type="sound"): """ @@ -46,7 +47,7 @@ def build_iter_for_infer(data_in, input_len=None, data_type="sound"): data = lines["source"] key = data["key"] if "key" in data else key else: # filelist, wav.scp, text.txt: id \t data or data - lines = line.strip().split() + lines = line.strip().split(maxsplit=1) data = lines[1] if len(lines)>1 else lines[0] key = lines[0] if len(lines)>1 else key @@ -227,6 +228,7 @@ class AutoModel: # step.1: compute the vad model model = self.vad_model kwargs = self.vad_kwargs + kwargs.update(cfg) beg_vad = time.time() res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg) end_vad = time.time() @@ -322,6 +324,23 @@ class AutoModel: result["key"] = key results_ret_list.append(result) pbar_total.update(1) + + # step.3 compute punc model + model = self.punc_model + kwargs = self.punc_kwargs + kwargs.update(cfg) + + for i, result in enumerate(results_ret_list): + beg_punc = time.time() + res = self.generate(result["text"], model=model, kwargs=kwargs, **cfg) + end_punc = time.time() + print(f"time punc: {end_punc - beg_punc:0.3f}") + + # sentences = time_stamp_sentence(model.punc_list, model.sentence_end_id, results_ret_list[i]["timestamp"], res[i]["text"]) + # results_ret_list[i]["time_stamp"] = res[0]["text_postprocessed_punc"] + # results_ret_list[i]["sentences"] = sentences + # results_ret_list[i]["text_with_punc"] = res[i]["text"] + pbar_total.update(1) end_total = time.time() time_escape_total_all_samples = end_total - beg_total diff --git a/funasr/models/bici_paraformer/model.py b/funasr/models/bici_paraformer/model.py index 03c889646..c37ba12f2 100644 --- a/funasr/models/bici_paraformer/model.py +++ b/funasr/models/bici_paraformer/model.py @@ -29,7 +29,7 @@ from funasr.utils.datadir_writer import DatadirWriter from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.register import tables from funasr.models.ctc.ctc import CTC -from funasr.utils.timestamp_tools import time_stamp_sentence + from funasr.models.paraformer.model import Paraformer @@ -321,18 +321,16 @@ class BiCifParaformer(Paraformer): text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess( token, timestamp) - sentences = time_stamp_sentence(None, time_stamp_postprocessed, text_postprocessed) - result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed, + + result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed, - "word_lists": word_lists, - "sentences": sentences } if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) - ibest_writer["text"][key[i]] = text + # ibest_writer["text"][key[i]] = text ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed - ibest_writer["text_postprocessed"][key[i]] = text_postprocessed + ibest_writer["text"][key[i]] = text_postprocessed else: result_i = {"key": key[i], "token_int": token_int} results.append(result_i) diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py index a1aff4720..24a6aea68 100644 --- a/funasr/models/ct_transformer/model.py +++ b/funasr/models/ct_transformer/model.py @@ -10,7 +10,7 @@ from funasr.train_utils.device_funcs import force_gatherable from funasr.train_utils.device_funcs import to_device import torch import torch.nn as nn -from funasr.models.ct_transformer.utils import split_to_mini_sentence +from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words from funasr.register import tables @@ -34,6 +34,7 @@ class CTTransformer(nn.Module): ignore_id: int = -1, sos: int = 1, eos: int = 2, + sentence_end_id: int = 3, **kwargs, ): super().__init__() @@ -54,10 +55,11 @@ class CTTransformer(nn.Module): self.ignore_id = ignore_id self.sos = sos self.eos = eos + self.sentence_end_id = sentence_end_id - def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: + def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: """Compute loss value from buffer sequences. Args: @@ -65,7 +67,7 @@ class CTTransformer(nn.Module): hidden (torch.Tensor): Target ids. (batch, len) """ - x = self.embed(input) + x = self.embed(text) # mask = self._target_mask(input) h, _, _ = self.encoder(x, text_lengths) y = self.decoder(h) @@ -216,22 +218,26 @@ class CTTransformer(nn.Module): frontend=None, **kwargs, ): + assert len(data_in) == 1 + vad_indexes = kwargs.get("vad_indexes", None) - text = data_in - text_lengths = data_lengths + text = data_in[0] + text_lengths = data_lengths[0] if data_lengths is not None else None split_size = kwargs.get("split_size", 20) - data = {"text": text} - result = self.preprocessor(data=data, uid="12938712838719") - split_text = self.preprocessor.pop_split_text_data(result) - mini_sentences = split_to_mini_sentence(split_text, split_size) - mini_sentences_id = split_to_mini_sentence(data["text"], split_size) + tokens = split_words(text) + tokens_int = tokenizer.encode(tokens) + + mini_sentences = split_to_mini_sentence(tokens, split_size) + mini_sentences_id = split_to_mini_sentence(tokens_int, split_size) assert len(mini_sentences) == len(mini_sentences_id) cache_sent = [] cache_sent_id = torch.from_numpy(np.array([], dtype='int32')) new_mini_sentence = "" new_mini_sentence_punc = [] cache_pop_trigger_limit = 200 + results = [] + meta_data = {} for mini_sentence_i in range(len(mini_sentences)): mini_sentence = mini_sentences[mini_sentence_i] mini_sentence_id = mini_sentences_id[mini_sentence_i] @@ -241,9 +247,9 @@ class CTTransformer(nn.Module): "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0), "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')), } - data = to_device(data, self.device) + data = to_device(data, kwargs["device"]) # y, _ = self.wrapped_model(**data) - y, _ = self.punc_forward(text, text_lengths) + y, _ = self.punc_forward(**data) _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) punctuations = indices if indices.size()[0] != 1: @@ -264,7 +270,7 @@ class CTTransformer(nn.Module): if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0: # The sentence it too long, cut off at a comma. sentenceEnd = last_comma_index - punctuations[sentenceEnd] = self.period + punctuations[sentenceEnd] = self.sentence_end_id cache_sent = mini_sentence[sentenceEnd + 1:] cache_sent_id = mini_sentence_id[sentenceEnd + 1:] mini_sentence = mini_sentence[0:sentenceEnd + 1] @@ -303,21 +309,19 @@ class CTTransformer(nn.Module): if mini_sentence_i == len(mini_sentences) - 1: if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、": new_mini_sentence_out = new_mini_sentence[:-1] + "。" - new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] + new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] elif new_mini_sentence[-1] == ",": new_mini_sentence_out = new_mini_sentence[:-1] + "." - new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] + new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==0: new_mini_sentence_out = new_mini_sentence + "。" - new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] + new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1: new_mini_sentence_out = new_mini_sentence + "." - new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] - - return new_mini_sentence_out, new_mini_sentence_punc_out - - # if self.with_vad(): - # assert vad_indexes is not None - # return self.punc_forward(text, text_lengths, vad_indexes) - # else: - # return self.punc_forward(text, text_lengths) \ No newline at end of file + new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] + + result_i = {"key": key[0], "text": new_mini_sentence_out} + results.append(result_i) + + return results, meta_data + diff --git a/funasr/models/ct_transformer/template.yaml b/funasr/models/ct_transformer/template.yaml new file mode 100644 index 000000000..cad04be38 --- /dev/null +++ b/funasr/models/ct_transformer/template.yaml @@ -0,0 +1,52 @@ +# This is an example that demonstrates how to configure a model file. +# You can modify the configuration according to your own requirements. + +# to print the register_table: +# from funasr.register import tables +# tables.print() + +model: CTTransformer +model_conf: + ignore_id: 0 + embed_unit: 256 + att_unit: 256 + dropout_rate: 0.1 + punc_list: + - + - _ + - ',' + - 。 + - '?' + - 、 + punc_weight: + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + +encoder: SANMEncoder +encoder_conf: + input_size: 256 + output_size: 256 + attention_heads: 8 + linear_units: 1024 + num_blocks: 4 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: pe + pos_enc_class: SinusoidalPositionEncoder + normalize_before: true + kernel_size: 11 + sanm_shfit: 0 + selfattention_layer_type: sanm + padding_idx: 0 + +tokenizer: CharTokenizer +tokenizer_conf: + unk_symbol: + + + diff --git a/funasr/models/ct_transformer/utils.py b/funasr/models/ct_transformer/utils.py index 0291dbc43..a4a00e0f7 100644 --- a/funasr/models/ct_transformer/utils.py +++ b/funasr/models/ct_transformer/utils.py @@ -12,3 +12,25 @@ def split_to_mini_sentence(words: list, word_limit: int = 20): if length % word_limit > 0: sentences.append(words[sentence_len * word_limit:]) return sentences + + +def split_words(text: str): + words = [] + segs = text.split() + for seg in segs: + # There is no space in seg. + current_word = "" + for c in seg: + if len(c.encode()) == 1: + # This is an ASCII char. + current_word += c + else: + # This is a Chinese char. + if len(current_word) > 0: + words.append(current_word) + current_word = "" + words.append(c) + if len(current_word) > 0: + words.append(current_word) + + return words diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index d92d08d5c..1caed90f3 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -535,13 +535,13 @@ class Paraformer(nn.Module): text = tokenizer.tokens2text(token) text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed} + result_i = {"key": key[i], "text_postprocessed": text_postprocessed} if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) - ibest_writer["text"][key[i]] = text - ibest_writer["text_postprocessed"][key[i]] = text_postprocessed + # ibest_writer["text"][key[i]] = text + ibest_writer["text"][key[i]] = text_postprocessed else: result_i = {"key": key[i], "token_int": token_int} results.append(result_i)