This commit is contained in:
嘉渊 2023-04-24 23:06:16 +08:00
parent e86b95e747
commit cd5db9f2dd

View File

@ -59,6 +59,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--dist_world_size", "--dist_world_size",
type=int,
default=1, default=1,
help="number of nodes for distributed training", help="number of nodes for distributed training",
) )
@ -69,6 +70,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--local_rank", "--local_rank",
type=int,
default=None, default=None,
help="local rank for distributed training", help="local rank for distributed training",
) )
@ -462,9 +464,9 @@ def get_parser():
if __name__ == '__main__': if __name__ == '__main__':
parser = get_parser() parser = get_parser()
common_args, extra_task_params = parser.parse_known_args() args, extra_task_params = parser.parse_known_args()
if extra_task_params: if extra_task_params:
args = build_args(common_args, parser, extra_task_params) args = build_args(args, parser, extra_task_params)
# set random seed # set random seed
set_all_random_seed(args.seed) set_all_random_seed(args.seed)