This commit is contained in:
游雁 2024-06-13 09:56:44 +08:00
parent f97f3e8dd5
commit dce53b268b
2 changed files with 3 additions and 2 deletions

View File

@ -66,6 +66,7 @@ def main(**kwargs):
# open tf32
torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
@ -83,7 +84,7 @@ def main(**kwargs):
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
torch.cuda.set_device(local_rank)
rank = dist.get_rank()
# rank = dist.get_rank()
logging.info("Build model, frontend, tokenizer")
device = kwargs.get("device", "cuda")

View File

@ -78,7 +78,7 @@ class Trainer:
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
resume (str, optional): The file path to a checkpoint to resume training from.
"""
self.rank = kwargs.get("rank", 0)
self.rank = rank
self.local_rank = local_rank
self.world_size = world_size
self.use_ddp = use_ddp