From d1b1fdd520cd33f0fc297eaa6ee6f451a85781cd Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Thu, 28 Mar 2024 10:17:50 +0800 Subject: [PATCH] Dev gzf new (#1555) * train * train * train * train * train * train * train * train * train * train * train * train * train --- funasr/datasets/audio_datasets/espnet_samplers.py | 8 ++++---- funasr/datasets/audio_datasets/samplers.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py index d38e2bf12..1524a6a34 100644 --- a/funasr/datasets/audio_datasets/espnet_samplers.py +++ b/funasr/datasets/audio_datasets/espnet_samplers.py @@ -56,7 +56,7 @@ class EspnetStyleBatchSampler(DistributedSampler): self.shuffle = shuffle and is_training self.drop_last = drop_last - self.total_size = len(self.dataset) + # self.total_size = len(self.dataset) # self.num_samples = int(math.ceil(self.total_size / self.num_replicas)) self.epoch = 0 self.sort_size = sort_size * num_replicas @@ -71,10 +71,10 @@ class EspnetStyleBatchSampler(DistributedSampler): g = torch.Generator() g.manual_seed(self.epoch) random.seed(self.epoch) - indices = torch.randperm(self.total_size, generator=g).tolist() + indices = torch.randperm(len(self.dataset), generator=g).tolist() else: - indices = list(range(self.total_size)) - + indices = list(range(len(self.dataset))) + # Sort indices by sample length sorted_indices = sorted(indices, key=lambda idx: self.dataset.get_source_len(idx)) diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py index c274f75cb..b4fb84605 100644 --- a/funasr/datasets/audio_datasets/samplers.py +++ b/funasr/datasets/audio_datasets/samplers.py @@ -323,8 +323,8 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler): self.shuffle = shuffle and is_training self.drop_last = drop_last - self.total_size = len(self.dataset) - # self.num_samples = int(math.ceil(self.total_size / self.num_replicas)) + # self.total_size = len(self.dataset) + self.num_samples = int(math.ceil(self.total_size / self.num_replicas)) self.epoch = 0 self.sort_size = sort_size * num_replicas self.max_token_length = kwargs.get("max_token_length", 2048)