From cd5db9f2ddd8e9c1098bab3ea2f574cad3d10c3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Mon, 24 Apr 2023 23:06:16 +0800 Subject: [PATCH] update --- funasr/bin/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 8f40e2417..64391d9f5 100755 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -59,6 +59,7 @@ def get_parser(): ) parser.add_argument( "--dist_world_size", + type=int, default=1, help="number of nodes for distributed training", ) @@ -69,6 +70,7 @@ def get_parser(): ) parser.add_argument( "--local_rank", + type=int, default=None, help="local rank for distributed training", ) @@ -462,9 +464,9 @@ def get_parser(): if __name__ == '__main__': parser = get_parser() - common_args, extra_task_params = parser.parse_known_args() + args, extra_task_params = parser.parse_known_args() if extra_task_params: - args = build_args(common_args, parser, extra_task_params) + args = build_args(args, parser, extra_task_params) # set random seed set_all_random_seed(args.seed)