diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py index 57a23ccb2..edb398ff3 100644 --- a/funasr/models/ct_transformer/model.py +++ b/funasr/models/ct_transformer/model.py @@ -347,8 +347,10 @@ class CTTransformer(torch.nn.Module): punc_array = punctuations else: punc_array = torch.cat([punc_array, punctuations], dim=0) + # post processing when using word level punc model if self.jieba_usr_dict is not None: + punc_array = punc_array.reshape(-1) len_tokens = len(tokens) new_punc_array = copy.copy(punc_array).tolist() # for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])):