fix: solve problems in sensevoice_bin.py related to argmax and unique, as mentioned in issue #2331 (#2332)

This commit is contained in:
majic31 2024-12-24 17:51:11 +08:00 committed by GitHub
parent 2e0b208658
commit ae7aff2e9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -183,8 +183,10 @@ class SenseVoiceSmall:
# ctc_logits = torch.from_numpy(ctc_logits).float()
# support batch_size=1 only currently
x = ctc_logits[b, : encoder_out_lens[b].item(), :]
yseq = x.argmax(dim=-1)
yseq = np.unique(yseq)
yseq = np.argmax(x, axis=-1)
# Use np.diff and np.where instead of torch.unique_consecutive.
mask = np.concatenate(([True], np.diff(yseq) != 0))
yseq = yseq[mask]
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()