diff --git a/funasr/datasets/small_datasets/build_loader.py b/funasr/datasets/small_datasets/build_loader.py index 6727602dc..d5d6f7680 100644 --- a/funasr/datasets/small_datasets/build_loader.py +++ b/funasr/datasets/small_datasets/build_loader.py @@ -1,15 +1,42 @@ +import os + import torch from funasr.datasets.small_datasets.dataset import ESPnetDataset from funasr.datasets.small_datasets.preprocessor import build_preprocess +from funasr.samplers.build_batch_sampler import build_batch_sampler -def build_dataloader(args, train=False): - preprocess_fn = build_preprocess(args, train=train) +def build_dataloader(args, mode="train"): + preprocess_fn = build_preprocess(args, train=mode=="train") dest_sample_rate = args.frontend_conf["fs"] if (args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000 + if mode == "train": + data_path_and_name_and_type = args.train_data_path_and_name_and_type + shape_files = args.train_shape_file + elif mode == "valid": + data_path_and_name_and_type = args.valid_data_path_and_name_and_type + shape_files = args.valid_shape_file + else: + raise NotImplementedError(f"mode={mode}") dataset = ESPnetDataset( - iter_options.data_path_and_name_and_type, + data_path_and_name_and_type, float_dtype=args.train_dtype, preprocess=preprocess_fn, max_cache_size=args.max_cache_size, max_cache_fd=args.max_cache_fd, dest_sample_rate=dest_sample_rate, ) + if os.path.exists(os.path.join(data_path_and_name_and_type[0][0].parent, "utt2category")): + utt2category_file = os.path.join(data_path_and_name_and_type[0][0].parent, "utt2category") + else: + utt2category_file = None + batch_sampler = build_batch_sampler( + type=args.batch_type, + shape_files=iter_options.shape_files, + fold_lengths=args.fold_length, + batch_size=iter_options.batch_size, + batch_bins=iter_options.batch_bins, + sort_in_batch=args.sort_in_batch, + sort_batch=args.sort_batch, + drop_last=False, + min_batch_size=torch.distributed.get_world_size() if args.distributed else 1, + utt2category_file=utt2category_file, + ) \ No newline at end of file diff --git a/funasr/datasets/small_datasets/dataset.py b/funasr/datasets/small_datasets/dataset.py index 9bf063029..6ba8a0291 100644 --- a/funasr/datasets/small_datasets/dataset.py +++ b/funasr/datasets/small_datasets/dataset.py @@ -12,7 +12,6 @@ from typing import Mapping from typing import Tuple from typing import Union -import humanfriendly import kaldiio import numpy as np import torch @@ -22,7 +21,6 @@ from typeguard import check_return_type from funasr.fileio.npy_scp import NpyScpReader from funasr.fileio.sound_scp import SoundScpReader -from funasr.utils.sized_dict import SizedDict class AdapterForSoundScpReader(collections.abc.Mapping): @@ -111,8 +109,6 @@ class ESPnetDataset(Dataset): ] = None, float_dtype: str = "float32", int_dtype: str = "long", - max_cache_size: Union[float, int, str] = 0.0, - max_cache_fd: int = 0, dest_sample_rate: int = 16000, ): assert check_argument_types() @@ -126,7 +122,6 @@ class ESPnetDataset(Dataset): self.float_dtype = float_dtype self.int_dtype = int_dtype - self.max_cache_fd = max_cache_fd self.dest_sample_rate = dest_sample_rate self.loader_dict = {} @@ -141,14 +136,6 @@ class ESPnetDataset(Dataset): if len(self.loader_dict[name]) == 0: raise RuntimeError(f"{path} has no samples") - if isinstance(max_cache_size, str): - max_cache_size = humanfriendly.parse_size(max_cache_size) - self.max_cache_size = max_cache_size - if max_cache_size > 0: - self.cache = SizedDict(shared=True) - else: - self.cache = None - def _build_loader( self, path: str, loader_type: str ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]: @@ -162,7 +149,7 @@ class ESPnetDataset(Dataset): loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False) return AdapterForSoundScpReader(loader, self.float_dtype) elif loader_type == "kaldi_ark": - loader = kaldiio.load_scp(path, max_cache_fd=self.max_cache_fd) + loader = kaldiio.load_scp(path) return AdapterForSoundScpReader(loader, self.float_dtype) elif loader_type == "npy": return NpyScpReader() @@ -207,10 +194,6 @@ class ESPnetDataset(Dataset): d = next(iter(self.loader_dict.values())) uid = list(d)[uid] - if self.cache is not None and uid in self.cache: - data = self.cache[uid] - return uid, data - data = {} # 1. Load data from each loaders for name, loader in self.loader_dict.items(): @@ -261,9 +244,6 @@ class ESPnetDataset(Dataset): raise NotImplementedError(f"Not supported dtype: {value.dtype}") data[name] = value - if self.cache is not None and self.cache.size < self.max_cache_size: - self.cache[uid] = data - retval = uid, data assert check_return_type(retval) return retval diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py index 4708cabd5..ecd4478dc 100644 --- a/funasr/datasets/small_datasets/preprocessor.py +++ b/funasr/datasets/small_datasets/preprocessor.py @@ -855,6 +855,19 @@ def build_preprocess(args, train): text_name=text_names, non_linguistic_symbols=args.non_linguistic_symbols, ) + elif args.task_name == "lm": + retval = LMPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + text_name="text", + non_linguistic_symbols=args.non_linguistic_symbols, + split_with_space=args.split_with_space, + seg_dict_file=args.seg_dict_file + ) elif args.task_name == "vad": retval = None else: