This commit is contained in:
游雁 2024-07-08 16:05:35 +08:00
parent ef5ea9b05f
commit 259ea7523f

View File

@ -982,7 +982,7 @@ class LLMASR4(nn.Module):
fbank_beg: torch.Tensor = None,
fbank_mask: torch.Tensor = None,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
):
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
@ -2280,13 +2280,13 @@ class LLMASR5(nn.Module):
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
fbank_beg: torch.Tensor,
fbank_mask: torch.Tensor,
speech: torch.Tensor = None,
speech_lengths: torch.Tensor = None,
input_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
labels_ids: torch.Tensor = None,
fbank_beg: torch.Tensor = None,
fbank_mask: torch.Tensor = None,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
@ -2299,55 +2299,55 @@ class LLMASR5(nn.Module):
# import pdb
#
# pdb.set_trace()
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size_speech, frames, _ = speech.shape
batch_size, token_num = input_ids.shape
with torch.cuda.amp.autocast(enabled=False):
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
if speech is not None:
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size, token_num, dims = inputs_embeds.shape
fake_token_len = kwargs.get("fake_token_len")
fake_token_len[fake_token_len < 0] = 0
fbank_beg[fbank_beg < 0] = 0
batch_size_speech, frames, _ = speech.shape
batch_size, token_num = input_ids.shape
speech_idx = 0
for batch_idx in range(batch_size):
with torch.cuda.amp.autocast(enabled=False):
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
for turn_id in range(fbank_beg.shape[1]):
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
if fbank_beg_idx > 0:
speech_token_len = fake_token_len[batch_idx, turn_id]
speech_token = encoder_out[speech_idx, :speech_token_len, :]
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
try:
inputs_embeds[
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
] = speech_token
except Exception as e:
#
logging.error(f"{str(e)}, {traceback.format_exc()}")
logging.info(
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
)
# import pdb;
# pdb.set_trace()
speech_token_len = encoder_out_lens[speech_idx].item()
batch_size, token_num, dims = inputs_embeds.shape
fake_token_len = kwargs.get("fake_token_len")
fake_token_len[fake_token_len < 0] = 0
fbank_beg[fbank_beg < 0] = 0
speech_idx = 0
for batch_idx in range(batch_size):
for turn_id in range(fbank_beg.shape[1]):
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
if fbank_beg_idx > 0:
speech_token_len = fake_token_len[batch_idx, turn_id]
speech_token = encoder_out[speech_idx, :speech_token_len, :]
inputs_embeds[
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
] = speech_token
speech_idx += 1
try:
inputs_embeds[
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
] = speech_token
except Exception as e:
#
logging.error(f"{str(e)}, {traceback.format_exc()}")
logging.info(
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
)
# import pdb;
# pdb.set_trace()
speech_token_len = encoder_out_lens[speech_idx].item()
speech_token = encoder_out[speech_idx, :speech_token_len, :]
inputs_embeds[
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
] = speech_token
speech_idx += 1
with torch.cuda.amp.autocast(
enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]