嘉渊 2 лет назад
Родитель
Сommit
d5c818131b
1 измененных файлов с 8 добавлено и 4 удалено
  1. 8 4
      funasr/bin/build_trainer.py

+ 8 - 4
funasr/bin/build_trainer.py

@@ -532,11 +532,9 @@ def build_trainer(modelscope_dict,
     args = build_args(args, parser, extra_task_params)
     args = build_args(args, parser, extra_task_params)
 
 
     if args.local_rank is not None:
     if args.local_rank is not None:
-        args.distributed = True
-        args.simple_ddp = True
+        distributed = True
     else:
     else:
-        args.distributed = False
-        args.ngpu = 1
+        distributed = False
     args.local_rank = args.local_rank if args.local_rank is not None else 0
     args.local_rank = args.local_rank if args.local_rank is not None else 0
     local_rank = args.local_rank
     local_rank = args.local_rank
     if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
     if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
@@ -593,6 +591,12 @@ def build_trainer(modelscope_dict,
     args.batch_type = "length"
     args.batch_type = "length"
     args.oss_bucket = None
     args.oss_bucket = None
     args.input_size = None
     args.input_size = None
+    if distributed:
+        args.distributed = True
+        args.simple_ddp = True
+    else:
+        args.distributed = False
+        args.ngpu = 1
     if optim is not None:
     if optim is not None:
         args.optim = optim
         args.optim = optim
     if lr is not None:
     if lr is not None: