diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index cf554384a..24837f0dd 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -1408,7 +1408,7 @@ class LLMASR4(nn.Module): return results, meta_data -@tables.register("model_classes", "LLMASR5") +# @tables.register("model_classes", "LLMASR5") class LLMASR5(nn.Module): """ """ @@ -2011,41 +2011,19 @@ class LLMASR5(nn.Module): """ """ def __init__( - self, - specaug: str = None, - specaug_conf: dict = None, - normalize: str = None, - normalize_conf: dict = None, - audio_encoder: str = None, - audio_encoder_conf: dict = None, - audio_adaptor: str = None, - audio_adaptor_conf: dict = None, - decoder: str = None, - decoder_conf: dict = None, - ctc: str = None, - ctc_conf: dict = None, - ctc_weight: float = 0.5, - llm: str = None, - llm_conf: dict = None, - input_size: int = 80, - vocab_size: int = -1, - ignore_id: int = -1, - blank_id: int = 0, - sos: int = 1, - eos: int = 2, - lsm_weight: float = 0.0, - length_normalized_loss: bool = False, - report_cer: bool = True, - report_wer: bool = True, - sym_space: str = "", - sym_blank: str = "", - # extract_feats_in_collect_stats: bool = True, - share_embedding: bool = False, - # preencoder: Optional[AbsPreEncoder] = None, - # postencoder: Optional[AbsPostEncoder] = None, - audio_decoder: str = None, - audio_decoder_conf: dict = None, - **kwargs, + self, + audio_encoder: str = None, + audio_encoder_conf: dict = None, + audio_adaptor: str = None, + audio_adaptor_conf: dict = None, + llm: str = None, + llm_conf: dict = None, + input_size: int = 80, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + audio_decoder: str = None, + audio_decoder_conf: dict = None, + **kwargs, ): super().__init__() @@ -2082,7 +2060,7 @@ class LLMASR5(nn.Module): idx = re.search(r"\.\d+\.", name) if idx is not None: beg, end = idx.regs[0] - layer_id = int(name[beg + 1: end - 1]) + layer_id = int(name[beg + 1 : end - 1]) if layer_id < freeze_layer_num: param.requires_grad = False elif "ln_post." not in name: @@ -2134,6 +2112,8 @@ class LLMASR5(nn.Module): self.length_normalized_loss = length_normalized_loss self.beam_search = None + self.eos = kwargs.get("eos", 151645) + # audio decoder related self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf) self.audio_decoder_in_proj = torch.nn.Linear(llm_dim, self.audio_decoder.embed_unit) @@ -2148,16 +2128,12 @@ class LLMASR5(nn.Module): def build_audio_decoder(self, name, conf): if name == "transformer": from funasr.models.llm_asr.transformer_lm import TransformerEmbedLM + if "text_vocab_size" in conf: - lm_model = TransformerEmbedLM( - vocab_size=self.lm_out_voc_size, - **conf - ) + lm_model = TransformerEmbedLM(vocab_size=self.lm_out_voc_size, **conf) else: lm_model = TransformerEmbedLM( - vocab_size=self.lm_out_voc_size, - text_vocab_size=self.lm_out_voc_size, - **conf + vocab_size=self.lm_out_voc_size, text_vocab_size=self.lm_out_voc_size, **conf ) else: raise TypeError(f"Unknown codec decoder type {name}") @@ -2175,30 +2151,35 @@ class LLMASR5(nn.Module): return self.codec_embedder(codec * mask).sum(dim=-2) * mask def prepare_audio_decoder_io( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - codec: Optional[torch.Tensor] = None, - codec_lengths: Optional[torch.Tensor] = None, - need_targets: bool = True, + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + codec: Optional[torch.Tensor] = None, + codec_lengths: Optional[torch.Tensor] = None, + need_targets: bool = True, ): """build inputs and targets for language model - Normally, this function is called in batchify_nll. - Args: - text: (Batch, Length, Dim) - text_lengths: (Batch,) - codec: (Batch, Length) - codec_lengths: (Batch,) - need_targets: bool, whether provide targets - """ + Normally, this function is called in batchify_nll. + Args: + text: (Batch, Length, Dim) + text_lengths: (Batch,) + codec: (Batch, Length) + codec_lengths: (Batch,) + need_targets: bool, whether provide targets + """ if need_targets: - assert codec is not None and codec_lengths is not None, \ - "need_target=True, but codec or codec_length is None" + assert ( + codec is not None and codec_lengths is not None + ), "need_target=True, but codec or codec_length is None" - sos_eos_emb = self.audio_decoder_embedding(torch.tensor([self.ad_sos_eos], dtype=torch.int64, device=text.device)) - task_id_emb = self.audio_decoder_embedding(torch.tensor([self.ad_task_id], dtype=torch.int64, device=text.device)) + sos_eos_emb = self.audio_decoder_embedding( + torch.tensor([self.ad_sos_eos], dtype=torch.int64, device=text.device) + ) + task_id_emb = self.audio_decoder_embedding( + torch.tensor([self.ad_task_id], dtype=torch.int64, device=text.device) + ) codec_emb = None if codec is not None and codec_lengths is not None: codec_emb = self.calc_dense_vector(codec, codec_lengths) @@ -2206,7 +2187,7 @@ class LLMASR5(nn.Module): for i, text_len in enumerate(text_lengths): one_input = [sos_eos_emb, text[i, :text_len], task_id_emb] if codec_emb is not None: - one_input.append(codec_emb[i, :codec_lengths[i]]) + one_input.append(codec_emb[i, : codec_lengths[i]]) inputs_list.append(torch.cat(one_input, dim=0)) llm_inputs = pad_list(inputs_list, 0.0) llm_lengths = text_lengths + 2 @@ -2217,7 +2198,9 @@ class LLMASR5(nn.Module): return llm_inputs, llm_lengths bb, tt = text.shape[0], codec_lengths.max() + 1 - llm_targets = -1 * torch.ones([bb, tt, self.predict_nq], dtype=torch.int64, device=text.device) + llm_targets = -1 * torch.ones( + [bb, tt, self.predict_nq], dtype=torch.int64, device=text.device + ) for i, codec_len in enumerate(codec_lengths): llm_targets[i, :codec_len] = codec[i, :codec_len] llm_targets[i, codec_len] = self.codebook_size + self.sos_eos @@ -2242,36 +2225,33 @@ class LLMASR5(nn.Module): """ batch_size = text.size(0) # For data parallel - text = text[:, :text_lengths.max()] - codec = codec[:, :codec_lengths.max()] + text = text[:, : text_lengths.max()] + codec = codec[:, : codec_lengths.max()] text = self.audio_decoder_in_proj(text) # build inputs and targets for language model with autocast(False): (sequence, target), (x_lengths, y_lengths) = self.prepare_audio_decoder_io( - text, text_lengths, - codec, codec_lengths, - need_targets=True + text, text_lengths, codec, codec_lengths, need_targets=True ) # 2a. Forward Language model # x: (Batch, Length) -> y: (Batch, Length, NVocab) - sequence = sequence[:, :x_lengths.max()] - target = target[:, :y_lengths.max()] - y, _ = self.audio_decoder(sequence, x_lengths, text_lengths+1) + sequence = sequence[:, : x_lengths.max()] + target = target[:, : y_lengths.max()] + y, _ = self.audio_decoder(sequence, x_lengths, text_lengths + 1) bb, tt = y.shape[0], y.shape[1] y = y.reshape(bb, tt, self.predict_nq, -1) # 2b. Extract real logits logits_list = [] for i, (text_len, codec_len) in enumerate(zip(text_lengths, codec_lengths)): - logits_list.append(y[i, text_len + 1:text_len + 2 + codec_len]) + logits_list.append(y[i, text_len + 1 : text_len + 2 + codec_len]) logits = pad_list(logits_list, 0.0) # 3. Calc negative log likelihood tt = logits.shape[1] nll = self.criterion_ce( - logits.reshape(bb, tt * self.predict_nq, -1), - target.reshape(bb, tt * self.predict_nq) + logits.reshape(bb, tt * self.predict_nq, -1), target.reshape(bb, tt * self.predict_nq) ) nll = nll.sum(-1) # nll: (BxL,) -> (BxL,) @@ -2279,18 +2259,18 @@ class LLMASR5(nn.Module): # nll: (BxL,) -> (B, L) nll = nll.reshape(batch_size, -1).reshape(batch_size, tt, self.predict_nq) - return nll, logits, target, codec_lengths+1 + return nll, logits, target, codec_lengths + 1 def forward( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - labels_ids: torch.Tensor, - fbank_beg: torch.Tensor, - fbank_mask: torch.Tensor, - **kwargs, + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels_ids: torch.Tensor, + fbank_beg: torch.Tensor, + fbank_mask: torch.Tensor, + **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Encoder + Decoder + Calc loss Args: @@ -2334,7 +2314,7 @@ class LLMASR5(nn.Module): try: inputs_embeds[ - batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token except Exception as e: # @@ -2347,13 +2327,13 @@ class LLMASR5(nn.Module): speech_token_len = encoder_out_lens[speech_idx].item() speech_token = encoder_out[speech_idx, :speech_token_len, :] inputs_embeds[ - batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token speech_idx += 1 with torch.cuda.amp.autocast( - enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype] + enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype] ): labels_ids[labels_ids == -1] = -100 attention_mask[attention_mask < 0] = 0 @@ -2364,6 +2344,47 @@ class LLMASR5(nn.Module): ) loss = model_outputs.loss + codec = kwargs.get("codec") + codec_len = kwargs.get("codec_len") + if len(codec_len.size()) > 1: + codec_len = codec_len[:, 0] + hidden_states = model_outputs.hidden_states[-1].float() + + target_ids = [] + target_ids_len = [] + hidden_states_select = [] + for batch_idx in range(labels_ids.shape[0]): + beg_i = 0 + end_i = 0 + for token_idx in range(labels_ids.shape[1]): + token_int = labels_ids[batch_idx, token_idx].item() + if token_int == self.eos: + target_ids_i = labels_ids[batch_idx, beg_i:end_i] + target_ids_len_i = end_i - beg_i + target_ids_len.append(target_ids_len_i) + target_ids.append(target_ids_i) + hidden_states_i = hidden_states[batch_idx, beg_i - 1 : end_i - 1, :] + hidden_states_select.append(hidden_states_i) + beg_i = end_i + continue + + end_i += 1 + if token_int <= 0: + beg_i += 1 + + target_ids = torch.nn.utils.rnn.pad_sequence( + target_ids, batch_first=True, padding_value=-100 + ) + hidden_states_select = torch.nn.utils.rnn.pad_sequence( + hidden_states_select, batch_first=True, padding_value=0.0 + ) + target_ids_len = torch.tensor(target_ids_len, dtype=torch.int32, device=input_ids.device) + target_ids = target_ids.to(device=input_ids.device) + hidden_states_select = hidden_states_select.to(device=input_ids.device) + loss, logits, target, codec_lengths = self.nll( + hidden_states_select, target_ids_len, codec, codec_len + ) + stats = {} with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) @@ -2487,10 +2508,10 @@ class LLMASR5(nn.Module): time3 = time.perf_counter() meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data["batch_data_time"] = ( - speech_lengths.sum().item() - * frontend.frame_shift - * frontend.lfr_n - / 1000 + speech_lengths.sum().item() + * frontend.frame_shift + * frontend.lfr_n + / 1000 ) if kwargs.get("permute", True): @@ -2558,13 +2579,13 @@ class LLMASR5(nn.Module): return output def inference_prepare( - self, - data_in, - data_lengths=None, - key: list = None, - tokenizer=None, - frontend=None, - **kwargs, + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, ): meta_data = {} @@ -2619,7 +2640,7 @@ class LLMASR5(nn.Module): try: inputs_embeds[ - batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token except Exception as e: # @@ -2632,20 +2653,20 @@ class LLMASR5(nn.Module): speech_token_len = encoder_out_lens[speech_idx].item() speech_token = encoder_out[speech_idx, :speech_token_len, :] inputs_embeds[ - batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, : + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token speech_idx += 1 return inputs_embeds, contents, batch, source_ids, meta_data def inference( - self, - data_in, - data_lengths=None, - key: list = None, - tokenizer=None, - frontend=None, - **kwargs, + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, ): inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare( @@ -2658,7 +2679,7 @@ class LLMASR5(nn.Module): llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype with torch.cuda.amp.autocast( - enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] + enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] ): label = contents["assistant"][-1] self.llm = self.llm.to(dtype_map[llm_dtype]) @@ -2688,7 +2709,7 @@ class LLMASR5(nn.Module): inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids ) - preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]:] + preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :] response = tokenizer.batch_decode( preds, add_special_tokens=False,