mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
ef5ea9b05f
commit
259ea7523f
@ -982,7 +982,7 @@ class LLMASR4(nn.Module):
|
||||
fbank_beg: torch.Tensor = None,
|
||||
fbank_mask: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
):
|
||||
"""Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
@ -2280,13 +2280,13 @@ class LLMASR5(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
labels_ids: torch.Tensor,
|
||||
fbank_beg: torch.Tensor,
|
||||
fbank_mask: torch.Tensor,
|
||||
speech: torch.Tensor = None,
|
||||
speech_lengths: torch.Tensor = None,
|
||||
input_ids: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels_ids: torch.Tensor = None,
|
||||
fbank_beg: torch.Tensor = None,
|
||||
fbank_mask: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Encoder + Decoder + Calc loss
|
||||
@ -2299,6 +2299,9 @@ class LLMASR5(nn.Module):
|
||||
# import pdb
|
||||
#
|
||||
# pdb.set_trace()
|
||||
input_ids[input_ids < 0] = 0
|
||||
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
||||
if speech is not None:
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
@ -2312,9 +2315,6 @@ class LLMASR5(nn.Module):
|
||||
# audio_adaptor
|
||||
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||
|
||||
input_ids[input_ids < 0] = 0
|
||||
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
||||
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
fake_token_len = kwargs.get("fake_token_len")
|
||||
fake_token_len[fake_token_len < 0] = 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user