嘉渊 2 жил өмнө
parent
commit
95d6db2656

+ 1 - 1
funasr/datasets/small_datasets/sequence_iter_factory.py

@@ -83,7 +83,7 @@ class SequenceIterFactory(AbsIterFactory):
             args.max_update = len(bs_list) * args.max_epoch
             logging.info("Max update: {}".format(args.max_update))
 
-        if args.distributed:
+        if args.distributed and mode=="train":
             world_size = torch.distributed.get_world_size()
             rank = torch.distributed.get_rank()
             for batch in batches: