From bd7455ec7da3178d9acc8d704ee63cb443a8887e Mon Sep 17 00:00:00 2001 From: speech_asr Date: Wed, 12 Apr 2023 10:43:01 +0800 Subject: [PATCH] update --- funasr/bin/train.py | 326 ++++++++++++++++++++++++++++++ funasr/tasks/abs_task.py | 6 +- funasr/utils/build_distributed.py | 38 ++++ 3 files changed, 367 insertions(+), 3 deletions(-) create mode 100644 funasr/bin/train.py create mode 100644 funasr/utils/build_distributed.py diff --git a/funasr/bin/train.py b/funasr/bin/train.py new file mode 100644 index 000000000..94dc75ca6 --- /dev/null +++ b/funasr/bin/train.py @@ -0,0 +1,326 @@ +import sys + +import torch + +from funasr.utils import config_argparse +from funasr.utils.build_distributed import build_distributed +from funasr.utils.types import str2bool + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="FunASR Common Training Parser", + ) + + # common configuration + parser.add_argument("--output_dir", help="model save path") + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + + # ddp related + parser.add_argument( + "--dist_backend", + default="nccl", + type=str, + help="distributed backend", + ) + parser.add_argument( + "--dist_init_method", + type=str, + default="env://", + help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", ' + '"WORLD_SIZE", and "RANK" are referred.', + ) + parser.add_argument( + "--dist_world_size", + default=None, + help="number of nodes for distributed training", + ) + parser.add_argument( + "--dist_rank", + default=None, + help="node rank for distributed training", + ) + parser.add_argument( + "--local_rank", + default=None, + help="local rank for distributed training", + ) + parser.add_argument( + "--unused_parameters", + type=str2bool, + default=False, + help="Whether to use the find_unused_parameters in " + "torch.nn.parallel.DistributedDataParallel ", + ) + + # cudnn related + parser.add_argument( + "--cudnn_enabled", + type=str2bool, + default=torch.backends.cudnn.enabled, + help="Enable CUDNN", + ) + parser.add_argument( + "--cudnn_benchmark", + type=str2bool, + default=torch.backends.cudnn.benchmark, + help="Enable cudnn-benchmark mode", + ) + parser.add_argument( + "--cudnn_deterministic", + type=str2bool, + default=True, + help="Enable cudnn-deterministic mode", + ) + + # trainer related + parser.add_argument( + "--max_epoch", + type=int, + default=40, + help="The maximum number epoch to train", + ) + parser.add_argument( + "--max_update", + type=int, + default=sys.maxsize, + help="The maximum number update step to train", + ) + parser.add_argument( + "--batch_interval", + type=int, + default=10000, + help="The batch interval for saving model.", + ) + parser.add_argument( + "--patience", + default=None, + help="Number of epochs to wait without improvement " + "before stopping the training", + ) + parser.add_argument( + "--val_scheduler_criterion", + type=str, + nargs=2, + default=("valid", "loss"), + help="The criterion used for the value given to the lr scheduler. " + 'Give a pair referring the phase, "train" or "valid",' + 'and the criterion name. The mode specifying "min" or "max" can ' + "be changed by --scheduler_conf", + ) + parser.add_argument( + "--early_stopping_criterion", + type=str, + nargs=3, + default=("valid", "loss", "min"), + help="The criterion used for judging of early stopping. " + 'Give a pair referring the phase, "train" or "valid",' + 'the criterion name and the mode, "min" or "max", e.g. "acc,max".', + ) + parser.add_argument( + "--best_model_criterion", + nargs="+", + default=[ + ("train", "loss", "min"), + ("valid", "loss", "min"), + ("train", "acc", "max"), + ("valid", "acc", "max"), + ], + help="The criterion used for judging of the best model. " + 'Give a pair referring the phase, "train" or "valid",' + 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".', + ) + parser.add_argument( + "--keep_nbest_models", + type=int, + nargs="+", + default=[10], + help="Remove previous snapshots excluding the n-best scored epochs", + ) + parser.add_argument( + "--nbest_averaging_interval", + type=int, + default=0, + help="The epoch interval to apply model averaging and save nbest models", + ) + parser.add_argument( + "--grad_clip", + type=float, + default=5.0, + help="Gradient norm threshold to clip", + ) + parser.add_argument( + "--grad_clip_type", + type=float, + default=2.0, + help="The type of the used p-norm for gradient clip. Can be inf", + ) + parser.add_argument( + "--grad_noise", + type=str2bool, + default=False, + help="The flag to switch to use noise injection to " + "gradients during training", + ) + parser.add_argument( + "--accum_grad", + type=int, + default=1, + help="The number of gradient accumulation", + ) + parser.add_argument( + "--resume", + type=str2bool, + default=False, + help="Enable resuming if checkpoint is existing", + ) + parser.add_argument( + "--use_amp", + type=str2bool, + default=False, + help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6", + ) + parser.add_argument( + "--log_interval", + default=None, + help="Show the logs every the number iterations in each epochs at the " + "training phase. If None is given, it is decided according the number " + "of training samples automatically .", + ) + + # pretrained model related + parser.add_argument( + "--init_param", + type=str, + default=[], + nargs="*", + help="Specify the file path used for initialization of parameters. " + "The format is ':::', " + "where file_path is the model file path, " + "src_key specifies the key of model states to be used in the model file, " + "dst_key specifies the attribute of the model to be initialized, " + "and exclude_keys excludes keys of model states for the initialization." + "e.g.\n" + " # Load all parameters" + " --init_param some/where/model.pb\n" + " # Load only decoder parameters" + " --init_param some/where/model.pb:decoder:decoder\n" + " # Load only decoder parameters excluding decoder.embed" + " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n" + " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n", + ) + parser.add_argument( + "--ignore_init_mismatch", + type=str2bool, + default=False, + help="Ignore size mismatch when loading pre-trained model", + ) + parser.add_argument( + "--freeze_param", + type=str, + default=[], + nargs="*", + help="Freeze parameters", + ) + + # dataset related + parser.add_argument( + "--dataset_type", + type=str, + default="small", + help="whether to use dataloader for large dataset", + ) + parser.add_argument( + "--train_data_file", + type=str, + default=None, + help="train_list for large dataset", + ) + parser.add_argument( + "--valid_data_file", + type=str, + default=None, + help="valid_list for large dataset", + ) + parser.add_argument( + "--train_data_path_and_name_and_type", + action="append", + default=[], + help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ", + ) + parser.add_argument( + "--valid_data_path_and_name_and_type", + action="append", + default=[], + ) + + # pai related + parser.add_argument( + "--use_pai", + type=str2bool, + default=False, + help="flag to indicate whether training on PAI", + ) + parser.add_argument( + "--simple_ddp", + type=str2bool, + default=False, + ) + parser.add_argument( + "--num_worker_count", + type=int, + default=1, + help="The number of machines on PAI.", + ) + parser.add_argument( + "--access_key_id", + type=str, + default=None, + help="The username for oss.", + ) + parser.add_argument( + "--access_key_secret", + type=str, + default=None, + help="The password for oss.", + ) + parser.add_argument( + "--endpoint", + type=str, + default=None, + help="The endpoint for oss.", + ) + parser.add_argument( + "--bucket_name", + type=str, + default=None, + help="The bucket name for oss.", + ) + parser.add_argument( + "--oss_bucket", + default=None, + help="oss bucket.", + ) + + # task related + parser.add_argument("--task_name", help="for different task") + + return parser + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + + args.distributed = args.dist_world_size > 1 + distributed_option = build_distributed(args) + + # + + diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 775cba86a..86957d9f8 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -30,6 +30,7 @@ import torch.multiprocessing import torch.nn import torch.optim import yaml +from funasr.train.abs_espnet_model import AbsESPnetModel from torch.utils.data import DataLoader from typeguard import check_argument_types from typeguard import check_return_type @@ -44,19 +45,18 @@ from funasr.iterators.chunk_iter_factory import ChunkIterFactory from funasr.iterators.multiple_iter_factory import MultipleIterFactory from funasr.iterators.sequence_iter_factory import SequenceIterFactory from funasr.main_funcs.collect_stats import collect_stats -from funasr.optimizers.sgd import SGD from funasr.optimizers.fairseq_adam import FairseqAdam +from funasr.optimizers.sgd import SGD from funasr.samplers.build_batch_sampler import BATCH_TYPES from funasr.samplers.build_batch_sampler import build_batch_sampler from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler from funasr.schedulers.noam_lr import NoamLR -from funasr.schedulers.warmup_lr import WarmupLR from funasr.schedulers.tri_stage_scheduler import TriStageLR +from funasr.schedulers.warmup_lr import WarmupLR from funasr.torch_utils.load_pretrained_model import load_pretrained_model from funasr.torch_utils.model_summary import model_summary from funasr.torch_utils.pytorch_version import pytorch_cudnn_version from funasr.torch_utils.set_all_random_seed import set_all_random_seed -from funasr.train.abs_espnet_model import AbsESPnetModel from funasr.train.class_choices import ClassChoices from funasr.train.distributed_utils import DistributedOption from funasr.train.trainer import Trainer diff --git a/funasr/utils/build_distributed.py b/funasr/utils/build_distributed.py new file mode 100644 index 000000000..b64b4c03c --- /dev/null +++ b/funasr/utils/build_distributed.py @@ -0,0 +1,38 @@ +import logging +import os + +import torch + +from funasr.train.distributed_utils import DistributedOption +from funasr.utils.build_dataclass import build_dataclass + + +def build_distributed(args): + distributed_option = build_dataclass(DistributedOption, args) + if args.use_pai: + distributed_option.init_options_pai() + distributed_option.init_torch_distributed_pai(args) + elif not args.simple_ddp: + distributed_option.init_torch_distributed(args) + elif args.distributed and args.simple_ddp: + distributed_option.init_torch_distributed_pai(args) + args.ngpu = torch.distributed.get_world_size() + + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + if not distributed_option.distributed or distributed_option.dist_rank == 0: + logging.basicConfig( + level="INFO", + format=f"[{os.uname()[1].split('.')[0]}]" + f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level="ERROR", + format=f"[{os.uname()[1].split('.')[0]}]" + f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size, + distributed_option.dist_rank, + distributed_option.local_rank)) + return distributed_option