This commit is contained in:
维石 2024-04-10 11:37:27 +08:00
parent 112c8e6eb7
commit b8bf792ce7

View File

@ -347,8 +347,10 @@ class CTTransformer(torch.nn.Module):
punc_array = punctuations
else:
punc_array = torch.cat([punc_array, punctuations], dim=0)
# post processing when using word level punc model
if self.jieba_usr_dict is not None:
punc_array = punc_array.reshape(-1)
len_tokens = len(tokens)
new_punc_array = copy.copy(punc_array).tolist()
# for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])):