distributed_utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. import dataclasses
  4. import logging
  5. import os
  6. import socket
  7. from typing import Optional
  8. import torch
  9. import torch.distributed
  10. @dataclasses.dataclass
  11. class DistributedOption:
  12. # Enable distributed Training
  13. distributed: bool = False
  14. # torch.distributed.Backend: "nccl", "mpi", "gloo", or "tcp"
  15. dist_backend: str = "nccl"
  16. # if init_method="env://",
  17. # env values of "MASTER_PORT", "MASTER_ADDR", "WORLD_SIZE", and "RANK" are referred.
  18. dist_init_method: str = "env://"
  19. dist_world_size: Optional[int] = None
  20. dist_rank: Optional[int] = None
  21. local_rank: Optional[int] = None
  22. ngpu: int = 0
  23. dist_master_addr: Optional[str] = None
  24. dist_master_port: Optional[int] = None
  25. dist_launcher: Optional[str] = None
  26. multiprocessing_distributed: bool = True
  27. def init_options(self):
  28. if self.distributed:
  29. if self.dist_init_method == "env://":
  30. if get_master_addr(self.dist_master_addr, self.dist_launcher) is None:
  31. raise RuntimeError(
  32. "--dist_master_addr or MASTER_ADDR must be set "
  33. "if --dist_init_method == 'env://'"
  34. )
  35. if get_master_port(self.dist_master_port) is None:
  36. raise RuntimeError(
  37. "--dist_master_port or MASTER_PORT must be set "
  38. "if --dist_init_port == 'env://'"
  39. )
  40. def init_torch_distributed(self, args):
  41. if self.distributed:
  42. # See:
  43. # https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html
  44. os.environ.setdefault("NCCL_DEBUG", "INFO")
  45. # See:
  46. # https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
  47. os.environ.setdefault("NCCL_BLOCKING_WAIT", "1")
  48. torch.distributed.init_process_group(backend='nccl',
  49. init_method=self.dist_init_method,
  50. world_size=args.dist_world_size,
  51. rank=args.dist_rank)
  52. self.dist_rank = torch.distributed.get_rank()
  53. self.dist_world_size = torch.distributed.get_world_size()
  54. self.local_rank = args.local_rank
  55. logging.info("world size: {}, rank: {}, local_rank: {}".format(self.dist_world_size, self.dist_rank,
  56. self.local_rank))
  57. def init_options_pai(self):
  58. if self.distributed:
  59. if self.dist_init_method == "env://":
  60. if get_master_addr(self.dist_master_addr, self.dist_launcher) is None:
  61. raise RuntimeError(
  62. "--dist_master_addr or MASTER_ADDR must be set "
  63. "if --dist_init_method == 'env://'"
  64. )
  65. if get_master_port(self.dist_master_port) is None:
  66. raise RuntimeError(
  67. "--dist_master_port or MASTER_PORT must be set "
  68. "if --dist_init_port == 'env://'"
  69. )
  70. self.dist_rank = get_rank(self.dist_rank, self.dist_launcher)
  71. self.dist_world_size = get_world_size(
  72. self.dist_world_size, self.dist_launcher
  73. )
  74. self.local_rank = get_local_rank(self.local_rank, self.dist_launcher)
  75. if (
  76. self.dist_rank is not None
  77. and self.dist_world_size is not None
  78. and self.dist_rank >= self.dist_world_size
  79. ):
  80. raise RuntimeError(
  81. f"RANK >= WORLD_SIZE: {self.dist_rank} >= {self.dist_world_size}"
  82. )
  83. if self.dist_init_method == "env://":
  84. self.dist_master_addr = get_master_addr(
  85. self.dist_master_addr, self.dist_launcher
  86. )
  87. self.dist_master_port = get_master_port(self.dist_master_port)
  88. if (
  89. self.dist_master_addr is not None
  90. and self.dist_master_port is not None
  91. ):
  92. self.dist_init_method = (
  93. f"tcp://{self.dist_master_addr}:{self.dist_master_port}"
  94. )
  95. def init_torch_distributed_pai(self, args):
  96. if self.distributed:
  97. # See:
  98. # https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html
  99. os.environ.setdefault("NCCL_DEBUG", "INFO")
  100. # See:
  101. # https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
  102. os.environ.setdefault("NCCL_BLOCKING_WAIT", "1")
  103. torch.distributed.init_process_group(backend='nccl', init_method='env://')
  104. self.dist_rank = torch.distributed.get_rank()
  105. self.dist_world_size = torch.distributed.get_world_size()
  106. self.local_rank = args.local_rank
  107. logging.info("world size: {}, rank: {}, local_rank: {}".format(self.dist_world_size, self.dist_rank,
  108. self.local_rank))
  109. def resolve_distributed_mode(args):
  110. # Note that args.distributed is set by only this function.
  111. # and ArgumentParser doesn't have such option
  112. if args.multiprocessing_distributed:
  113. num_nodes = get_num_nodes(args.dist_world_size, args.dist_launcher)
  114. # a. multi-node
  115. if num_nodes > 1:
  116. args.distributed = True
  117. # b. single-node and multi-gpu with multiprocessing_distributed mode
  118. elif args.ngpu > 1:
  119. args.distributed = True
  120. # c. single-node and single-gpu
  121. else:
  122. args.distributed = False
  123. if args.ngpu <= 1:
  124. # Disable multiprocessing_distributed mode if 1process per node or cpu mode
  125. args.multiprocessing_distributed = False
  126. if args.ngpu == 1:
  127. # If the number of GPUs equals to 1 with multiprocessing_distributed mode,
  128. # LOCAL_RANK is always 0
  129. args.local_rank = 0
  130. if num_nodes > 1 and get_node_rank(args.dist_rank, args.dist_launcher) is None:
  131. raise RuntimeError(
  132. "--dist_rank or RANK must be set "
  133. "if --multiprocessing_distributed == true"
  134. )
  135. # Note that RANK, LOCAL_RANK, and WORLD_SIZE is automatically set,
  136. # so we don't need to check here
  137. else:
  138. # d. multiprocess and multi-gpu with external launcher
  139. # e.g. torch.distributed.launch
  140. if get_world_size(args.dist_world_size, args.dist_launcher) > 1:
  141. args.distributed = True
  142. # e. single-process
  143. else:
  144. args.distributed = False
  145. if args.distributed and args.ngpu > 0:
  146. if get_local_rank(args.local_rank, args.dist_launcher) is None:
  147. raise RuntimeError(
  148. "--local_rank or LOCAL_RANK must be set "
  149. "if --multiprocessing_distributed == false"
  150. )
  151. if args.distributed:
  152. if get_node_rank(args.dist_rank, args.dist_launcher) is None:
  153. raise RuntimeError(
  154. "--dist_rank or RANK must be set "
  155. "if --multiprocessing_distributed == false"
  156. )
  157. if args.distributed and args.dist_launcher == "slurm" and not is_in_slurm_step():
  158. raise RuntimeError("Launch by 'srun' command if --dist_launcher='slurm'")
  159. def is_in_slurm_job() -> bool:
  160. return "SLURM_PROCID" in os.environ and "SLURM_NTASKS" in os.environ
  161. def is_in_slurm_step() -> bool:
  162. return (
  163. is_in_slurm_job()
  164. and "SLURM_STEP_NUM_NODES" in os.environ
  165. and "SLURM_STEP_NODELIST" in os.environ
  166. )
  167. def _int_or_none(x: Optional[str]) -> Optional[int]:
  168. if x is None:
  169. return x
  170. return int(x)
  171. def free_port():
  172. """Find free port using bind().
  173. There are some interval between finding this port and using it
  174. and the other process might catch the port by that time.
  175. Thus it is not guaranteed that the port is really empty.
  176. """
  177. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
  178. sock.bind(("", 0))
  179. return sock.getsockname()[1]
  180. def get_rank(prior=None, launcher: str = None) -> Optional[int]:
  181. if prior is None:
  182. if launcher == "slurm":
  183. if not is_in_slurm_step():
  184. raise RuntimeError("This process seems not to be launched by 'srun'")
  185. prior = os.environ["SLURM_PROCID"]
  186. elif launcher == "mpi":
  187. raise RuntimeError(
  188. "launcher=mpi is used for 'multiprocessing-distributed' mode"
  189. )
  190. elif launcher is not None:
  191. raise RuntimeError(f"launcher='{launcher}' is not supported")
  192. if prior is not None:
  193. return int(prior)
  194. else:
  195. # prior is None and RANK is None -> RANK = None
  196. return _int_or_none(os.environ.get("RANK"))
  197. def get_world_size(prior=None, launcher: str = None) -> int:
  198. if prior is None:
  199. if launcher == "slurm":
  200. if not is_in_slurm_step():
  201. raise RuntimeError("This process seems not to be launched by 'srun'")
  202. prior = int(os.environ["SLURM_NTASKS"])
  203. elif launcher == "mpi":
  204. raise RuntimeError(
  205. "launcher=mpi is used for 'multiprocessing-distributed' mode"
  206. )
  207. elif launcher is not None:
  208. raise RuntimeError(f"launcher='{launcher}' is not supported")
  209. if prior is not None:
  210. return int(prior)
  211. else:
  212. # prior is None and WORLD_SIZE is None -> WORLD_SIZE = 1
  213. return int(os.environ.get("WORLD_SIZE", "1"))
  214. def get_local_rank(prior=None, launcher: str = None) -> Optional[int]:
  215. # LOCAL_RANK is same as GPU device id
  216. if prior is None:
  217. if launcher == "slurm":
  218. if not is_in_slurm_step():
  219. raise RuntimeError("This process seems not to be launched by 'srun'")
  220. prior = int(os.environ["SLURM_LOCALID"])
  221. elif launcher == "mpi":
  222. raise RuntimeError(
  223. "launcher=mpi is used for 'multiprocessing-distributed' mode"
  224. )
  225. elif launcher is not None:
  226. raise RuntimeError(f"launcher='{launcher}' is not supported")
  227. if prior is not None:
  228. return int(prior)
  229. elif "LOCAL_RANK" in os.environ:
  230. return int(os.environ["LOCAL_RANK"])
  231. elif "CUDA_VISIBLE_DEVICES" in os.environ:
  232. # There are two possibility:
  233. # - "CUDA_VISIBLE_DEVICES" is set to multiple GPU ids. e.g. "0.1,2"
  234. # => This intends to specify multiple devices to to be used exactly
  235. # and local_rank information is possibly insufficient.
  236. # - "CUDA_VISIBLE_DEVICES" is set to an id. e.g. "1"
  237. # => This could be used for LOCAL_RANK
  238. cvd = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
  239. if len(cvd) == 1 and "LOCAL_RANK" not in os.environ:
  240. # If CUDA_VISIBLE_DEVICES is set and LOCAL_RANK is not set,
  241. # then use it as LOCAL_RANK.
  242. # Unset CUDA_VISIBLE_DEVICES
  243. # because the other device must be visible to communicate
  244. return int(os.environ.pop("CUDA_VISIBLE_DEVICES"))
  245. else:
  246. return None
  247. else:
  248. return None
  249. def get_master_addr(prior=None, launcher: str = None) -> Optional[str]:
  250. if prior is None:
  251. if launcher == "slurm":
  252. if not is_in_slurm_step():
  253. raise RuntimeError("This process seems not to be launched by 'srun'")
  254. # e.g nodelist = foo[1-10],bar[3-8] or foo4,bar[2-10]
  255. nodelist = os.environ["SLURM_STEP_NODELIST"]
  256. prior = nodelist.split(",")[0].split("-")[0].replace("[", "")
  257. if prior is not None:
  258. return str(prior)
  259. else:
  260. return os.environ.get("MASTER_ADDR")
  261. def get_master_port(prior=None) -> Optional[int]:
  262. if prior is not None:
  263. return prior
  264. else:
  265. return _int_or_none(os.environ.get("MASTER_PORT"))
  266. def get_node_rank(prior=None, launcher: str = None) -> Optional[int]:
  267. """Get Node Rank.
  268. Use for "multiprocessing distributed" mode.
  269. The initial RANK equals to the Node id in this case and
  270. the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed.
  271. """
  272. if prior is not None:
  273. return prior
  274. elif launcher == "slurm":
  275. if not is_in_slurm_step():
  276. raise RuntimeError("This process seems not to be launched by 'srun'")
  277. # Assume ntasks_per_node == 1
  278. if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]:
  279. raise RuntimeError(
  280. "Run with --ntasks_per_node=1 if mutliprocessing_distributed=true"
  281. )
  282. return int(os.environ["SLURM_NODEID"])
  283. elif launcher == "mpi":
  284. # Use mpi4py only for initialization and not using for communication
  285. from mpi4py import MPI
  286. comm = MPI.COMM_WORLD
  287. # Assume ntasks_per_node == 1 (We can't check whether it is or not)
  288. return comm.Get_rank()
  289. elif launcher is not None:
  290. raise RuntimeError(f"launcher='{launcher}' is not supported")
  291. else:
  292. return _int_or_none(os.environ.get("RANK"))
  293. def get_num_nodes(prior=None, launcher: str = None) -> Optional[int]:
  294. """Get the number of nodes.
  295. Use for "multiprocessing distributed" mode.
  296. RANK equals to the Node id in this case and
  297. the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed.
  298. """
  299. if prior is not None:
  300. return prior
  301. elif launcher == "slurm":
  302. if not is_in_slurm_step():
  303. raise RuntimeError("This process seems not to be launched by 'srun'")
  304. # Assume ntasks_per_node == 1
  305. if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]:
  306. raise RuntimeError(
  307. "Run with --ntasks_per_node=1 if mutliprocessing_distributed=true"
  308. )
  309. return int(os.environ["SLURM_STEP_NUM_NODES"])
  310. elif launcher == "mpi":
  311. # Use mpi4py only for initialization and not using for communication
  312. from mpi4py import MPI
  313. comm = MPI.COMM_WORLD
  314. # Assume ntasks_per_node == 1 (We can't check whether it is or not)
  315. return comm.Get_size()
  316. elif launcher is not None:
  317. raise RuntimeError(f"launcher='{launcher}' is not supported")
  318. else:
  319. # prior is None -> NUM_NODES = 1
  320. return int(os.environ.get("WORLD_SIZE", 1))