mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
39 lines
1.5 KiB
Python
39 lines
1.5 KiB
Python
import logging
|
|
import os
|
|
|
|
import torch
|
|
|
|
from funasr.train.distributed_utils import DistributedOption
|
|
from funasr.utils.build_dataclass import build_dataclass
|
|
|
|
|
|
def build_distributed(args):
|
|
distributed_option = build_dataclass(DistributedOption, args)
|
|
if args.use_pai:
|
|
distributed_option.init_options_pai()
|
|
distributed_option.init_torch_distributed_pai(args)
|
|
elif not args.simple_ddp:
|
|
distributed_option.init_torch_distributed(args)
|
|
elif args.distributed and args.simple_ddp:
|
|
distributed_option.init_torch_distributed_pai(args)
|
|
args.ngpu = torch.distributed.get_world_size()
|
|
|
|
for handler in logging.root.handlers[:]:
|
|
logging.root.removeHandler(handler)
|
|
if not distributed_option.distributed or distributed_option.dist_rank == 0:
|
|
logging.basicConfig(
|
|
level="INFO",
|
|
format=f"[{os.uname()[1].split('.')[0]}]"
|
|
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
|
)
|
|
else:
|
|
logging.basicConfig(
|
|
level="ERROR",
|
|
format=f"[{os.uname()[1].split('.')[0]}]"
|
|
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
|
)
|
|
logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
|
|
distributed_option.dist_rank,
|
|
distributed_option.local_rank))
|
|
return distributed_option
|