diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py index f6127b65a..8cb092639 100644 --- a/funasr/datasets/openai_datasets/datasets.py +++ b/funasr/datasets/openai_datasets/datasets.py @@ -180,51 +180,43 @@ class OpenAIDataset(torch.utils.data.Dataset): return output def collator(self, samples: list = None): - outputs = {} - for sample in samples: - if sample is None: - continue - for key in sample.keys(): - if key not in outputs: - outputs[key] = [] - outputs[key].append(sample[key]) - for key, data_list in outputs.items(): - if isinstance(data_list[0], torch.Tensor): - if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32: + for idx in range(self.retry): + badcase_flag = False - pad_value = self.int_pad_value - else: - pad_value = self.float_pad_value + outputs = {} + for sample in samples: + if sample is None: + continue + for key in sample.keys(): + if key not in outputs: + outputs[key] = [] + outputs[key].append(sample[key]) - outputs[key] = torch.nn.utils.rnn.pad_sequence( - data_list, batch_first=True, padding_value=pad_value - ) - - if self.batch_type != "example": - for i in range(10): - outputs = self._filter_badcase(outputs, i=i) - - return outputs - - def _filter_badcase(self, outputs, i=0): - b, t = outputs["input_ids"].shape - - if b * t > self.batch_size * 2: - beg = torch.randint(0, 2, ()).item() - if b < 2: - beg = 0 - logging.info( - f"Warning, b * t: {b * t} > {self.batch_size}, b: {b}, t: {t}, drop half data {i}th, beg:{beg}" - ) for key, data_list in outputs.items(): - outputs[key] = outputs[key][beg : beg + b : 2] - # - # speech_lengths_max = outputs["speech_lengths"].max().item() - # outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :] - # text_lengths_max = outputs["text_lengths"].max().item() - # outputs["text"] = outputs["text"][:, :text_lengths_max] - # target_mask_lengths_max = outputs["target_mask_lengths"].max().item() - # outputs["target_mask"] = outputs["target_mask"][:, :target_mask_lengths_max] + if isinstance(data_list[0], torch.Tensor): + if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32: + + pad_value = self.int_pad_value + else: + pad_value = self.float_pad_value + + outputs[key] = torch.nn.utils.rnn.pad_sequence( + data_list, batch_first=True, padding_value=pad_value + ) + + if self.batch_type != "example": + b, t = outputs["input_ids"].shape + if b * t > self.batch_size * 2: + beg = torch.randint(0, 2, ()).item() + if b < 2: + beg = 0 + logging.info( + f"Warning, b * t: {b * t} > {self.batch_size}, b: {b}, t: {t}, drop half data {idx}th, beg:{beg}" + ) + samples = samples[beg : beg + b : 2] + continue + + break return outputs