build_distributed.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import logging
  2. import os
  3. import torch
  4. from funasr.train.distributed_utils import DistributedOption
  5. from funasr.utils.build_dataclass import build_dataclass
  6. def build_distributed(args):
  7. distributed_option = build_dataclass(DistributedOption, args)
  8. if args.use_pai:
  9. distributed_option.init_options_pai()
  10. distributed_option.init_torch_distributed_pai(args)
  11. elif not args.simple_ddp:
  12. distributed_option.init_torch_distributed(args)
  13. elif args.distributed and args.simple_ddp:
  14. distributed_option.init_torch_distributed_pai(args)
  15. args.ngpu = torch.distributed.get_world_size()
  16. for handler in logging.root.handlers[:]:
  17. logging.root.removeHandler(handler)
  18. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  19. logging.basicConfig(
  20. level="INFO",
  21. format=f"[{os.uname()[1].split('.')[0]}]"
  22. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  23. )
  24. else:
  25. logging.basicConfig(
  26. level="ERROR",
  27. format=f"[{os.uname()[1].split('.')[0]}]"
  28. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  29. )
  30. logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
  31. distributed_option.dist_rank,
  32. distributed_option.local_rank))
  33. return distributed_option