From 1eb85d7d17a12877e36feefcb3bb2f20a2c171f0 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Tue, 18 Apr 2023 16:09:25 +0800 Subject: [PATCH] update --- .../datasets/small_datasets/build_loader.py | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/funasr/datasets/small_datasets/build_loader.py b/funasr/datasets/small_datasets/build_loader.py index cdc648b83..a96627a37 100644 --- a/funasr/datasets/small_datasets/build_loader.py +++ b/funasr/datasets/small_datasets/build_loader.py @@ -1,13 +1,15 @@ +import logging import os -import torch from funasr.datasets.small_datasets.dataset import ESPnetDataset from funasr.datasets.small_datasets.preprocessor import build_preprocess -from funasr.samplers.build_batch_sampler import build_batch_sampler +from funasr.samplers.length_batch_sampler import LengthBatchSampler + def build_dataloader(args, mode="train"): - preprocess_fn = build_preprocess(args, train=mode=="train") - dest_sample_rate = args.frontend_conf["fs"] if (args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000 + preprocess_fn = build_preprocess(args, train=mode == "train") + dest_sample_rate = args.frontend_conf["fs"] if ( + args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000 if mode == "train": data_path_and_name_and_type = args.train_data_path_and_name_and_type shape_files = args.train_shape_file @@ -25,15 +27,22 @@ def build_dataloader(args, mode="train"): utt2category_file = os.path.join(data_path_and_name_and_type[0][0].parent, "utt2category") else: utt2category_file = None - batch_sampler = build_batch_sampler( - type=args.batch_type, - shape_files=iter_options.shape_files, - fold_lengths=args.fold_length, - batch_size=iter_options.batch_size, - batch_bins=iter_options.batch_bins, - sort_in_batch=args.sort_in_batch, - sort_batch=args.sort_batch, + + dataset_conf = args.dataset_conf + batch_sampler = LengthBatchSampler( + batch_bins=dataset_conf["batch_size"], + shape_files=shape_files, + sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending", + sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending", drop_last=False, - min_batch_size=torch.distributed.get_world_size() if args.distributed else 1, - utt2category_file=utt2category_file, - ) \ No newline at end of file + padding=True, + ) + + batches = list(batch_sampler) + bs_list = [len(batch) for batch in batches] + logging.info(f"[{mode}] dataset:\n{dataset}") + logging.info(f"[{mode}] Batch sampler: {batch_sampler}") + logging.info( + f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, " + f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}" + )