This commit is contained in:
九耳 2023-03-30 14:20:35 +08:00
parent 62aab4a6da
commit 19bda23f5e

View File

@ -32,8 +32,7 @@ class TargetDelayTransformer():
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
self.batch_size = 1
self.encoder_conf = config["encoder_conf"]
self.punc_list = config.punc_list
self.punc_list = config['punc_list']
self.period = 0
for i in range(len(self.punc_list)):
if self.punc_list[i] == ",":
@ -44,13 +43,13 @@ class TargetDelayTransformer():
self.period = i
self.preprocessor = CodeMixTokenizerCommonPreprocessor(
train=False,
token_type=config.token_type,
token_list=config.token_list,
bpemodel=config.bpemodel,
text_cleaner=config.cleaner,
g2p_type=config.g2p,
token_type=config['token_type'],
token_list=config['token_list'],
bpemodel=config['bpemodel'],
text_cleaner=config['cleaner'],
g2p_type=config['g2p'],
text_name="text",
non_linguistic_symbols=config.non_linguistic_symbols,
non_linguistic_symbols=config['non_linguistic_symbols'],
)
def __call__(self, text: Union[list, str], split_size=20):