This commit is contained in:
游雁 2024-06-13 17:40:40 +08:00
parent c553a8db17
commit 7355e20503

View File

@ -496,11 +496,14 @@ class LLMASR2(nn.Module):
batch_size, frames, _ = speech.shape
# audio encoder
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
with torch.cuda.amp.autocast(enabled=False):
# audio encoder
encoder_out, encoder_out_lens = self.audio_encoder(
speech.permute(0, 2, 1), speech_lengths
)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)