diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index f8fedf25c..3544e43ce 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -982,7 +982,7 @@ class LLMASR4(nn.Module): fbank_beg: torch.Tensor = None, fbank_mask: torch.Tensor = None, **kwargs, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + ): """Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) @@ -2280,13 +2280,13 @@ class LLMASR5(nn.Module): def forward( self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - labels_ids: torch.Tensor, - fbank_beg: torch.Tensor, - fbank_mask: torch.Tensor, + speech: torch.Tensor = None, + speech_lengths: torch.Tensor = None, + input_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + labels_ids: torch.Tensor = None, + fbank_beg: torch.Tensor = None, + fbank_mask: torch.Tensor = None, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Encoder + Decoder + Calc loss @@ -2299,55 +2299,55 @@ class LLMASR5(nn.Module): # import pdb # # pdb.set_trace() - if len(speech_lengths.size()) > 1: - speech_lengths = speech_lengths[:, 0] - - batch_size_speech, frames, _ = speech.shape - batch_size, token_num = input_ids.shape - - with torch.cuda.amp.autocast(enabled=False): - # audio encoder - encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - - # audio_adaptor - encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) - input_ids[input_ids < 0] = 0 inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) + if speech is not None: + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] - batch_size, token_num, dims = inputs_embeds.shape - fake_token_len = kwargs.get("fake_token_len") - fake_token_len[fake_token_len < 0] = 0 - fbank_beg[fbank_beg < 0] = 0 + batch_size_speech, frames, _ = speech.shape + batch_size, token_num = input_ids.shape - speech_idx = 0 - for batch_idx in range(batch_size): + with torch.cuda.amp.autocast(enabled=False): + # audio encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - for turn_id in range(fbank_beg.shape[1]): - fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() - if fbank_beg_idx > 0: - speech_token_len = fake_token_len[batch_idx, turn_id] - speech_token = encoder_out[speech_idx, :speech_token_len, :] + # audio_adaptor + encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) - try: - inputs_embeds[ - batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : - ] = speech_token - except Exception as e: - # - logging.error(f"{str(e)}, {traceback.format_exc()}") - logging.info( - f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" - ) - # import pdb; - # pdb.set_trace() - speech_token_len = encoder_out_lens[speech_idx].item() + batch_size, token_num, dims = inputs_embeds.shape + fake_token_len = kwargs.get("fake_token_len") + fake_token_len[fake_token_len < 0] = 0 + fbank_beg[fbank_beg < 0] = 0 + + speech_idx = 0 + for batch_idx in range(batch_size): + + for turn_id in range(fbank_beg.shape[1]): + fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() + if fbank_beg_idx > 0: + speech_token_len = fake_token_len[batch_idx, turn_id] speech_token = encoder_out[speech_idx, :speech_token_len, :] - inputs_embeds[ - batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : - ] = speech_token - speech_idx += 1 + try: + inputs_embeds[ + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : + ] = speech_token + except Exception as e: + # + logging.error(f"{str(e)}, {traceback.format_exc()}") + logging.info( + f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" + ) + # import pdb; + # pdb.set_trace() + speech_token_len = encoder_out_lens[speech_idx].item() + speech_token = encoder_out[speech_idx, :speech_token_len, :] + inputs_embeds[ + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : + ] = speech_token + + speech_idx += 1 with torch.cuda.amp.autocast( enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]