This commit is contained in:
游雁 2024-07-08 16:05:35 +08:00
parent ef5ea9b05f
commit 259ea7523f

View File

@ -982,7 +982,7 @@ class LLMASR4(nn.Module):
fbank_beg: torch.Tensor = None, fbank_beg: torch.Tensor = None,
fbank_mask: torch.Tensor = None, fbank_mask: torch.Tensor = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: ):
"""Encoder + Decoder + Calc loss """Encoder + Decoder + Calc loss
Args: Args:
speech: (Batch, Length, ...) speech: (Batch, Length, ...)
@ -2280,13 +2280,13 @@ class LLMASR5(nn.Module):
def forward( def forward(
self, self,
speech: torch.Tensor, speech: torch.Tensor = None,
speech_lengths: torch.Tensor, speech_lengths: torch.Tensor = None,
input_ids: torch.Tensor, input_ids: torch.Tensor = None,
attention_mask: torch.Tensor, attention_mask: torch.Tensor = None,
labels_ids: torch.Tensor, labels_ids: torch.Tensor = None,
fbank_beg: torch.Tensor, fbank_beg: torch.Tensor = None,
fbank_mask: torch.Tensor, fbank_mask: torch.Tensor = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss """Encoder + Decoder + Calc loss
@ -2299,6 +2299,9 @@ class LLMASR5(nn.Module):
# import pdb # import pdb
# #
# pdb.set_trace() # pdb.set_trace()
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
if speech is not None:
if len(speech_lengths.size()) > 1: if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0] speech_lengths = speech_lengths[:, 0]
@ -2312,9 +2315,6 @@ class LLMASR5(nn.Module):
# audio_adaptor # audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
batch_size, token_num, dims = inputs_embeds.shape batch_size, token_num, dims = inputs_embeds.shape
fake_token_len = kwargs.get("fake_token_len") fake_token_len = kwargs.get("fake_token_len")
fake_token_len[fake_token_len < 0] = 0 fake_token_len[fake_token_len < 0] = 0