|
|
@@ -53,7 +53,7 @@ class DistributedOption:
|
|
|
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
|
|
|
os.environ.setdefault("NCCL_BLOCKING_WAIT", "1")
|
|
|
|
|
|
- torch.distributed.init_process_group(backend='nccl',
|
|
|
+ torch.distributed.init_process_group(backend=self.dist_backend,
|
|
|
init_method=self.dist_init_method,
|
|
|
world_size=args.dist_world_size,
|
|
|
rank=args.dist_rank)
|
|
|
@@ -113,7 +113,7 @@ class DistributedOption:
|
|
|
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
|
|
|
os.environ.setdefault("NCCL_BLOCKING_WAIT", "1")
|
|
|
|
|
|
- torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
|
|
+ torch.distributed.init_process_group(backend=self.dist_backend, init_method='env://')
|
|
|
self.dist_rank = torch.distributed.get_rank()
|
|
|
self.dist_world_size = torch.distributed.get_world_size()
|
|
|
self.local_rank = args.local_rank
|