From d475a13df0d9339866ab3739e2c1bb30079c368b Mon Sep 17 00:00:00 2001 From: speech_asr Date: Thu, 2 Mar 2023 14:58:17 +0800 Subject: [PATCH] update dist_backend option --- funasr/train/distributed_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/funasr/train/distributed_utils.py b/funasr/train/distributed_utils.py index c89793019..13f57447c 100644 --- a/funasr/train/distributed_utils.py +++ b/funasr/train/distributed_utils.py @@ -53,7 +53,7 @@ class DistributedOption: # https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group os.environ.setdefault("NCCL_BLOCKING_WAIT", "1") - torch.distributed.init_process_group(backend='nccl', + torch.distributed.init_process_group(backend=self.dist_backend, init_method=self.dist_init_method, world_size=args.dist_world_size, rank=args.dist_rank) @@ -113,7 +113,7 @@ class DistributedOption: # https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group os.environ.setdefault("NCCL_BLOCKING_WAIT", "1") - torch.distributed.init_process_group(backend='nccl', init_method='env://') + torch.distributed.init_process_group(backend=self.dist_backend, init_method='env://') self.dist_rank = torch.distributed.get_rank() self.dist_world_size = torch.distributed.get_world_size() self.local_rank = args.local_rank