mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
ef5ea9b05f
commit
259ea7523f
@ -982,7 +982,7 @@ class LLMASR4(nn.Module):
|
||||
fbank_beg: torch.Tensor = None,
|
||||
fbank_mask: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
):
|
||||
"""Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
@ -2280,13 +2280,13 @@ class LLMASR5(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
labels_ids: torch.Tensor,
|
||||
fbank_beg: torch.Tensor,
|
||||
fbank_mask: torch.Tensor,
|
||||
speech: torch.Tensor = None,
|
||||
speech_lengths: torch.Tensor = None,
|
||||
input_ids: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels_ids: torch.Tensor = None,
|
||||
fbank_beg: torch.Tensor = None,
|
||||
fbank_mask: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Encoder + Decoder + Calc loss
|
||||
@ -2299,55 +2299,55 @@ class LLMASR5(nn.Module):
|
||||
# import pdb
|
||||
#
|
||||
# pdb.set_trace()
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size_speech, frames, _ = speech.shape
|
||||
batch_size, token_num = input_ids.shape
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
# audio encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# audio_adaptor
|
||||
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||
|
||||
input_ids[input_ids < 0] = 0
|
||||
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
|
||||
if speech is not None:
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
fake_token_len = kwargs.get("fake_token_len")
|
||||
fake_token_len[fake_token_len < 0] = 0
|
||||
fbank_beg[fbank_beg < 0] = 0
|
||||
batch_size_speech, frames, _ = speech.shape
|
||||
batch_size, token_num = input_ids.shape
|
||||
|
||||
speech_idx = 0
|
||||
for batch_idx in range(batch_size):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
# audio encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
for turn_id in range(fbank_beg.shape[1]):
|
||||
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
|
||||
if fbank_beg_idx > 0:
|
||||
speech_token_len = fake_token_len[batch_idx, turn_id]
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
# audio_adaptor
|
||||
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
|
||||
|
||||
try:
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
except Exception as e:
|
||||
#
|
||||
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
||||
logging.info(
|
||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
|
||||
)
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
speech_token_len = encoder_out_lens[speech_idx].item()
|
||||
batch_size, token_num, dims = inputs_embeds.shape
|
||||
fake_token_len = kwargs.get("fake_token_len")
|
||||
fake_token_len[fake_token_len < 0] = 0
|
||||
fbank_beg[fbank_beg < 0] = 0
|
||||
|
||||
speech_idx = 0
|
||||
for batch_idx in range(batch_size):
|
||||
|
||||
for turn_id in range(fbank_beg.shape[1]):
|
||||
fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
|
||||
if fbank_beg_idx > 0:
|
||||
speech_token_len = fake_token_len[batch_idx, turn_id]
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
|
||||
speech_idx += 1
|
||||
try:
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
except Exception as e:
|
||||
#
|
||||
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
||||
logging.info(
|
||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
|
||||
)
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
speech_token_len = encoder_out_lens[speech_idx].item()
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
|
||||
speech_idx += 1
|
||||
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user