mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
d9ad40bf6f
commit
831d00aec2
@ -4,6 +4,7 @@ import sys
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.build_distributed import build_distributed
|
||||
from funasr.utils.prepare_data import prepare_data
|
||||
@ -340,4 +341,10 @@ if __name__ == '__main__':
|
||||
distributed_option.dist_rank,
|
||||
distributed_option.local_rank))
|
||||
|
||||
# prepare files for dataloader
|
||||
prepare_data(args, distributed_option)
|
||||
|
||||
set_all_random_seed(args.seed)
|
||||
torch.backends.cudnn.enabled = args.cudnn_enabled
|
||||
torch.backends.cudnn.benchmark = args.cudnn_benchmark
|
||||
torch.backends.cudnn.deterministic = args.cudnn_deterministic
|
||||
|
||||
@ -64,27 +64,17 @@ class SentencepiecesTokenizer(AbsTokenizer):
|
||||
return self.sp.DecodePieces(list(tokens))
|
||||
|
||||
|
||||
class ArkDataLoader(AbsIterFactory):
|
||||
def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None,
|
||||
bpemodel_file=None, mode="train"):
|
||||
symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
|
||||
if seg_dict_file is not None:
|
||||
seg_dict = load_seg_dict(seg_dict_file)
|
||||
else:
|
||||
seg_dict = None
|
||||
if punc_dict_file is not None:
|
||||
punc_dict = read_symbol_table(punc_dict_file)
|
||||
else:
|
||||
punc_dict = None
|
||||
self.dataset_conf = dataset_conf
|
||||
self.frontend_conf = frontend_conf
|
||||
class LargeDataLoader(AbsIterFactory):
|
||||
def __init__(self, args, mode="train"):
|
||||
symbol_table = read_symbol_table(args.token_list) if args.token_list is not None else None
|
||||
seg_dict = load_seg_dict(args.seg_dict_file) if args.seg_dict_file is not None else None
|
||||
punc_dict = load_seg_dict(args.punc_dict_file) if args.punc_dict_file is not None else None
|
||||
bpe_tokenizer = load_seg_dict(args.bpemodel_file) if args.bpemodel_file is not None else None
|
||||
self.dataset_conf = args.dataset_conf
|
||||
self.frontend_conf = args.frontend_conf
|
||||
logging.info("dataloader config: {}".format(self.dataset_conf))
|
||||
batch_mode = self.dataset_conf.get("batch_mode", "padding")
|
||||
if bpemodel_file is not None:
|
||||
bpe_tokenizer = SentencepiecesTokenizer(bpemodel_file)
|
||||
else:
|
||||
bpe_tokenizer = None
|
||||
self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
|
||||
self.dataset = Dataset(args.data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
|
||||
self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
|
||||
|
||||
def build_iter(self, epoch, shuffle=True):
|
||||
|
||||
442
funasr/datasets/small_datasets/dataset.py
Normal file
442
funasr/datasets/small_datasets/dataset.py
Normal file
@ -0,0 +1,442 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
import numbers
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import Mapping
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import h5py
|
||||
import humanfriendly
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.fileio.npy_scp import NpyScpReader
|
||||
from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
|
||||
from funasr.fileio.rand_gen_dataset import IntRandomGenerateDataset
|
||||
from funasr.fileio.read_text import load_num_sequence_text
|
||||
from funasr.fileio.read_text import read_2column_text
|
||||
from funasr.fileio.sound_scp import SoundScpReader
|
||||
from funasr.utils.sized_dict import SizedDict
|
||||
|
||||
|
||||
class AdapterForSoundScpReader(collections.abc.Mapping):
|
||||
def __init__(self, loader, dtype=None):
|
||||
assert check_argument_types()
|
||||
self.loader = loader
|
||||
self.dtype = dtype
|
||||
self.rate = None
|
||||
|
||||
def keys(self):
|
||||
return self.loader.keys()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.loader)
|
||||
|
||||
def __getitem__(self, key: str) -> np.ndarray:
|
||||
retval = self.loader[key]
|
||||
|
||||
if isinstance(retval, tuple):
|
||||
assert len(retval) == 2, len(retval)
|
||||
if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
|
||||
# sound scp case
|
||||
rate, array = retval
|
||||
elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
|
||||
# Extended ark format case
|
||||
array, rate = retval
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
|
||||
)
|
||||
|
||||
if self.rate is not None and self.rate != rate:
|
||||
raise RuntimeError(
|
||||
f"Sampling rates are mismatched: {self.rate} != {rate}"
|
||||
)
|
||||
self.rate = rate
|
||||
# Multichannel wave fie
|
||||
# array: (NSample, Channel) or (Nsample)
|
||||
if self.dtype is not None:
|
||||
array = array.astype(self.dtype)
|
||||
|
||||
else:
|
||||
# Normal ark case
|
||||
assert isinstance(retval, np.ndarray), type(retval)
|
||||
array = retval
|
||||
if self.dtype is not None:
|
||||
array = array.astype(self.dtype)
|
||||
|
||||
assert isinstance(array, np.ndarray), type(array)
|
||||
return array
|
||||
|
||||
|
||||
class H5FileWrapper:
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
self.h5_file = h5py.File(path, "r")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self.h5_file)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.h5_file)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.h5_file)
|
||||
|
||||
def __getitem__(self, key) -> np.ndarray:
|
||||
value = self.h5_file[key]
|
||||
return value[()]
|
||||
|
||||
|
||||
def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
|
||||
# The file is as follows:
|
||||
# utterance_id_A /some/where/a.wav
|
||||
# utterance_id_B /some/where/a.flac
|
||||
|
||||
# NOTE(kamo): SoundScpReader doesn't support pipe-fashion
|
||||
# like Kaldi e.g. "cat a.wav |".
|
||||
# NOTE(kamo): The audio signal is normalized to [-1,1] range.
|
||||
loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
|
||||
|
||||
# SoundScpReader.__getitem__() returns Tuple[int, ndarray],
|
||||
# but ndarray is desired, so Adapter class is inserted here
|
||||
return AdapterForSoundScpReader(loader, float_dtype)
|
||||
|
||||
|
||||
def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
|
||||
loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
|
||||
return AdapterForSoundScpReader(loader, float_dtype)
|
||||
|
||||
|
||||
def rand_int_loader(filepath, loader_type):
|
||||
# e.g. rand_int_3_10
|
||||
try:
|
||||
low, high = map(int, loader_type[len("rand_int_") :].split("_"))
|
||||
except ValueError:
|
||||
raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
|
||||
return IntRandomGenerateDataset(filepath, low, high)
|
||||
|
||||
|
||||
DATA_TYPES = {
|
||||
"sound": dict(
|
||||
func=sound_loader,
|
||||
kwargs=["dest_sample_rate","float_dtype"],
|
||||
help="Audio format types which supported by sndfile wav, flac, etc."
|
||||
"\n\n"
|
||||
" utterance_id_a a.wav\n"
|
||||
" utterance_id_b b.wav\n"
|
||||
" ...",
|
||||
),
|
||||
"kaldi_ark": dict(
|
||||
func=kaldi_loader,
|
||||
kwargs=["max_cache_fd"],
|
||||
help="Kaldi-ark file type."
|
||||
"\n\n"
|
||||
" utterance_id_A /some/where/a.ark:123\n"
|
||||
" utterance_id_B /some/where/a.ark:456\n"
|
||||
" ...",
|
||||
),
|
||||
"npy": dict(
|
||||
func=NpyScpReader,
|
||||
kwargs=[],
|
||||
help="Npy file format."
|
||||
"\n\n"
|
||||
" utterance_id_A /some/where/a.npy\n"
|
||||
" utterance_id_B /some/where/b.npy\n"
|
||||
" ...",
|
||||
),
|
||||
"text_int": dict(
|
||||
func=functools.partial(load_num_sequence_text, loader_type="text_int"),
|
||||
kwargs=[],
|
||||
help="A text file in which is written a sequence of interger numbers "
|
||||
"separated by space."
|
||||
"\n\n"
|
||||
" utterance_id_A 12 0 1 3\n"
|
||||
" utterance_id_B 3 3 1\n"
|
||||
" ...",
|
||||
),
|
||||
"csv_int": dict(
|
||||
func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
|
||||
kwargs=[],
|
||||
help="A text file in which is written a sequence of interger numbers "
|
||||
"separated by comma."
|
||||
"\n\n"
|
||||
" utterance_id_A 100,80\n"
|
||||
" utterance_id_B 143,80\n"
|
||||
" ...",
|
||||
),
|
||||
"text_float": dict(
|
||||
func=functools.partial(load_num_sequence_text, loader_type="text_float"),
|
||||
kwargs=[],
|
||||
help="A text file in which is written a sequence of float numbers "
|
||||
"separated by space."
|
||||
"\n\n"
|
||||
" utterance_id_A 12. 3.1 3.4 4.4\n"
|
||||
" utterance_id_B 3. 3.12 1.1\n"
|
||||
" ...",
|
||||
),
|
||||
"csv_float": dict(
|
||||
func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
|
||||
kwargs=[],
|
||||
help="A text file in which is written a sequence of float numbers "
|
||||
"separated by comma."
|
||||
"\n\n"
|
||||
" utterance_id_A 12.,3.1,3.4,4.4\n"
|
||||
" utterance_id_B 3.,3.12,1.1\n"
|
||||
" ...",
|
||||
),
|
||||
"text": dict(
|
||||
func=read_2column_text,
|
||||
kwargs=[],
|
||||
help="Return text as is. The text must be converted to ndarray "
|
||||
"by 'preprocess'."
|
||||
"\n\n"
|
||||
" utterance_id_A hello world\n"
|
||||
" utterance_id_B foo bar\n"
|
||||
" ...",
|
||||
),
|
||||
"hdf5": dict(
|
||||
func=H5FileWrapper,
|
||||
kwargs=[],
|
||||
help="A HDF5 file which contains arrays at the first level or the second level."
|
||||
" >>> f = h5py.File('file.h5')\n"
|
||||
" >>> array1 = f['utterance_id_A']\n"
|
||||
" >>> array2 = f['utterance_id_B']\n",
|
||||
),
|
||||
"rand_float": dict(
|
||||
func=FloatRandomGenerateDataset,
|
||||
kwargs=[],
|
||||
help="Generate random float-ndarray which has the given shapes "
|
||||
"in the file."
|
||||
"\n\n"
|
||||
" utterance_id_A 3,4\n"
|
||||
" utterance_id_B 10,4\n"
|
||||
" ...",
|
||||
),
|
||||
"rand_int_\\d+_\\d+": dict(
|
||||
func=rand_int_loader,
|
||||
kwargs=["loader_type"],
|
||||
help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
|
||||
"shapes in the path. "
|
||||
"Give the lower and upper value by the file type. e.g. "
|
||||
"rand_int_0_10 -> Generate integers from 0 to 10."
|
||||
"\n\n"
|
||||
" utterance_id_A 3,4\n"
|
||||
" utterance_id_B 10,4\n"
|
||||
" ...",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class AbsDataset(Dataset, ABC):
|
||||
@abstractmethod
|
||||
def has_name(self, name) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def names(self) -> Tuple[str, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ESPnetDataset(AbsDataset):
|
||||
"""
|
||||
Pytorch Dataset class for FunASR, simplied from ESPnet
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path_name_type_list: Collection[Tuple[str, str, str]],
|
||||
preprocess: Callable[
|
||||
[str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
|
||||
] = None,
|
||||
float_dtype: str = "float32",
|
||||
int_dtype: str = "long",
|
||||
max_cache_size: Union[float, int, str] = 0.0,
|
||||
max_cache_fd: int = 0,
|
||||
dest_sample_rate: int = 16000,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(path_name_type_list) == 0:
|
||||
raise ValueError(
|
||||
'1 or more elements are required for "path_name_type_list"'
|
||||
)
|
||||
|
||||
path_name_type_list = copy.deepcopy(path_name_type_list)
|
||||
self.preprocess = preprocess
|
||||
|
||||
self.float_dtype = float_dtype
|
||||
self.int_dtype = int_dtype
|
||||
self.max_cache_fd = max_cache_fd
|
||||
self.dest_sample_rate = dest_sample_rate
|
||||
|
||||
self.loader_dict = {}
|
||||
self.debug_info = {}
|
||||
for path, name, _type in path_name_type_list:
|
||||
if name in self.loader_dict:
|
||||
raise RuntimeError(f'"{name}" is duplicated for data-key')
|
||||
|
||||
loader = self._build_loader(path, _type)
|
||||
self.loader_dict[name] = loader
|
||||
self.debug_info[name] = path, _type
|
||||
if len(self.loader_dict[name]) == 0:
|
||||
raise RuntimeError(f"{path} has no samples")
|
||||
|
||||
# TODO(kamo): Should check consistency of each utt-keys?
|
||||
|
||||
if isinstance(max_cache_size, str):
|
||||
max_cache_size = humanfriendly.parse_size(max_cache_size)
|
||||
self.max_cache_size = max_cache_size
|
||||
if max_cache_size > 0:
|
||||
self.cache = SizedDict(shared=True)
|
||||
else:
|
||||
self.cache = None
|
||||
|
||||
def _build_loader(
|
||||
self, path: str, loader_type: str
|
||||
) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
|
||||
"""Helper function to instantiate Loader.
|
||||
|
||||
Args:
|
||||
path: The file path
|
||||
loader_type: loader_type. sound, npy, text_int, text_float, etc
|
||||
"""
|
||||
for key, dic in DATA_TYPES.items():
|
||||
# e.g. loader_type="sound"
|
||||
# -> return DATA_TYPES["sound"]["func"](path)
|
||||
if re.match(key, loader_type):
|
||||
kwargs = {}
|
||||
for key2 in dic["kwargs"]:
|
||||
if key2 == "loader_type":
|
||||
kwargs["loader_type"] = loader_type
|
||||
elif key2 == "dest_sample_rate" and loader_type=="sound":
|
||||
kwargs["dest_sample_rate"] = self.dest_sample_rate
|
||||
elif key2 == "float_dtype":
|
||||
kwargs["float_dtype"] = self.float_dtype
|
||||
elif key2 == "int_dtype":
|
||||
kwargs["int_dtype"] = self.int_dtype
|
||||
elif key2 == "max_cache_fd":
|
||||
kwargs["max_cache_fd"] = self.max_cache_fd
|
||||
else:
|
||||
raise RuntimeError(f"Not implemented keyword argument: {key2}")
|
||||
|
||||
func = dic["func"]
|
||||
try:
|
||||
return func(path, **kwargs)
|
||||
except Exception:
|
||||
if hasattr(func, "__name__"):
|
||||
name = func.__name__
|
||||
else:
|
||||
name = str(func)
|
||||
logging.error(f"An error happened with {name}({path})")
|
||||
raise
|
||||
else:
|
||||
raise RuntimeError(f"Not supported: loader_type={loader_type}")
|
||||
|
||||
def has_name(self, name) -> bool:
|
||||
return name in self.loader_dict
|
||||
|
||||
def names(self) -> Tuple[str, ...]:
|
||||
return tuple(self.loader_dict)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(next(iter(self.loader_dict.values())))
|
||||
|
||||
def __repr__(self):
|
||||
_mes = self.__class__.__name__
|
||||
_mes += "("
|
||||
for name, (path, _type) in self.debug_info.items():
|
||||
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
|
||||
_mes += f"\n preprocess: {self.preprocess})"
|
||||
return _mes
|
||||
|
||||
def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
|
||||
assert check_argument_types()
|
||||
|
||||
# Change integer-id to string-id
|
||||
if isinstance(uid, int):
|
||||
d = next(iter(self.loader_dict.values()))
|
||||
uid = list(d)[uid]
|
||||
|
||||
if self.cache is not None and uid in self.cache:
|
||||
data = self.cache[uid]
|
||||
return uid, data
|
||||
|
||||
data = {}
|
||||
# 1. Load data from each loaders
|
||||
for name, loader in self.loader_dict.items():
|
||||
try:
|
||||
value = loader[uid]
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = np.array(value)
|
||||
if not isinstance(
|
||||
value, (np.ndarray, torch.Tensor, str, numbers.Number)
|
||||
):
|
||||
raise TypeError(
|
||||
f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
|
||||
)
|
||||
except Exception:
|
||||
path, _type = self.debug_info[name]
|
||||
logging.error(
|
||||
f"Error happened with path={path}, type={_type}, id={uid}"
|
||||
)
|
||||
raise
|
||||
|
||||
# torch.Tensor is converted to ndarray
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = value.numpy()
|
||||
elif isinstance(value, numbers.Number):
|
||||
value = np.array([value])
|
||||
data[name] = value
|
||||
|
||||
# 2. [Option] Apply preprocessing
|
||||
# e.g. funasr.train.preprocessor:CommonPreprocessor
|
||||
if self.preprocess is not None:
|
||||
data = self.preprocess(uid, data)
|
||||
|
||||
# 3. Force data-precision
|
||||
for name in data:
|
||||
value = data[name]
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise RuntimeError(
|
||||
f"All values must be converted to np.ndarray object "
|
||||
f'by preprocessing, but "{name}" is still {type(value)}.'
|
||||
)
|
||||
|
||||
# Cast to desired type
|
||||
if value.dtype.kind == "f":
|
||||
value = value.astype(self.float_dtype)
|
||||
elif value.dtype.kind == "i":
|
||||
value = value.astype(self.int_dtype)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported dtype: {value.dtype}")
|
||||
data[name] = value
|
||||
|
||||
if self.cache is not None and self.cache.size < self.max_cache_size:
|
||||
self.cache[uid] = data
|
||||
|
||||
retval = uid, data
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
11
funasr/utils/build_dataloader.py
Normal file
11
funasr/utils/build_dataloader.py
Normal file
@ -0,0 +1,11 @@
|
||||
from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
|
||||
|
||||
|
||||
def build_dataloader(args):
|
||||
if args.dataset_type == "small":
|
||||
pass
|
||||
elif args.dataset_type == "large":
|
||||
train_iter_factory = LargeDataLoader(args, mode="train")
|
||||
valid_iter_factory = LargeDataLoader(args, mode="valid")
|
||||
else:
|
||||
raise ValueError(f"Not supported dataset_type={args.dataset_type}")
|
||||
@ -1,9 +1,11 @@
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from multiprocessing import Pool
|
||||
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
import torchaudio
|
||||
|
||||
|
||||
def filter_wav_text(data_dir, dataset):
|
||||
@ -34,25 +36,37 @@ def filter_wav_text(data_dir, dataset):
|
||||
f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
|
||||
else:
|
||||
filter_count += 1
|
||||
logging.info("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines),
|
||||
filter_count,
|
||||
dataset))
|
||||
logging.info(
|
||||
"{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines),
|
||||
filter_count,
|
||||
dataset))
|
||||
|
||||
|
||||
def calc_shape_core(root_path, frontend_conf, speech_length_min, speech_length_max, idx):
|
||||
def wav2num_frame(wav_path, frontend_conf):
|
||||
waveform, sampling_rate = torchaudio.load(wav_path)
|
||||
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
|
||||
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
|
||||
return n_frames, feature_dim
|
||||
|
||||
|
||||
def calc_shape_core(root_path, args, idx):
|
||||
wav_scp_file = os.path.join(root_path, "wav.scp.{}".format(idx))
|
||||
shape_file = os.path.join(root_path, "speech_shape.{}".format(idx))
|
||||
with open(wav_scp_file) as f:
|
||||
lines = f.readlines()
|
||||
frontend_conf = args.frontend_conf
|
||||
dataset_conf = args.dataset_conf
|
||||
speech_length_min = dataset_conf.speech_length_min if hasattr(dataset_conf, "speech_length_min") else -1
|
||||
speech_length_max = dataset_conf.speech_length_max if hasattr(dataset_conf, "speech_length_max") else -1
|
||||
with open(shape_file, "w") as f:
|
||||
for line in lines:
|
||||
sample_name, wav_path = line.strip().split()
|
||||
n_frames, feature_dim, speech_length = wav2num_frame(wav_path, frontend_conf)
|
||||
n_frames, feature_dim = wav2num_frame(wav_path, frontend_conf)
|
||||
write_flag = True
|
||||
if speech_length_min > 0 and speech_length < speech_length_min:
|
||||
write_flag = False
|
||||
if speech_length_max > 0 and speech_length > speech_length_max:
|
||||
write_flag = False
|
||||
if n_frames > 0 and speech_length_min > 0:
|
||||
write_flag = n_frames >= speech_length_min
|
||||
if n_frames > 0 and speech_length_max > 0:
|
||||
write_flag = n_frames <= speech_length_max
|
||||
if write_flag:
|
||||
f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
|
||||
f.flush()
|
||||
@ -61,12 +75,13 @@ def calc_shape_core(root_path, frontend_conf, speech_length_min, speech_length_m
|
||||
def calc_shape(args, dataset, nj=32):
|
||||
shape_path = os.path.join(args.data_dir, dataset, "speech_shape")
|
||||
if os.path.exists(shape_path):
|
||||
print('Shape file for small dataset already exists.')
|
||||
logging.info('Shape file for small dataset already exists.')
|
||||
return
|
||||
|
||||
split_shape_path = os.path.join(args.data_dir, dataset, "shape_files")
|
||||
if os.path
|
||||
os.makedirs(split_shape_path, exist_ok=True)
|
||||
if os.path.exists(split_shape_path):
|
||||
shutil.rmtree(split_shape_path)
|
||||
os.mkdir(split_shape_path)
|
||||
|
||||
# split
|
||||
wav_scp_file = os.path.join(args.data_dir, dataset, "wav.scp")
|
||||
@ -87,21 +102,58 @@ def calc_shape(args, dataset, nj=32):
|
||||
|
||||
p = Pool(nj)
|
||||
for i in range(nj):
|
||||
p.apply_async(calc_shape_core,
|
||||
args=(shape_path, frontend_conf, speech_length_min, speech_length_max, str(i + 1)))
|
||||
print('Generating shape files, please wait a few minutes...')
|
||||
p.apply_async(calc_shape_core, args=(split_shape_path, args, str(i + 1)))
|
||||
logging.info("Generating shape files, please wait a few minutes...")
|
||||
p.close()
|
||||
p.join()
|
||||
|
||||
# combine
|
||||
file = os.path.join(data_dir, dataset, "speech_shape")
|
||||
with open(file, "w") as f:
|
||||
with open(shape_path, "w") as f:
|
||||
for i in range(nj):
|
||||
job_file = os.path.join(shape_path, "speech_shape.{}".format(str(i + 1)))
|
||||
job_file = os.path.join(split_shape_path, "speech_shape.{}".format(str(i + 1)))
|
||||
with open(job_file) as job_f:
|
||||
lines = job_f.readlines()
|
||||
f.writelines(lines)
|
||||
print('Generating shape files done.')
|
||||
logging.info('Generating shape files done.')
|
||||
|
||||
|
||||
def generate_data_list(data_dir, dataset, nj=100):
|
||||
list_file = os.path.join(data_dir, dataset, "data.list")
|
||||
if os.path.exists(list_file):
|
||||
logging.info('Data list for large dataset already exists.')
|
||||
return
|
||||
split_path = os.path.join(data_dir, dataset, "split")
|
||||
if os.path.exists(split_path):
|
||||
shutil.rmtree(split_path)
|
||||
os.mkdir(split_path)
|
||||
|
||||
with open(os.path.join(data_dir, dataset, "wav.scp")) as f_wav:
|
||||
wav_lines = f_wav.readlines()
|
||||
with open(os.path.join(data_dir, dataset, "text")) as f_text:
|
||||
text_lines = f_text.readlines()
|
||||
num_lines = len(wav_lines)
|
||||
num_job_lines = num_lines // nj
|
||||
start = 0
|
||||
for i in range(nj):
|
||||
end = start + num_job_lines
|
||||
split_path_nj = os.path.join(split_path, str(i + 1))
|
||||
os.mkdir(split_path_nj)
|
||||
wav_file = os.path.join(split_path_nj, "wav.scp")
|
||||
text_file = os.path.join(split_path_nj, "text")
|
||||
with open(wav_file, "w") as fw, open(text_file, "w") as ft:
|
||||
if i == nj - 1:
|
||||
fw.writelines(wav_lines[start:])
|
||||
ft.writelines(text_lines[start:])
|
||||
else:
|
||||
fw.writelines(wav_lines[start:end])
|
||||
ft.writelines(text_lines[start:end])
|
||||
start = end
|
||||
|
||||
with open(list_file, "w") as f_data:
|
||||
for i in range(nj):
|
||||
wav_path = os.path.join(split_path, str(i + 1), "wav.scp")
|
||||
text_path = os.path.join(split_path, str(i + 1), "text")
|
||||
f_data.write(wav_path + " " + text_path + "\n")
|
||||
|
||||
|
||||
def prepare_data(args, distributed_option):
|
||||
@ -109,6 +161,18 @@ def prepare_data(args, distributed_option):
|
||||
if not distributed or distributed_option.dist_rank == 0:
|
||||
filter_wav_text(args.data_dir, args.train_set)
|
||||
filter_wav_text(args.data_dir, args.dev_set)
|
||||
dist.barrier()
|
||||
|
||||
if args.dataset_type == "small" and args.train_shape_file is None:
|
||||
calc_shape(args, args.train_set)
|
||||
calc_shape(args, args.dev_set)
|
||||
|
||||
if args.dataset_type == "large" and args.train_data_file is None:
|
||||
generate_data_list(args.data_dir, args.train_set)
|
||||
generate_data_list(args.data_dir, args.dev_set)
|
||||
|
||||
args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
|
||||
args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")]
|
||||
args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
|
||||
args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list")
|
||||
if distributed:
|
||||
dist.barrier()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user