funasr1.0

This commit is contained in:
游雁 2023-12-26 22:35:27 +08:00
parent b66a41fb73
commit bdc7a17c1f
4 changed files with 148 additions and 11 deletions

View File

@ -0,0 +1,9 @@
cmd="funasr/bin/inference.py"
python $cmd \
+model="/Users/zhifu/Downloads/modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \
+input="/Users/zhifu/FunASR/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt" \
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_punc" \
+device="cpu" \
+debug="true"

View File

@ -26,11 +26,14 @@ def download_fr_ms(**kwargs):
kwargs["init_param"] = init_param
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
kwargs["model"] = cfg["model"]
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
return OmegaConf.to_container(kwargs, resolve=True)

View File

@ -1,9 +1,16 @@
from typing import Any
from typing import List
from typing import Tuple
from typing import Optional
import numpy as np
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.train_utils.device_funcs import force_gatherable
from funasr.train_utils.device_funcs import to_device
import torch
import torch.nn as nn
from funasr.models.ct_transformer.utils import split_to_mini_sentence
from funasr.register import tables
@ -17,7 +24,7 @@ class CTTransformer(nn.Module):
def __init__(
self,
encoder: str = None,
encoder_conf: str = None,
encoder_conf: dict = None,
vocab_size: int = -1,
punc_list: list = None,
punc_weight: list = None,
@ -191,7 +198,7 @@ class CTTransformer(nn.Module):
punc_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
):
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
@ -202,11 +209,115 @@ class CTTransformer(nn.Module):
return loss, stats, weight
def generate(self,
text: torch.Tensor,
text_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
if self.with_vad():
assert vad_indexes is not None
return self.punc_forward(text, text_lengths, vad_indexes)
else:
return self.punc_forward(text, text_lengths)
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
vad_indexes = kwargs.get("vad_indexes", None)
text = data_in
text_lengths = data_lengths
split_size = kwargs.get("split_size", 20)
data = {"text": text}
result = self.preprocessor(data=data, uid="12938712838719")
split_text = self.preprocessor.pop_split_text_data(result)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
new_mini_sentence = ""
new_mini_sentence_punc = []
cache_pop_trigger_limit = 200
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
}
data = to_device(data, self.device)
# y, _ = self.wrapped_model(**data)
y, _ = self.punc_forward(text, text_lengths)
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if self.punc_list[punctuations[i]] == "" or self.punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
cache_sent = mini_sentence[sentenceEnd + 1:]
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
# if len(punctuations) == 0:
# continue
punctuations_np = punctuations.cpu().numpy()
new_mini_sentence_punc += [int(x) for x in punctuations_np]
words_with_punc = []
for i in range(len(mini_sentence)):
if (i==0 or self.punc_list[punctuations[i-1]] == "" or self.punc_list[punctuations[i-1]] == "") and len(mini_sentence[i][0].encode()) == 1:
mini_sentence[i] = mini_sentence[i].capitalize()
if i == 0:
if len(mini_sentence[i][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
if i > 0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
punc_res = self.punc_list[punctuations[i]]
if len(mini_sentence[i][0].encode()) == 1:
if punc_res == "":
punc_res = ","
elif punc_res == "":
punc_res = "."
elif punc_res == "":
punc_res = "?"
words_with_punc.append(punc_res)
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
new_mini_sentence_punc_out = new_mini_sentence_punc
if mini_sentence_i == len(mini_sentences) - 1:
if new_mini_sentence[-1] == "" or new_mini_sentence[-1] == "":
new_mini_sentence_out = new_mini_sentence[:-1] + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] == ",":
new_mini_sentence_out = new_mini_sentence[:-1] + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] != "" and new_mini_sentence[-1] != "" and len(new_mini_sentence[-1].encode())==0:
new_mini_sentence_out = new_mini_sentence + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
new_mini_sentence_out = new_mini_sentence + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out
# if self.with_vad():
# assert vad_indexes is not None
# return self.punc_forward(text, text_lengths, vad_indexes)
# else:
# return self.punc_forward(text, text_lengths)

View File

@ -0,0 +1,14 @@
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences