diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py index 9aea1a3eb..2e87675fd 100755 --- a/funasr/bin/asr_inference_uniasr.py +++ b/funasr/bin/asr_inference_uniasr.py @@ -215,14 +215,14 @@ class Speech2Text: lfr_factor = max(1, (speech.size()[-1] // 80) - 1) # lengths: (1,) lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) - speech_raw = speech.clone().to(self.device) if self.frontend is not None: feats, feats_len = self.frontend.forward(speech, lengths) feats = to_device(feats, device=self.device) feats_len = feats_len.int() else: - feats = speech_raw + feats = speech feats_len = lengths + feats_raw = feats.clone().to(self.device) batch = {"speech": feats, "speech_lengths": feats_len} # a. To device @@ -235,7 +235,7 @@ class Speech2Text: if self.decoding_mode == "model1": predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len) else: - enc, enc_len = self.asr_model.encode2(enc, enc_len, feats, feats_len, ind=self.decoding_ind) + enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind) predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len) scama_mask = predictor_outs[4]