This commit is contained in:
九耳 2023-03-30 16:15:49 +08:00
parent 4b3c492998
commit 7df8452a85

View File

@ -76,9 +76,8 @@ class TargetDelayTransformer():
try:
outputs = self.infer(data['text'], data['text_lengths'])
y = outputs[0]
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
assert punctuations.size()[0] == len(mini_sentence)
punctuations = np.argmax(y,axis=-1)[0]
assert punctuations.size == len(mini_sentence)
except ONNXRuntimeError:
logging.warning("error")