From 7ea3836893bfdf1aac03952bb1ff2da2c6ef6e57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Tue, 1 Aug 2023 14:18:32 +0800 Subject: [PATCH] update --- funasr/bin/build_trainer_bak.py | 159 -------------------------------- 1 file changed, 159 deletions(-) delete mode 100644 funasr/bin/build_trainer_bak.py diff --git a/funasr/bin/build_trainer_bak.py b/funasr/bin/build_trainer_bak.py deleted file mode 100644 index e7f28edd1..000000000 --- a/funasr/bin/build_trainer_bak.py +++ /dev/null @@ -1,159 +0,0 @@ -import os - -import yaml - -def update_dct(fin_configs, root): - if root == {}: - return {} - for root_key, root_value in root.items(): - if not isinstance(root[root_key], dict): - fin_configs[root_key] = root[root_key] - else: - if root_key in fin_configs.keys(): - result = update_dct(fin_configs[root_key], root[root_key]) - fin_configs[root_key] = result - else: - fin_configs[root_key] = root[root_key] - return fin_configs - - -def parse_args(mode): - if mode == "asr": - from funasr.tasks.asr import ASRTask as ASRTask - elif mode == "paraformer": - from funasr.tasks.asr import ASRTaskParaformer as ASRTask - elif mode == "paraformer_streaming": - from funasr.tasks.asr import ASRTaskParaformer as ASRTask - elif mode == "paraformer_vad_punc": - from funasr.tasks.asr import ASRTaskParaformer as ASRTask - elif mode == "uniasr": - from funasr.tasks.asr import ASRTaskUniASR as ASRTask - elif mode == "mfcca": - from funasr.tasks.asr import ASRTaskMFCCA as ASRTask - elif mode == "tp": - from funasr.tasks.asr import ASRTaskAligner as ASRTask - else: - raise ValueError("Unknown mode: {}".format(mode)) - parser = ASRTask.get_parser() - args = parser.parse_args() - return args, ASRTask - - -def build_trainer(modelscope_dict, - data_dir, - output_dir, - train_set="train", - dev_set="validation", - distributed=False, - dataset_type="small", - batch_bins=None, - max_epoch=None, - optim=None, - lr=None, - scheduler=None, - scheduler_conf=None, - specaug=None, - specaug_conf=None, - mate_params=None, - **kwargs): - mode = modelscope_dict['mode'] - args, ASRTask = parse_args(mode=mode) - # ddp related - if args.local_rank is not None: - distributed = True - else: - distributed = False - args.local_rank = args.local_rank if args.local_rank is not None else 0 - local_rank = args.local_rank - if "CUDA_VISIBLE_DEVICES" in os.environ.keys(): - gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",") - os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank]) - else: - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank) - - config = modelscope_dict['am_model_config'] - finetune_config = modelscope_dict['finetune_config'] - init_param = modelscope_dict['init_model'] - cmvn_file = modelscope_dict['cmvn_file'] - seg_dict_file = modelscope_dict['seg_dict'] - - # overwrite parameters - with open(config) as f: - configs = yaml.safe_load(f) - with open(finetune_config) as f: - finetune_configs = yaml.safe_load(f) - # set data_types - if dataset_type == "large": - # finetune_configs["dataset_conf"]["data_types"] = "sound,text" - if 'data_types' not in finetune_configs['dataset_conf']: - finetune_configs["dataset_conf"]["data_types"] = "sound,text" - finetune_configs = update_dct(configs, finetune_configs) - for key, value in finetune_configs.items(): - if hasattr(args, key): - setattr(args, key, value) - if mate_params is not None: - for key, value in mate_params.items(): - if hasattr(args, key): - setattr(args, key, value) - if mate_params is not None and "lora_params" in mate_params: - lora_params = mate_params['lora_params'] - configs['encoder_conf'].update(lora_params) - configs['decoder_conf'].update(lora_params) - - # prepare data - args.dataset_type = dataset_type - if args.dataset_type == "small": - args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"], - ["{}/{}/text".format(data_dir, train_set), "text", "text"]] - args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"], - ["{}/{}/text".format(data_dir, dev_set), "text", "text"]] - elif args.dataset_type == "large": - args.train_data_file = None - args.valid_data_file = None - else: - raise ValueError(f"Not supported dataset_type={args.dataset_type}") - args.init_param = [init_param] - if mate_params is not None and "init_param" in mate_params: - if len(mate_params["init_param"]) != 0: - args.init_param = mate_params["init_param"] - args.cmvn_file = cmvn_file - if os.path.exists(seg_dict_file): - args.seg_dict_file = seg_dict_file - else: - args.seg_dict_file = None - args.data_dir = data_dir - args.train_set = train_set - args.dev_set = dev_set - args.output_dir = output_dir - args.gpu_id = args.local_rank - args.config = finetune_config - if optim is not None: - args.optim = optim - if lr is not None: - args.optim_conf["lr"] = lr - if scheduler is not None: - args.scheduler = scheduler - if scheduler_conf is not None: - args.scheduler_conf = scheduler_conf - if specaug is not None: - args.specaug = specaug - if specaug_conf is not None: - args.specaug_conf = specaug_conf - if max_epoch is not None: - args.max_epoch = max_epoch - if batch_bins is not None: - if args.dataset_type == "small": - args.batch_bins = batch_bins - elif args.dataset_type == "large": - args.dataset_conf["batch_conf"]["batch_size"] = batch_bins - else: - raise ValueError(f"Not supported dataset_type={args.dataset_type}") - if args.normalize in ["null", "none", "None"]: - args.normalize = None - if args.patience in ["null", "none", "None"]: - args.patience = None - args.local_rank = local_rank - args.distributed = distributed - ASRTask.finetune_args = args - - return ASRTask