diff --git a/funasr/bin/train.py b/funasr/bin/train.py index c54e851ac..474e85753 100755 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -389,8 +389,8 @@ if __name__ == '__main__': parser = get_parser() args, extra_task_params = parser.parse_known_args() if extra_task_params: - task_args = build_args(args, extra_task_params) - args = argparse.Namespace(**vars(args), **vars(task_args)) + args = build_args(args, parser, extra_task_params) + # args = argparse.Namespace(**vars(args), **vars(task_args)) # set random seed set_all_random_seed(args.seed) diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py index a50122da3..a0baaa72d 100644 --- a/funasr/build_utils/build_args.py +++ b/funasr/build_utils/build_args.py @@ -7,37 +7,37 @@ from funasr.utils.types import str2bool from funasr.utils.types import str_or_none -def build_args(args, extra_task_params): - parser = config_argparse.ArgumentParser("Task related config") +def build_args(args, parser, extra_task_params): + task_parser = config_argparse.ArgumentParser("Task related config") if args.task_name == "asr": from funasr.build_utils.build_asr_model import class_choices_list for class_choices in class_choices_list: - class_choices.add_arguments(parser) - parser.add_argument( + class_choices.add_arguments(task_parser) + task_parser.add_argument( "--split_with_space", type=str2bool, default=True, help="whether to split text using ", ) - parser.add_argument( + task_parser.add_argument( "--seg_dict_file", type=str, default=None, help="seg_dict_file for text processing", ) - parser.add_argument( + task_parser.add_argument( "--input_size", type=int_or_none, default=None, help="The number of input dimension of the feature", ) - parser.add_argument( + task_parser.add_argument( "--ctc_conf", action=NestedDictAction, default=get_default_kwargs(CTC), help="The keyword arguments for CTC class.", ) - parser.add_argument( + task_parser.add_argument( "--cmvn_file", type=str_or_none, default=None, @@ -47,8 +47,8 @@ def build_args(args, extra_task_params): elif args.task_name == "pretrain": from funasr.build_utils.build_pretrain_model import class_choices_list for class_choices in class_choices_list: - class_choices.add_arguments(parser) - parser.add_argument( + class_choices.add_arguments(task_parser) + task_parser.add_argument( "--input_size", type=int_or_none, default=None, @@ -58,18 +58,18 @@ def build_args(args, extra_task_params): elif args.task_name == "lm": from funasr.build_utils.build_lm_model import class_choices_list for class_choices in class_choices_list: - class_choices.add_arguments(parser) + class_choices.add_arguments(task_parser) elif args.task_name == "punc": from funasr.build_utils.build_punc_model import class_choices_list for class_choices in class_choices_list: - class_choices.add_arguments(parser) + class_choices.add_arguments(task_parser) elif args.task_name == "vad": from funasr.build_utils.build_vad_model import class_choices_list for class_choices in class_choices_list: - class_choices.add_arguments(parser) - parser.add_argument( + class_choices.add_arguments(task_parser) + task_parser.add_argument( "--input_size", type=int_or_none, default=None, @@ -79,10 +79,13 @@ def build_args(args, extra_task_params): elif args.task_name == "diar": from funasr.build_utils.build_diar_model import class_choices_list for class_choices in class_choices_list: - class_choices.add_arguments(parser) + class_choices.add_arguments(task_parser) else: raise NotImplementedError("Not supported task: {}".format(args.task_name)) + for action in parser._actions: + task_parser._add_action(action) + task_args = parser.parse_args(extra_task_params) return task_args