This commit is contained in:
嘉渊 2023-04-24 16:41:43 +08:00
parent 15fc867ba5
commit d17d48935a
2 changed files with 6 additions and 5 deletions

View File

@ -388,9 +388,10 @@ def get_parser():
if __name__ == '__main__': if __name__ == '__main__':
parser = get_parser() parser = get_parser()
args = parser.parse_args() args, extra_task_params = parser.parse_known_args()
task_args = build_args(args) if extra_task_params:
args = argparse.Namespace(**vars(args), **vars(task_args)) task_args = build_args(args, extra_task_params)
args = argparse.Namespace(**vars(args), **vars(task_args))
# set random seed # set random seed
set_all_random_seed(args.seed) set_all_random_seed(args.seed)

View File

@ -8,7 +8,7 @@ from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none from funasr.utils.types import str_or_none
def build_args(args): def build_args(args, extra_task_params):
parser = argparse.ArgumentParser("Task related config") parser = argparse.ArgumentParser("Task related config")
if args.task_name == "asr": if args.task_name == "asr":
from funasr.build_utils.build_asr_model import class_choices_list from funasr.build_utils.build_asr_model import class_choices_list
@ -85,5 +85,5 @@ def build_args(args):
else: else:
raise NotImplementedError("Not supported task: {}".format(args.task_name)) raise NotImplementedError("Not supported task: {}".format(args.task_name))
task_args = parser.parse_args() task_args = parser.parse_args(extra_task_params)
return task_args return task_args