From da98950b422bd14d2c9357a878c19268b196b9c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BB=81=E8=BF=B7?= Date: Thu, 29 Dec 2022 14:12:58 +0800 Subject: [PATCH] fix uniasr inference bug --- funasr/bin/asr_inference_uniasr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py index 9aea1a3eb..2e87675fd 100755 --- a/funasr/bin/asr_inference_uniasr.py +++ b/funasr/bin/asr_inference_uniasr.py @@ -215,14 +215,14 @@ class Speech2Text: lfr_factor = max(1, (speech.size()[-1] // 80) - 1) # lengths: (1,) lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) - speech_raw = speech.clone().to(self.device) if self.frontend is not None: feats, feats_len = self.frontend.forward(speech, lengths) feats = to_device(feats, device=self.device) feats_len = feats_len.int() else: - feats = speech_raw + feats = speech feats_len = lengths + feats_raw = feats.clone().to(self.device) batch = {"speech": feats, "speech_lengths": feats_len} # a. To device @@ -235,7 +235,7 @@ class Speech2Text: if self.decoding_mode == "model1": predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len) else: - enc, enc_len = self.asr_model.encode2(enc, enc_len, feats, feats_len, ind=self.decoding_ind) + enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind) predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len) scama_mask = predictor_outs[4]