funasr1.0

This commit is contained in:
游雁 2023-12-26 22:35:27 +08:00
parent b66a41fb73
commit bdc7a17c1f
4 changed files with 148 additions and 11 deletions

View File

@ -0,0 +1,9 @@
cmd="funasr/bin/inference.py"
python $cmd \
+model="/Users/zhifu/Downloads/modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \
+input="/Users/zhifu/FunASR/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt" \
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_punc" \
+device="cpu" \
+debug="true"

View File

@ -26,11 +26,14 @@ def download_fr_ms(**kwargs):
kwargs["init_param"] = init_param kwargs["init_param"] = init_param
if os.path.exists(os.path.join(model_or_path, "tokens.txt")): if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt") kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
if os.path.exists(os.path.join(model_or_path, "seg_dict")): if os.path.exists(os.path.join(model_or_path, "seg_dict")):
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict") kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
if os.path.exists(os.path.join(model_or_path, "bpe.model")): if os.path.exists(os.path.join(model_or_path, "bpe.model")):
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model") kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
kwargs["model"] = cfg["model"] kwargs["model"] = cfg["model"]
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
return OmegaConf.to_container(kwargs, resolve=True) return OmegaConf.to_container(kwargs, resolve=True)

View File

@ -1,9 +1,16 @@
from typing import Any from typing import Any
from typing import List from typing import List
from typing import Tuple from typing import Tuple
from typing import Optional
import numpy as np
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.train_utils.device_funcs import force_gatherable
from funasr.train_utils.device_funcs import to_device
import torch import torch
import torch.nn as nn import torch.nn as nn
from funasr.models.ct_transformer.utils import split_to_mini_sentence
from funasr.register import tables from funasr.register import tables
@ -17,7 +24,7 @@ class CTTransformer(nn.Module):
def __init__( def __init__(
self, self,
encoder: str = None, encoder: str = None,
encoder_conf: str = None, encoder_conf: dict = None,
vocab_size: int = -1, vocab_size: int = -1,
punc_list: list = None, punc_list: list = None,
punc_weight: list = None, punc_weight: list = None,
@ -191,7 +198,7 @@ class CTTransformer(nn.Module):
punc_lengths: torch.Tensor, punc_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None, vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None, vad_indexes_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: ):
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes) nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
ntokens = y_lengths.sum() ntokens = y_lengths.sum()
loss = nll.sum() / ntokens loss = nll.sum() / ntokens
@ -202,11 +209,115 @@ class CTTransformer(nn.Module):
return loss, stats, weight return loss, stats, weight
def generate(self, def generate(self,
text: torch.Tensor, data_in,
text_lengths: torch.Tensor, data_lengths=None,
vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]: key: list = None,
if self.with_vad(): tokenizer=None,
assert vad_indexes is not None frontend=None,
return self.punc_forward(text, text_lengths, vad_indexes) **kwargs,
else: ):
return self.punc_forward(text, text_lengths) vad_indexes = kwargs.get("vad_indexes", None)
text = data_in
text_lengths = data_lengths
split_size = kwargs.get("split_size", 20)
data = {"text": text}
result = self.preprocessor(data=data, uid="12938712838719")
split_text = self.preprocessor.pop_split_text_data(result)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
new_mini_sentence = ""
new_mini_sentence_punc = []
cache_pop_trigger_limit = 200
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
}
data = to_device(data, self.device)
# y, _ = self.wrapped_model(**data)
y, _ = self.punc_forward(text, text_lengths)
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if self.punc_list[punctuations[i]] == "" or self.punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
cache_sent = mini_sentence[sentenceEnd + 1:]
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
# if len(punctuations) == 0:
# continue
punctuations_np = punctuations.cpu().numpy()
new_mini_sentence_punc += [int(x) for x in punctuations_np]
words_with_punc = []
for i in range(len(mini_sentence)):
if (i==0 or self.punc_list[punctuations[i-1]] == "" or self.punc_list[punctuations[i-1]] == "") and len(mini_sentence[i][0].encode()) == 1:
mini_sentence[i] = mini_sentence[i].capitalize()
if i == 0:
if len(mini_sentence[i][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
if i > 0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
punc_res = self.punc_list[punctuations[i]]
if len(mini_sentence[i][0].encode()) == 1:
if punc_res == "":
punc_res = ","
elif punc_res == "":
punc_res = "."
elif punc_res == "":
punc_res = "?"
words_with_punc.append(punc_res)
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
new_mini_sentence_punc_out = new_mini_sentence_punc
if mini_sentence_i == len(mini_sentences) - 1:
if new_mini_sentence[-1] == "" or new_mini_sentence[-1] == "":
new_mini_sentence_out = new_mini_sentence[:-1] + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] == ",":
new_mini_sentence_out = new_mini_sentence[:-1] + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] != "" and new_mini_sentence[-1] != "" and len(new_mini_sentence[-1].encode())==0:
new_mini_sentence_out = new_mini_sentence + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
new_mini_sentence_out = new_mini_sentence + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out
# if self.with_vad():
# assert vad_indexes is not None
# return self.punc_forward(text, text_lengths, vad_indexes)
# else:
# return self.punc_forward(text, text_lengths)

View File

@ -0,0 +1,14 @@
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences