This commit is contained in:
haoneng.lhn 2023-05-18 21:10:16 +08:00
parent 47bd60924c
commit 025df72c10

View File

@ -488,15 +488,20 @@ class Speech2TextParaformer:
nbest_hyps = nbest_hyps[: self.nbest]
else:
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
[self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
)
if pre_token_length[i] == 0:
yseq = torch.tensor(
[self.asr_model.sos] + [self.asr_model.eos], device=yseq.device
)
score = torch.tensor(0.0, device=yseq.device)
else:
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
[self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)