From 54a91194901ad72562d5cb5856ee8c302d93fb0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 27 Nov 2023 14:11:54 +0800 Subject: [PATCH] dataloader --- funasr/datasets/data_sampler.py | 4 ++-- funasr/datasets/dataloader_fn.py | 2 +- funasr/datasets/dataset_jsonl.py | 7 +++++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/funasr/datasets/data_sampler.py b/funasr/datasets/data_sampler.py index c8e7b0d71..60c7c84d7 100644 --- a/funasr/datasets/data_sampler.py +++ b/funasr/datasets/data_sampler.py @@ -46,8 +46,8 @@ class BatchSampler(torch.utils.data.BatchSampler): idx_map = self.shuffle_idx[idx] # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] - sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \ - self.dataset.indexed_dataset[idx_map]["target_len"] + sample_len_cur = self.dataset.indexed_dataset.get_source_len(self.dataset.indexed_dataset[idx_map]) + \ + self.dataset.indexed_dataset.get_target_len(self.dataset.indexed_dataset[idx_map]) datalen_with_index.append([idx, sample_len_cur]) diff --git a/funasr/datasets/dataloader_fn.py b/funasr/datasets/dataloader_fn.py index a43c94773..13d35a5ec 100644 --- a/funasr/datasets/dataloader_fn.py +++ b/funasr/datasets/dataloader_fn.py @@ -47,7 +47,7 @@ if __name__ == "__main__": collate_fn=dataset.collator, batch_sampler=batch_sampler, shuffle=False, - num_workers=8, + num_workers=0, pin_memory=True) print(len(dataset)) diff --git a/funasr/datasets/dataset_jsonl.py b/funasr/datasets/dataset_jsonl.py index 543b60e6a..3a548c8d0 100644 --- a/funasr/datasets/dataset_jsonl.py +++ b/funasr/datasets/dataset_jsonl.py @@ -78,6 +78,13 @@ class IndexedDatasetJsonl(torch.utils.data.Dataset): def __getitem__(self, index): return self.contents[index] + + def get_source_len(self, data_dict): + return data_dict["source_len"] + + def get_target_len(self, data_dict): + + return data_dict["target_len"] if "target_len" in data_dict else 0 class AudioDataset(torch.utils.data.Dataset):