This commit is contained in:
嘉渊 2023-04-25 15:10:24 +08:00
parent 0237e32625
commit 95d6db2656

View File

@ -83,7 +83,7 @@ class SequenceIterFactory(AbsIterFactory):
args.max_update = len(bs_list) * args.max_epoch
logging.info("Max update: {}".format(args.max_update))
if args.distributed:
if args.distributed and mode=="train":
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
for batch in batches: