diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index 59462711f..acb5fd870 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -488,15 +488,20 @@ class Speech2TextParaformer: nbest_hyps = nbest_hyps[: self.nbest] else: - yseq = am_scores.argmax(dim=-1) - score = am_scores.max(dim=-1)[0] - score = torch.sum(score, dim=-1) - # pad with mask tokens to ensure compatibility with sos/eos tokens - yseq = torch.tensor( - [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device - ) + if pre_token_length[i] == 0: + yseq = torch.tensor( + [self.asr_model.sos] + [self.asr_model.eos], device=yseq.device + ) + score = torch.tensor(0.0, device=yseq.device) + else: + yseq = am_scores.argmax(dim=-1) + score = am_scores.max(dim=-1)[0] + score = torch.sum(score, dim=-1) + # pad with mask tokens to ensure compatibility with sos/eos tokens + yseq = torch.tensor( + [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device + ) nbest_hyps = [Hypothesis(yseq=yseq, score=score)] - for hyp in nbest_hyps: assert isinstance(hyp, (Hypothesis)), type(hyp)