FunASR/funasr/datasets/small_datasets/build_loader.py
speech_asr c2b99916f5 update
2023-04-18 01:07:41 +08:00

39 lines
1.6 KiB
Python

import os
import torch
from funasr.datasets.small_datasets.dataset import ESPnetDataset
from funasr.datasets.small_datasets.preprocessor import build_preprocess
from funasr.samplers.build_batch_sampler import build_batch_sampler
def build_dataloader(args, mode="train"):
preprocess_fn = build_preprocess(args, train=mode=="train")
dest_sample_rate = args.frontend_conf["fs"] if (args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
if mode == "train":
data_path_and_name_and_type = args.train_data_path_and_name_and_type
shape_files = args.train_shape_file
elif mode == "valid":
data_path_and_name_and_type = args.valid_data_path_and_name_and_type
shape_files = args.valid_shape_file
else:
raise NotImplementedError(f"mode={mode}")
dataset = ESPnetDataset(
data_path_and_name_and_type,
preprocess=preprocess_fn,
dest_sample_rate=dest_sample_rate,
)
if os.path.exists(os.path.join(data_path_and_name_and_type[0][0].parent, "utt2category")):
utt2category_file = os.path.join(data_path_and_name_and_type[0][0].parent, "utt2category")
else:
utt2category_file = None
batch_sampler = build_batch_sampler(
type=args.batch_type,
shape_files=iter_options.shape_files,
fold_lengths=args.fold_length,
batch_size=iter_options.batch_size,
batch_bins=iter_options.batch_bins,
sort_in_batch=args.sort_in_batch,
sort_batch=args.sort_batch,
drop_last=False,
min_batch_size=torch.distributed.get_world_size() if args.distributed else 1,
utt2category_file=utt2category_file,
)