diff --git a/funasr/models/sense_voice/model_small.py b/funasr/models/sense_voice/model_small.py index b84ffeeca..61897bffb 100644 --- a/funasr/models/sense_voice/model_small.py +++ b/funasr/models/sense_voice/model_small.py @@ -2034,11 +2034,11 @@ class SenseVoiceL(nn.Module): lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1 meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000 - speech = speech.to(device=kwargs["device"])[0, :, :] + speech = speech.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"]) (outs, ret_dict), out_lens = self.model.encoder( - speech.permute(0, 2, 1), speech_lengths, + speech, speech_lengths, only_extract_tokens=True ) tokens = ret_dict["indices"]