mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
39 lines
1.6 KiB
Python
39 lines
1.6 KiB
Python
import os
|
|
|
|
import torch
|
|
from funasr.datasets.small_datasets.dataset import ESPnetDataset
|
|
from funasr.datasets.small_datasets.preprocessor import build_preprocess
|
|
from funasr.samplers.build_batch_sampler import build_batch_sampler
|
|
|
|
def build_dataloader(args, mode="train"):
|
|
preprocess_fn = build_preprocess(args, train=mode=="train")
|
|
dest_sample_rate = args.frontend_conf["fs"] if (args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
|
|
if mode == "train":
|
|
data_path_and_name_and_type = args.train_data_path_and_name_and_type
|
|
shape_files = args.train_shape_file
|
|
elif mode == "valid":
|
|
data_path_and_name_and_type = args.valid_data_path_and_name_and_type
|
|
shape_files = args.valid_shape_file
|
|
else:
|
|
raise NotImplementedError(f"mode={mode}")
|
|
dataset = ESPnetDataset(
|
|
data_path_and_name_and_type,
|
|
preprocess=preprocess_fn,
|
|
dest_sample_rate=dest_sample_rate,
|
|
)
|
|
if os.path.exists(os.path.join(data_path_and_name_and_type[0][0].parent, "utt2category")):
|
|
utt2category_file = os.path.join(data_path_and_name_and_type[0][0].parent, "utt2category")
|
|
else:
|
|
utt2category_file = None
|
|
batch_sampler = build_batch_sampler(
|
|
type=args.batch_type,
|
|
shape_files=iter_options.shape_files,
|
|
fold_lengths=args.fold_length,
|
|
batch_size=iter_options.batch_size,
|
|
batch_bins=iter_options.batch_bins,
|
|
sort_in_batch=args.sort_in_batch,
|
|
sort_batch=args.sort_batch,
|
|
drop_last=False,
|
|
min_batch_size=torch.distributed.get_world_size() if args.distributed else 1,
|
|
utt2category_file=utt2category_file,
|
|
) |