diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py index a6bd0445a..fc690ccf7 100644 --- a/funasr/datasets/openai_datasets/datasets.py +++ b/funasr/datasets/openai_datasets/datasets.py @@ -351,15 +351,17 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset): splits = self.pattern.split(source_input) source_ids = [] fbank_i = [] - fbank_mask_i = [] + # fbank_mask_i = [] fake_token_len_i = 0 fbank_beg_i = -1 fbank_lens_i = [] + speech = [] + speech_lengths = [] for k, sub_str in enumerate(splits): if not sub_str.startswith("<|startofspeech|>"): sub_token = self.tokenizer.encode(sub_str) source_ids += sub_token - fbank_mask_i += [0] * len(sub_token) + # fbank_mask_i += [0] * len(sub_token) else: sub_str = sub_str.replace("<|startofspeech|>", "").replace( "<|endofspeech|>", "" @@ -395,22 +397,25 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset): fake_token = [0] * fake_token_len_i fbank_beg_i = len(source_ids) source_ids += fake_token - fbank_mask_i += [1] * len(fake_token) + # fbank_mask_i += [1] * len(fake_token) if badcase_flag: continue + if fbank_beg_i > 0: + fbank_beg += [fbank_beg_i + len(input_ids)] + fake_token_len += [fake_token_len_i] + else: + fbank_beg += [-1] + fake_token_len += [0] - fbank_beg += [fbank_beg_i + len(input_ids)] - fake_token_len += [fake_token_len_i] source_mask = [-100] * len(source_ids) target_out = f"{target_out}<|im_end|>" target_ids = self.tokenizer.encode(target_out) input_ids += source_ids + target_ids labels += source_mask + target_ids - fbank.append(speech[0, :, :]) - fbank_mask += fbank_mask_i - fbank_lens.append(speech_lengths) - + if len(speech) > 0: + fbank.append(speech[0, :, :]) + fbank_lens.append(speech_lengths) if badcase_flag: continue @@ -420,20 +425,20 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset): # fbank = speech[0, :, :] # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32) - fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32) + # fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32) fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32) fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32) output = { - "speech": fbank, - "speech_lengths": fbank_lens, - "fbank_mask": fbank_mask, "fbank_beg": fbank_beg, "fake_token_len": fake_token_len, "input_ids": input_ids, "attention_mask": attention_mask, "labels_ids": labels, } + if len(fbank) > 0: + output["speech"] = fbank + output["speech_lengths"] = fbank_lens break return output diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index f93212e17..17b25d1b8 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -1021,6 +1021,8 @@ class LLMASR4(nn.Module): # import pdb # # pdb.set_trace() + batch_size, token_num = input_ids.shape + stats = {} input_ids[input_ids < 0] = 0 inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) if speech is not None: @@ -1028,9 +1030,7 @@ class LLMASR4(nn.Module): speech_lengths = speech_lengths[:, 0] batch_size_speech, frames, _ = speech.shape - batch_size, token_num = input_ids.shape - # with torch.cuda.amp.autocast(enabled=False): # audio encoder if self.audio_encoder_activation_checkpoint: from torch.utils.checkpoint import checkpoint @@ -1078,6 +1078,11 @@ class LLMASR4(nn.Module): speech_idx += 1 + stats["batch_size_speech"] = batch_size_speech + stats["batch_size_x_frames"] = frames * batch_size_speech + stats["batch_size_real_frames"] = speech_lengths.sum().item() + stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] + with torch.cuda.amp.autocast( enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype] ): @@ -1090,7 +1095,6 @@ class LLMASR4(nn.Module): ) loss = model_outputs.loss - stats = {} with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) @@ -1098,10 +1102,7 @@ class LLMASR4(nn.Module): stats["loss"] = torch.clone(loss.detach()) stats["batch_size"] = batch_size - stats["batch_size_speech"] = batch_size_speech - stats["batch_size_x_frames"] = frames * batch_size_speech - stats["batch_size_real_frames"] = speech_lengths.sum().item() - stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] + stats["batch_size_x_tokens"] = token_num * batch_size stats["batch_size_real_tokens"] = attention_mask.sum().item() stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]