This commit is contained in:
九耳 2023-03-30 16:48:55 +08:00
parent c5acc04e2d
commit 0fd9640ced
2 changed files with 26 additions and 18 deletions

View File

@ -8,8 +8,7 @@ import numpy as np
from .utils.utils import (ONNXRuntimeError,
OrtInferSession, get_logger,
read_yaml)
from .utils.preprocessor import CodeMixTokenizerCommonPreprocessor
from .utils.utils import split_to_mini_sentence
from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
logging = get_logger()
@ -30,6 +29,7 @@ class TargetDelayTransformer():
config_file = os.path.join(model_dir, 'punc.yaml')
config = read_yaml(config_file)
self.converter = TokenIDConverter(config['token_list'])
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
self.batch_size = 1
self.punc_list = config['punc_list']
@ -41,23 +41,12 @@ class TargetDelayTransformer():
self.punc_list[i] = ""
elif self.punc_list[i] == "":
self.period = i
self.preprocessor = CodeMixTokenizerCommonPreprocessor(
train=False,
token_type=config['token_type'],
token_list=config['token_list'],
bpemodel=config['bpemodel'],
text_cleaner=config['cleaner'],
g2p_type=config['g2p'],
text_name="text",
non_linguistic_symbols=config['non_linguistic_symbols'],
)
def __call__(self, text: Union[list, str], split_size=20):
data = {"text": text}
result = self.preprocessor(data=data, uid="12938712838719")
split_text = self.preprocessor.pop_split_text_data(result)
split_text = code_mix_split_words(text)
split_text_id = self.converter.tokens2ids(split_text)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = []
@ -68,9 +57,9 @@ class TargetDelayTransformer():
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
data = {
"text": mini_sentence_id[None,:].astype(np.int64),
"text": mini_sentence_id[None,:],
"text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
}
try:

View File

@ -228,6 +228,25 @@ def split_to_mini_sentence(words: list, word_limit: int = 20):
sentences.append(words[sentence_len * word_limit:])
return sentences
def code_mix_split_words(text: str):
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():