mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
ef5ea9b05f
commit
259ea7523f
@ -982,7 +982,7 @@ class LLMASR4(nn.Module):
|
|||||||
fbank_beg: torch.Tensor = None,
|
fbank_beg: torch.Tensor = None,
|
||||||
fbank_mask: torch.Tensor = None,
|
fbank_mask: torch.Tensor = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
):
|
||||||
"""Encoder + Decoder + Calc loss
|
"""Encoder + Decoder + Calc loss
|
||||||
Args:
|
Args:
|
||||||
speech: (Batch, Length, ...)
|
speech: (Batch, Length, ...)
|
||||||
@ -2280,13 +2280,13 @@ class LLMASR5(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
speech: torch.Tensor,
|
speech: torch.Tensor = None,
|
||||||
speech_lengths: torch.Tensor,
|
speech_lengths: torch.Tensor = None,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor = None,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor = None,
|
||||||
labels_ids: torch.Tensor,
|
labels_ids: torch.Tensor = None,
|
||||||
fbank_beg: torch.Tensor,
|
fbank_beg: torch.Tensor = None,
|
||||||
fbank_mask: torch.Tensor,
|
fbank_mask: torch.Tensor = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||||
"""Encoder + Decoder + Calc loss
|
"""Encoder + Decoder + Calc loss
|
||||||
@ -2299,6 +2299,9 @@ class LLMASR5(nn.Module):
|
|||||||
# import pdb
|
# import pdb
|
||||||
#
|
#
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
|
input_ids[input_ids < 0] = 0
|
||||||
|
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
||||||
|
if speech is not None:
|
||||||
if len(speech_lengths.size()) > 1:
|
if len(speech_lengths.size()) > 1:
|
||||||
speech_lengths = speech_lengths[:, 0]
|
speech_lengths = speech_lengths[:, 0]
|
||||||
|
|
||||||
@ -2312,9 +2315,6 @@ class LLMASR5(nn.Module):
|
|||||||
# audio_adaptor
|
# audio_adaptor
|
||||||
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||||
|
|
||||||
input_ids[input_ids < 0] = 0
|
|
||||||
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
|
||||||
|
|
||||||
batch_size, token_num, dims = inputs_embeds.shape
|
batch_size, token_num, dims = inputs_embeds.shape
|
||||||
fake_token_len = kwargs.get("fake_token_len")
|
fake_token_len = kwargs.get("fake_token_len")
|
||||||
fake_token_len[fake_token_len < 0] = 0
|
fake_token_len[fake_token_len < 0] = 0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user