mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
05d4176e88
commit
58fb22cb2b
@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.build_dataloader import build_dataloader
|
||||
from funasr.utils.build_distributed import build_distributed
|
||||
from funasr.utils.prepare_data import prepare_data
|
||||
from funasr.utils.types import str2bool
|
||||
@ -338,14 +339,36 @@ if __name__ == '__main__':
|
||||
format=f"[{os.uname()[1].split('.')[0]}]"
|
||||
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
|
||||
distributed_option.dist_rank,
|
||||
distributed_option.local_rank))
|
||||
|
||||
# prepare files for dataloader
|
||||
prepare_data(args, distributed_option)
|
||||
|
||||
# set random seed
|
||||
set_all_random_seed(args.seed)
|
||||
torch.backends.cudnn.enabled = args.cudnn_enabled
|
||||
torch.backends.cudnn.benchmark = args.cudnn_benchmark
|
||||
torch.backends.cudnn.deterministic = args.cudnn_deterministic
|
||||
|
||||
train_dataloader, valid_dataloader = build_dataloader(args)
|
||||
|
||||
logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
|
||||
distributed_option.dist_rank,
|
||||
distributed_option.local_rank))
|
||||
|
||||
# optimizers = cls.build_optimizers(args, model=model)
|
||||
# schedulers = []
|
||||
# for i, optim in enumerate(optimizers, 1):
|
||||
# suf = "" if i == 1 else str(i)
|
||||
# name = getattr(args, f"scheduler{suf}")
|
||||
# conf = getattr(args, f"scheduler{suf}_conf")
|
||||
# if name is not None:
|
||||
# cls_ = scheduler_classes.get(name)
|
||||
# if cls_ is None:
|
||||
# raise ValueError(
|
||||
# f"must be one of {list(scheduler_classes)}: {name}"
|
||||
# )
|
||||
# scheduler = cls_(optim, **conf)
|
||||
# else:
|
||||
# scheduler = None
|
||||
#
|
||||
# schedulers.append(scheduler)
|
||||
|
||||
@ -1,84 +0,0 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
|
||||
from funasr.datasets.small_datasets.dataset import ESPnetDataset
|
||||
from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
|
||||
from funasr.datasets.small_datasets.preprocessor import build_preprocess
|
||||
from funasr.datasets.small_datasets.sequence_iter_factory import SequenceIterFactory
|
||||
|
||||
|
||||
def build_dataloader(args, mode="train"):
|
||||
# preprocess
|
||||
preprocess_fn = build_preprocess(args, train=mode == "train")
|
||||
|
||||
# collate
|
||||
if args.task_name in ["punc", "lm"]:
|
||||
collate_fn = CommonCollateFn(int_pad_value=0)
|
||||
else:
|
||||
collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
|
||||
|
||||
# dataset
|
||||
dest_sample_rate = args.frontend_conf["fs"] if (
|
||||
args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
|
||||
if mode == "train":
|
||||
data_path_and_name_and_type = args.train_data_path_and_name_and_type
|
||||
shape_files = args.train_shape_file
|
||||
elif mode == "valid":
|
||||
data_path_and_name_and_type = args.valid_data_path_and_name_and_type
|
||||
shape_files = args.valid_shape_file
|
||||
else:
|
||||
raise NotImplementedError(f"mode={mode}")
|
||||
dataset = ESPnetDataset(
|
||||
data_path_and_name_and_type,
|
||||
preprocess=preprocess_fn,
|
||||
dest_sample_rate=dest_sample_rate,
|
||||
)
|
||||
|
||||
# sampler
|
||||
dataset_conf = args.dataset_conf
|
||||
batch_sampler = LengthBatchSampler(
|
||||
batch_bins=dataset_conf["batch_size"],
|
||||
shape_files=shape_files,
|
||||
sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending",
|
||||
sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending",
|
||||
drop_last=False,
|
||||
padding=True,
|
||||
)
|
||||
|
||||
batches = list(batch_sampler)
|
||||
bs_list = [len(batch) for batch in batches]
|
||||
logging.info(f"[{mode}] dataset:\n{dataset}")
|
||||
logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
|
||||
logging.info(
|
||||
f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
|
||||
f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
|
||||
)
|
||||
|
||||
if args.scheduler == "tri_stage" and mode == "train":
|
||||
args.max_update = len(bs_list) * args.max_epoch
|
||||
logging.info("Max update: {}".format(args.max_update))
|
||||
|
||||
if args.distributed:
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
for batch in batches:
|
||||
if len(batch) < world_size:
|
||||
raise RuntimeError(
|
||||
f"The batch-size must be equal or more than world_size: "
|
||||
f"{len(batch)} < {world_size}"
|
||||
)
|
||||
batches = [batch[rank::world_size] for batch in batches]
|
||||
|
||||
# dataloader
|
||||
return SequenceIterFactory(
|
||||
dataset=dataset,
|
||||
batches=batches,
|
||||
seed=args.seed,
|
||||
shuffle=mode == "train",
|
||||
num_workers=args.num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=args.ngpu > 0,
|
||||
)
|
||||
@ -27,8 +27,7 @@ class RawSampler(AbsSampler):
|
||||
|
||||
|
||||
class SequenceIterFactory(AbsIterFactory):
|
||||
"""Build iterator for each epoch.
|
||||
|
||||
"""Build iterator for each epoch, modified from ESPnet
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@ -1160,7 +1160,8 @@ class AbsTask(ABC):
|
||||
args.batch_bins = args.batch_bins * args.ngpu
|
||||
|
||||
# filter samples if wav.scp and text are mismatch
|
||||
if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
|
||||
if (
|
||||
args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
|
||||
if not args.simple_ddp or distributed_option.dist_rank == 0:
|
||||
filter_wav_text(args.data_dir, args.train_set)
|
||||
filter_wav_text(args.data_dir, args.dev_set)
|
||||
@ -1169,8 +1170,10 @@ class AbsTask(ABC):
|
||||
|
||||
if args.train_shape_file is None and args.dataset_type == "small":
|
||||
if not args.simple_ddp or distributed_option.dist_rank == 0:
|
||||
calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
|
||||
calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
|
||||
calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min,
|
||||
args.speech_length_max)
|
||||
calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min,
|
||||
args.speech_length_max)
|
||||
if args.simple_ddp:
|
||||
dist.barrier()
|
||||
args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
|
||||
@ -1360,15 +1363,21 @@ class AbsTask(ABC):
|
||||
if args.dataset_type == "large":
|
||||
from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
|
||||
train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
|
||||
frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
|
||||
seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
|
||||
punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
|
||||
frontend_conf=args.frontend_conf if hasattr(args,
|
||||
"frontend_conf") else None,
|
||||
seg_dict_file=args.seg_dict_file if hasattr(args,
|
||||
"seg_dict_file") else None,
|
||||
punc_dict_file=args.punc_list if hasattr(args,
|
||||
"punc_list") else None,
|
||||
bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
|
||||
mode="train")
|
||||
valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
|
||||
frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
|
||||
seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
|
||||
punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
|
||||
valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
|
||||
frontend_conf=args.frontend_conf if hasattr(args,
|
||||
"frontend_conf") else None,
|
||||
seg_dict_file=args.seg_dict_file if hasattr(args,
|
||||
"seg_dict_file") else None,
|
||||
punc_dict_file=args.punc_list if hasattr(args,
|
||||
"punc_list") else None,
|
||||
bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
|
||||
mode="eval")
|
||||
elif args.dataset_type == "small":
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
|
||||
from funasr.datasets.small_datasets.build_dataloader import build_dataloader
|
||||
from funasr.datasets.small_datasets.sequence_iter_factory import SequenceIterFactory
|
||||
|
||||
|
||||
def build_dataloader(args):
|
||||
if args.dataset_type == "small":
|
||||
train_iter_factory = LargeDataLoader(args, mode="train")
|
||||
valid_iter_factory = LargeDataLoader(args, mode="valid")
|
||||
train_iter_factory = SequenceIterFactory(args, mode="train")
|
||||
valid_iter_factory = SequenceIterFactory(args, mode="valid")
|
||||
elif args.dataset_type == "large":
|
||||
train_iter_factory = LargeDataLoader(args, mode="train")
|
||||
valid_iter_factory = LargeDataLoader(args, mode="valid")
|
||||
valid_iter_factory = LargeDataLoader(args, mode="valid")
|
||||
else:
|
||||
raise ValueError(f"Not supported dataset_type={args.dataset_type}")
|
||||
raise ValueError(f"Not supported dataset_type={args.dataset_type}")
|
||||
|
||||
return train_iter_factory, valid_iter_factory
|
||||
|
||||
13
funasr/utils/build_model.py
Normal file
13
funasr/utils/build_model.py
Normal file
@ -0,0 +1,13 @@
|
||||
import logging
|
||||
|
||||
def build_model(args):
|
||||
if args.token_list is not None:
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
args.token_list = list(token_list)
|
||||
vocab_size = len(token_list)
|
||||
logging.info(f"Vocabulary size: {vocab_size}")
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user