mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix
This commit is contained in:
parent
4b3c492998
commit
7df8452a85
@ -76,9 +76,8 @@ class TargetDelayTransformer():
|
||||
try:
|
||||
outputs = self.infer(data['text'], data['text_lengths'])
|
||||
y = outputs[0]
|
||||
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
|
||||
punctuations = indices
|
||||
assert punctuations.size()[0] == len(mini_sentence)
|
||||
punctuations = np.argmax(y,axis=-1)[0]
|
||||
assert punctuations.size == len(mini_sentence)
|
||||
except ONNXRuntimeError:
|
||||
logging.warning("error")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user