From b01c9f1c25282c8376f8e25eabcc6dd29d14ad13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 17 Jun 2024 14:08:57 +0800 Subject: [PATCH] decoding --- funasr/datasets/openai_datasets/datasets.py | 9 ++-- funasr/models/llm_asr/model.py | 51 ++++++++++----------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py index ae9f289f7..04ddcfdfc 100644 --- a/funasr/datasets/openai_datasets/datasets.py +++ b/funasr/datasets/openai_datasets/datasets.py @@ -300,9 +300,9 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset): return len(self.index_ds) def __getitem__(self, index): - import pdb - - pdb.set_trace() + # import pdb + # + # pdb.set_trace() output = None @@ -397,6 +397,7 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset): labels += source_mask + target_ids fbank.append(speech[0, :, :]) fbank_mask += fbank_mask_i + fbank_lens.append(speech_lengths) if len(input_ids) > self.max_token_length: logging.info( @@ -410,7 +411,7 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset): labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] # fbank = speech[0, :, :] - fbank_lens = speech_lengths + # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32) fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32) fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32) fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 1e3515fc9..03a2c08b0 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -990,12 +990,14 @@ class LLMASR4(nn.Module): text: (Batch, Length) text_lengths: (Batch,) """ - # import pdb; - # pdb.set_trace() + import pdb + + pdb.set_trace() if len(speech_lengths.size()) > 1: speech_lengths = speech_lengths[:, 0] - batch_size, frames, _ = speech.shape + batch_size_speech, frames, _ = speech.shape + batch_size, token_num = input_ids.shape with torch.cuda.amp.autocast(enabled=False): # audio encoder @@ -1008,38 +1010,34 @@ class LLMASR4(nn.Module): inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) batch_size, token_num, dims = inputs_embeds.shape - fbank_mask[fbank_mask < 0] = 0 - fbank_fake_lens = fbank_mask.sum(-1).to(torch.int32) - # _, l, _ = encoder_out.shape fake_token_len = kwargs.get("fake_token_len") fake_token_len[fake_token_len < 0] = 0 fbank_beg[fbank_beg < 0] = 0 + speech_idx = 0 for batch_idx in range(batch_size): for turn_id in range(fbank_beg.shape[1]): fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() - if fbank_beg[turn_id] > 0: + if fbank_beg[batch_idx, turn_id] > 0: speech_token_len = fake_token_len[batch_idx, turn_id] - speech_token = encoder_out[batch_idx + turn_id, turn_id, :speech_token_len, :] + speech_token = encoder_out[speech_idx, :speech_token_len, :] - fbank_fake_len = fbank_fake_lens[batch_idx].item() - fbank_beg_idx = fbank_beg[batch_idx, 0].item() - min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx) + try: + inputs_embeds[ + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : + ] = speech_token + except Exception as e: + logging.error(f"{str(e)}, {traceback.format_exc()}") + logging.info( + f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens[speech_idx].item()}" + ) + speech_token_len = encoder_out_lens[speech_idx].item() + speech_token = encoder_out[speech_idx, turn_id, :speech_token_len, :] + inputs_embeds[ + batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : + ] = speech_token - try: - inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[ - batch_idx, :min_len, : - ] - except Exception as e: - logging.error(f"{str(e)}, {traceback.format_exc()}") - logging.info( - f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, min_len: {min_len}, fbank_fake_len: {fbank_fake_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens[batch_idx].item()}" - ) - fbank_fake_len = encoder_out_lens[batch_idx].item() - min_len = min(fbank_fake_len, min_len) - inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[ - batch_idx, :min_len, : - ] + speech_idx += 1 with torch.cuda.amp.autocast( enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype] @@ -1061,7 +1059,8 @@ class LLMASR4(nn.Module): stats["loss"] = torch.clone(loss.detach()) stats["batch_size"] = batch_size - stats["batch_size_x_frames"] = frames * batch_size + stats["batch_size_speech"] = batch_size_speech + stats["batch_size_x_frames"] = frames * batch_size_speech stats["batch_size_real_frames"] = speech_lengths.sum().item() stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] stats["batch_size_x_tokens"] = token_num * batch_size