diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 15180710b..f056b08a7 100755 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -388,9 +388,10 @@ def get_parser(): if __name__ == '__main__': parser = get_parser() - args = parser.parse_args() - task_args = build_args(args) - args = argparse.Namespace(**vars(args), **vars(task_args)) + args, extra_task_params = parser.parse_known_args() + if extra_task_params: + task_args = build_args(args, extra_task_params) + args = argparse.Namespace(**vars(args), **vars(task_args)) # set random seed set_all_random_seed(args.seed) diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py index b77cbcc4c..1d0728bc1 100644 --- a/funasr/build_utils/build_args.py +++ b/funasr/build_utils/build_args.py @@ -8,7 +8,7 @@ from funasr.utils.types import str2bool from funasr.utils.types import str_or_none -def build_args(args): +def build_args(args, extra_task_params): parser = argparse.ArgumentParser("Task related config") if args.task_name == "asr": from funasr.build_utils.build_asr_model import class_choices_list @@ -85,5 +85,5 @@ def build_args(args): else: raise NotImplementedError("Not supported task: {}".format(args.task_name)) - task_args = parser.parse_args() + task_args = parser.parse_args(extra_task_params) return task_args