add extract token run_mode

This commit is contained in:
志浩 2024-09-24 17:15:30 +08:00
parent 1fb762d9be
commit ce5b79d234

View File

@ -2034,11 +2034,11 @@ class SenseVoiceL(nn.Module):
lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
speech = speech.to(device=kwargs["device"])[0, :, :]
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
(outs, ret_dict), out_lens = self.model.encoder(
speech.permute(0, 2, 1), speech_lengths,
speech, speech_lengths,
only_extract_tokens=True
)
tokens = ret_dict["indices"]