Merge branch 'dev_cmz2' of github.com:alibaba-damo-academy/FunASR into dev_cmz2

add
This commit is contained in:
游雁 2023-03-30 16:35:03 +08:00
commit c5acc04e2d
2 changed files with 5 additions and 7 deletions

View File

@ -4,6 +4,6 @@ model_dir = "/disk1/mengzhe.cmz/workspace/FunASR/funasr/export/damo/punc_ct-tran
model = TargetDelayTransformer(model_dir)
text_in = "我们都是木头人不会讲话不会动"
text_in="跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益"
result = model(text_in)
print(result)
print(result[0])

View File

@ -76,9 +76,8 @@ class TargetDelayTransformer():
try:
outputs = self.infer(data['text'], data['text_lengths'])
y = outputs[0]
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
assert punctuations.size()[0] == len(mini_sentence)
punctuations = np.argmax(y,axis=-1)[0]
assert punctuations.size == len(mini_sentence)
except ONNXRuntimeError:
logging.warning("error")
@ -102,8 +101,7 @@ class TargetDelayTransformer():
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
punctuations_np = punctuations.cpu().numpy()
new_mini_sentence_punc += [int(x) for x in punctuations_np]
new_mini_sentence_punc += [int(x) for x in punctuations]
words_with_punc = []
for i in range(len(mini_sentence)):
if i > 0: