|
|
@@ -83,7 +83,7 @@ class SequenceIterFactory(AbsIterFactory):
|
|
|
args.max_update = len(bs_list) * args.max_epoch
|
|
|
logging.info("Max update: {}".format(args.max_update))
|
|
|
|
|
|
- if args.distributed:
|
|
|
+ if args.distributed and mode=="train":
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
rank = torch.distributed.get_rank()
|
|
|
for batch in batches:
|