mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr1.0
This commit is contained in:
parent
b66a41fb73
commit
bdc7a17c1f
9
examples/industrial_data_pretraining/punc/infer.sh
Normal file
9
examples/industrial_data_pretraining/punc/infer.sh
Normal file
@ -0,0 +1,9 @@
|
||||
|
||||
cmd="funasr/bin/inference.py"
|
||||
|
||||
python $cmd \
|
||||
+model="/Users/zhifu/Downloads/modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \
|
||||
+input="/Users/zhifu/FunASR/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt" \
|
||||
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_punc" \
|
||||
+device="cpu" \
|
||||
+debug="true"
|
||||
@ -26,12 +26,15 @@ def download_fr_ms(**kwargs):
|
||||
kwargs["init_param"] = init_param
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
||||
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
|
||||
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
|
||||
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
|
||||
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
|
||||
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
|
||||
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
|
||||
kwargs["model"] = cfg["model"]
|
||||
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
||||
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
|
||||
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
||||
|
||||
return OmegaConf.to_container(kwargs, resolve=True)
|
||||
|
||||
|
||||
@ -1,9 +1,16 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.train_utils.device_funcs import to_device
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.models.ct_transformer.utils import split_to_mini_sentence
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
@ -17,7 +24,7 @@ class CTTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: str = None,
|
||||
encoder_conf: str = None,
|
||||
encoder_conf: dict = None,
|
||||
vocab_size: int = -1,
|
||||
punc_list: list = None,
|
||||
punc_weight: list = None,
|
||||
@ -191,7 +198,7 @@ class CTTransformer(nn.Module):
|
||||
punc_lengths: torch.Tensor,
|
||||
vad_indexes: Optional[torch.Tensor] = None,
|
||||
vad_indexes_lengths: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
):
|
||||
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
|
||||
ntokens = y_lengths.sum()
|
||||
loss = nll.sum() / ntokens
|
||||
@ -202,11 +209,115 @@ class CTTransformer(nn.Module):
|
||||
return loss, stats, weight
|
||||
|
||||
def generate(self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
|
||||
if self.with_vad():
|
||||
assert vad_indexes is not None
|
||||
return self.punc_forward(text, text_lengths, vad_indexes)
|
||||
else:
|
||||
return self.punc_forward(text, text_lengths)
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
vad_indexes = kwargs.get("vad_indexes", None)
|
||||
text = data_in
|
||||
text_lengths = data_lengths
|
||||
split_size = kwargs.get("split_size", 20)
|
||||
|
||||
data = {"text": text}
|
||||
result = self.preprocessor(data=data, uid="12938712838719")
|
||||
split_text = self.preprocessor.pop_split_text_data(result)
|
||||
mini_sentences = split_to_mini_sentence(split_text, split_size)
|
||||
mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
|
||||
assert len(mini_sentences) == len(mini_sentences_id)
|
||||
cache_sent = []
|
||||
cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
|
||||
new_mini_sentence = ""
|
||||
new_mini_sentence_punc = []
|
||||
cache_pop_trigger_limit = 200
|
||||
for mini_sentence_i in range(len(mini_sentences)):
|
||||
mini_sentence = mini_sentences[mini_sentence_i]
|
||||
mini_sentence_id = mini_sentences_id[mini_sentence_i]
|
||||
mini_sentence = cache_sent + mini_sentence
|
||||
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
|
||||
data = {
|
||||
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
|
||||
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
|
||||
}
|
||||
data = to_device(data, self.device)
|
||||
# y, _ = self.wrapped_model(**data)
|
||||
y, _ = self.punc_forward(text, text_lengths)
|
||||
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
|
||||
punctuations = indices
|
||||
if indices.size()[0] != 1:
|
||||
punctuations = torch.squeeze(indices)
|
||||
assert punctuations.size()[0] == len(mini_sentence)
|
||||
|
||||
# Search for the last Period/QuestionMark as cache
|
||||
if mini_sentence_i < len(mini_sentences) - 1:
|
||||
sentenceEnd = -1
|
||||
last_comma_index = -1
|
||||
for i in range(len(punctuations) - 2, 1, -1):
|
||||
if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?":
|
||||
sentenceEnd = i
|
||||
break
|
||||
if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
|
||||
last_comma_index = i
|
||||
|
||||
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
|
||||
# The sentence it too long, cut off at a comma.
|
||||
sentenceEnd = last_comma_index
|
||||
punctuations[sentenceEnd] = self.period
|
||||
cache_sent = mini_sentence[sentenceEnd + 1:]
|
||||
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
|
||||
mini_sentence = mini_sentence[0:sentenceEnd + 1]
|
||||
punctuations = punctuations[0:sentenceEnd + 1]
|
||||
|
||||
# if len(punctuations) == 0:
|
||||
# continue
|
||||
|
||||
punctuations_np = punctuations.cpu().numpy()
|
||||
new_mini_sentence_punc += [int(x) for x in punctuations_np]
|
||||
words_with_punc = []
|
||||
for i in range(len(mini_sentence)):
|
||||
if (i==0 or self.punc_list[punctuations[i-1]] == "。" or self.punc_list[punctuations[i-1]] == "?") and len(mini_sentence[i][0].encode()) == 1:
|
||||
mini_sentence[i] = mini_sentence[i].capitalize()
|
||||
if i == 0:
|
||||
if len(mini_sentence[i][0].encode()) == 1:
|
||||
mini_sentence[i] = " " + mini_sentence[i]
|
||||
if i > 0:
|
||||
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
|
||||
mini_sentence[i] = " " + mini_sentence[i]
|
||||
words_with_punc.append(mini_sentence[i])
|
||||
if self.punc_list[punctuations[i]] != "_":
|
||||
punc_res = self.punc_list[punctuations[i]]
|
||||
if len(mini_sentence[i][0].encode()) == 1:
|
||||
if punc_res == ",":
|
||||
punc_res = ","
|
||||
elif punc_res == "。":
|
||||
punc_res = "."
|
||||
elif punc_res == "?":
|
||||
punc_res = "?"
|
||||
words_with_punc.append(punc_res)
|
||||
new_mini_sentence += "".join(words_with_punc)
|
||||
# Add Period for the end of the sentence
|
||||
new_mini_sentence_out = new_mini_sentence
|
||||
new_mini_sentence_punc_out = new_mini_sentence_punc
|
||||
if mini_sentence_i == len(mini_sentences) - 1:
|
||||
if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
|
||||
new_mini_sentence_out = new_mini_sentence[:-1] + "。"
|
||||
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
|
||||
elif new_mini_sentence[-1] == ",":
|
||||
new_mini_sentence_out = new_mini_sentence[:-1] + "."
|
||||
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
|
||||
elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==0:
|
||||
new_mini_sentence_out = new_mini_sentence + "。"
|
||||
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
|
||||
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
|
||||
new_mini_sentence_out = new_mini_sentence + "."
|
||||
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
|
||||
|
||||
return new_mini_sentence_out, new_mini_sentence_punc_out
|
||||
|
||||
# if self.with_vad():
|
||||
# assert vad_indexes is not None
|
||||
# return self.punc_forward(text, text_lengths, vad_indexes)
|
||||
# else:
|
||||
# return self.punc_forward(text, text_lengths)
|
||||
14
funasr/models/ct_transformer/utils.py
Normal file
14
funasr/models/ct_transformer/utils.py
Normal file
@ -0,0 +1,14 @@
|
||||
|
||||
|
||||
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
||||
assert word_limit > 1
|
||||
if len(words) <= word_limit:
|
||||
return [words]
|
||||
sentences = []
|
||||
length = len(words)
|
||||
sentence_len = length // word_limit
|
||||
for i in range(sentence_len):
|
||||
sentences.append(words[i * word_limit:(i + 1) * word_limit])
|
||||
if length % word_limit > 0:
|
||||
sentences.append(words[sentence_len * word_limit:])
|
||||
return sentences
|
||||
Loading…
Reference in New Issue
Block a user