diff --git a/funasr/train/distributed_utils.py b/funasr/train/distributed_utils.py index c89793019..13f57447c 100644 --- a/funasr/train/distributed_utils.py +++ b/funasr/train/distributed_utils.py @@ -53,7 +53,7 @@ class DistributedOption: # https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group os.environ.setdefault("NCCL_BLOCKING_WAIT", "1") - torch.distributed.init_process_group(backend='nccl', + torch.distributed.init_process_group(backend=self.dist_backend, init_method=self.dist_init_method, world_size=args.dist_world_size, rank=args.dist_rank) @@ -113,7 +113,7 @@ class DistributedOption: # https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group os.environ.setdefault("NCCL_BLOCKING_WAIT", "1") - torch.distributed.init_process_group(backend='nccl', init_method='env://') + torch.distributed.init_process_group(backend=self.dist_backend, init_method='env://') self.dist_rank = torch.distributed.get_rank() self.dist_world_size = torch.distributed.get_world_size() self.local_rank = args.local_rank