From 1163110135c625a8a3ebd94e050d9adb5b55bb84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Sun, 9 Jun 2024 03:29:41 +0800 Subject: [PATCH] fix bug --- funasr/datasets/openai_datasets/datasets.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py index e13c15618..9b36078f3 100644 --- a/funasr/datasets/openai_datasets/datasets.py +++ b/funasr/datasets/openai_datasets/datasets.py @@ -51,7 +51,7 @@ class OpenAIDataset(torch.utils.data.Dataset): self.batch_size = kwargs.get("batch_size") self.batch_type = kwargs.get("batch_type") self.prompt_ids_len = 0 - self.retry = kwargs.get("retry", 10) + self.retry = kwargs.get("retry", 100) self.permute = False from funasr.frontends.whisper_frontend import WhisperFrontend @@ -212,13 +212,9 @@ class OpenAIDataset(torch.utils.data.Dataset): if self.batch_type != "example": b, t = outputs["input_ids"].shape if b > 1 and b * t > self.batch_size * self.batch_size_scale_ratio_max: - # beg = torch.randint(0, 2, ()).item() - # if b < 2: - # beg = 0 logging.info( - f"Warning, b*t: {b}*{t}={b * t} > batch_size*relax: {self.batch_size_scale_ratio_max}*{self.batch_size}={self.batch_size_scale_ratio_max*self.batch_size}, drop half data {idx}th, beg:{beg}" + f"Warning, b*t: {b}*{t}={b * t} > batch_size*relax: {self.batch_size_scale_ratio_max}*{self.batch_size}={self.batch_size_scale_ratio_max*self.batch_size}, drop last data" ) - # samples = samples[beg : beg + b : 2] samples = samples[:-1] continue