diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py index 64ced69be..8ea4517ec 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py @@ -32,8 +32,7 @@ class TargetDelayTransformer(): self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) self.batch_size = 1 - self.encoder_conf = config["encoder_conf"] - self.punc_list = config.punc_list + self.punc_list = config['punc_list'] self.period = 0 for i in range(len(self.punc_list)): if self.punc_list[i] == ",": @@ -44,13 +43,13 @@ class TargetDelayTransformer(): self.period = i self.preprocessor = CodeMixTokenizerCommonPreprocessor( train=False, - token_type=config.token_type, - token_list=config.token_list, - bpemodel=config.bpemodel, - text_cleaner=config.cleaner, - g2p_type=config.g2p, + token_type=config['token_type'], + token_list=config['token_list'], + bpemodel=config['bpemodel'], + text_cleaner=config['cleaner'], + g2p_type=config['g2p'], text_name="text", - non_linguistic_symbols=config.non_linguistic_symbols, + non_linguistic_symbols=config['non_linguistic_symbols'], ) def __call__(self, text: Union[list, str], split_size=20):