mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
infer for word punc model
This commit is contained in:
parent
294c1162df
commit
e451eb799a
@ -3,6 +3,7 @@
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import copy
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
@ -341,11 +342,25 @@ class CTTransformer(torch.nn.Module):
|
||||
new_mini_sentence_out = new_mini_sentence + "."
|
||||
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
|
||||
if len(punctuations): punctuations[-1] = 2
|
||||
# keep a punctuations array for punc segment
|
||||
# keep a punctuations array for punc segment
|
||||
if punc_array is None:
|
||||
punc_array = punctuations
|
||||
else:
|
||||
punc_array = torch.cat([punc_array, punctuations], dim=0)
|
||||
# post processing when using word level punc model
|
||||
if jieba_usr_dict:
|
||||
len_tokens = len(tokens)
|
||||
new_punc_array = copy.copy(punc_array).tolist()
|
||||
# for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])):
|
||||
for i, token in enumerate(tokens[::-1]):
|
||||
if '\u0e00' <= token[0] <= '\u9fa5': # ignore en words
|
||||
if len(token) > 1:
|
||||
num_append = len(token) - 1
|
||||
ind_append = len_tokens - i - 1
|
||||
for _ in range(num_append):
|
||||
new_punc_array.insert(ind_append, 1)
|
||||
punc_array = torch.tensor(new_punc_array)
|
||||
|
||||
result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
|
||||
results.append(result_i)
|
||||
return results, meta_data
|
||||
|
||||
Loading…
Reference in New Issue
Block a user