FunASR/funasr/build_utils/build_distributed.py
speech_asr 993fdd8ecf update
2023-04-20 17:01:47 +08:00

39 lines
1.5 KiB
Python

import logging
import os
import torch
from funasr.train.distributed_utils import DistributedOption
from funasr.utils.build_dataclass import build_dataclass
def build_distributed(args):
distributed_option = build_dataclass(DistributedOption, args)
if args.use_pai:
distributed_option.init_options_pai()
distributed_option.init_torch_distributed_pai(args)
elif not args.simple_ddp:
distributed_option.init_torch_distributed(args)
elif args.distributed and args.simple_ddp:
distributed_option.init_torch_distributed_pai(args)
args.ngpu = torch.distributed.get_world_size()
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if not distributed_option.distributed or distributed_option.dist_rank == 0:
logging.basicConfig(
level="INFO",
format=f"[{os.uname()[1].split('.')[0]}]"
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
else:
logging.basicConfig(
level="ERROR",
format=f"[{os.uname()[1].split('.')[0]}]"
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
distributed_option.dist_rank,
distributed_option.local_rank))
return distributed_option