| 1234567891011121314151617181920212223242526272829303132333435363738 |
- import logging
- import os
- import torch
- from funasr.train.distributed_utils import DistributedOption
- from funasr.utils.build_dataclass import build_dataclass
- def build_distributed(args):
- distributed_option = build_dataclass(DistributedOption, args)
- if args.use_pai:
- distributed_option.init_options_pai()
- distributed_option.init_torch_distributed_pai(args)
- elif not args.simple_ddp:
- distributed_option.init_torch_distributed(args)
- elif args.distributed and args.simple_ddp:
- distributed_option.init_torch_distributed_pai(args)
- args.ngpu = torch.distributed.get_world_size()
- for handler in logging.root.handlers[:]:
- logging.root.removeHandler(handler)
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- logging.basicConfig(
- level="INFO",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- else:
- logging.basicConfig(
- level="ERROR",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
- distributed_option.dist_rank,
- distributed_option.local_rank))
- return distributed_option
|