diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py index 6a79e7541..6d9b03505 100644 --- a/funasr/datasets/sense_voice_datasets/datasets.py +++ b/funasr/datasets/sense_voice_datasets/datasets.py @@ -50,6 +50,7 @@ class SenseVoiceDataset(torch.utils.data.Dataset): self.eos = kwargs.get("eos", "<|endoftext|>") self.batch_size = kwargs.get("batch_size") self.batch_type = kwargs.get("batch_type") + self.prompt_ids_len = 0 def get_source_len(self, index): item = self.index_ds[index] @@ -73,6 +74,9 @@ class SenseVoiceDataset(torch.utils.data.Dataset): speech, speech_lengths = extract_fbank( data_src, data_type=self.data_type, frontend=self.frontend, is_final=True ) # speech: [b, T, d] + + if speech_lengths > self.batch_size: + return None speech = speech.permute(0, 2, 1) target = item["target"] if self.preprocessor_text: @@ -84,9 +88,12 @@ class SenseVoiceDataset(torch.utils.data.Dataset): prompt = f"{self.sos}{task}{text_language}" prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") prompt_ids_len = len(prompt_ids) - 1 # [sos, task] + self.prompt_ids_len = prompt_ids_len target_ids = self.tokenizer.encode(target, allowed_special="all") target_ids_len = len(target_ids) + 1 # [lid, text] + if target_ids_len > 200: + return None eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] @@ -108,16 +115,30 @@ class SenseVoiceDataset(torch.utils.data.Dataset): "text": text, "text_lengths": text_lengths, "target_mask": target_mask, + "target_mask_lengths": target_mask_lengths, } def collator(self, samples: list = None): outputs = {} for sample in samples: + if sample is None: + continue for key in sample.keys(): if key not in outputs: outputs[key] = [] outputs[key].append(sample[key]) + if len(outputs) < 1: + logging.info(f"ERROR: data is empty!") + outputs = { + "speech": torch.rand((10, 128), dtype=torch.float32), + "speech_lengths": torch.tensor([10], dtype=torch.int32), + "text": torch.tensor([58836], dtype=torch.int32), + "text_lengths": torch.tensor([1], dtype=torch.int32), + "target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]]), + } + return outputs + for key, data_list in outputs.items(): if isinstance(data_list[0], torch.Tensor): if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32: @@ -132,25 +153,29 @@ class SenseVoiceDataset(torch.utils.data.Dataset): if self.batch_type != "example": for i in range(3): - outputs = self._filter_badcase(outputs) + outputs = self._filter_badcase(outputs, i=i) return outputs def _filter_badcase(self, outputs, i=0): b, t, _ = outputs["speech"].shape - if b * t > self.batch_size: + + if b * t > self.batch_size * 1.25: beg = torch.randint(0, 2, ()).item() + if b < 2: + beg = 0 logging.info( f"Warning, b * t: {b * t} > {self.batch_size}, drop half data {i}th, beg:{beg}" ) for key, data_list in outputs.items(): outputs[key] = outputs[key][beg : beg + b : 2] - speech_lengths_max = outputs["speech_lengths_max"].max().item() + + speech_lengths_max = outputs["speech_lengths"].max().item() outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :] text_lengths_max = outputs["text_lengths"].max().item() outputs["text"] = outputs["text"][:, :text_lengths_max] - target_mask_lengths_max = outputs["target_mask_lengths_max"].max().item() + target_mask_lengths_max = outputs["target_mask_lengths"].max().item() outputs["target_mask"] = outputs["target_mask"][:, :target_mask_lengths_max] return outputs diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py index b731bb677..07fb4eb58 100644 --- a/funasr/models/sense_voice/model.py +++ b/funasr/models/sense_voice/model.py @@ -309,7 +309,7 @@ class SenseVoiceRWKV(nn.Module): if len(speech_lengths.size()) > 1: speech_lengths = speech_lengths[:, 0] - batch_size = speech.shape[0] + batch_size, frames, _ = speech.shape if self.activation_checkpoint: from torch.utils.checkpoint import checkpoint @@ -328,6 +328,7 @@ class SenseVoiceRWKV(nn.Module): stats["acc"] = acc_att stats["loss"] = torch.clone(loss.detach()) stats["batch_size"] = batch_size + stats["batch_size_x_frames"] = frames * batch_size # force_gatherable: to-device and to-tensor if scalar for DataParallel if self.length_normalized_loss: