|
|
@@ -147,9 +147,17 @@ class Trainer:
|
|
|
for epoch in range(self.start_epoch, self.max_epoch + 1):
|
|
|
|
|
|
self._train_epoch(epoch)
|
|
|
+
|
|
|
|
|
|
+ if self.use_ddp or self.use_fsdp:
|
|
|
+ dist.barrier()
|
|
|
+
|
|
|
self._validate_epoch(epoch)
|
|
|
-
|
|
|
+
|
|
|
+ if self.use_ddp or self.use_fsdp:
|
|
|
+ dist.barrier()
|
|
|
+
|
|
|
+
|
|
|
if self.rank == 0:
|
|
|
self._save_checkpoint(epoch)
|
|
|
|
|
|
@@ -164,7 +172,9 @@ class Trainer:
|
|
|
|
|
|
if self.use_ddp or self.use_fsdp:
|
|
|
dist.barrier()
|
|
|
- self.writer.close()
|
|
|
+
|
|
|
+ if self.writer:
|
|
|
+ self.writer.close()
|
|
|
|
|
|
|
|
|
def _train_epoch(self, epoch):
|
|
|
@@ -230,6 +240,8 @@ class Trainer:
|
|
|
continue
|
|
|
|
|
|
# Execute an optimization step (update model parameters)
|
|
|
+ if self.use_ddp or self.use_fsdp:
|
|
|
+ dist.barrier()
|
|
|
self.optim.step()
|
|
|
self.scheduler.step()
|
|
|
# Clear gradients for the next accumulation stage
|
|
|
@@ -244,7 +256,7 @@ class Trainer:
|
|
|
pbar.update(1)
|
|
|
if self.local_rank == 0:
|
|
|
description = (
|
|
|
- f"Epoch: {epoch}/{self.max_epoch}, "
|
|
|
+ f"Train epoch: {epoch}/{self.max_epoch}, "
|
|
|
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
|
|
f"{speed_stats}, "
|
|
|
f"(loss: {loss.detach().cpu().item():.3f}), "
|
|
|
@@ -306,7 +318,7 @@ class Trainer:
|
|
|
pbar.update(1)
|
|
|
if self.local_rank == 0:
|
|
|
description = (
|
|
|
- f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
|
|
|
+ f"validation epoch: {epoch}/{self.max_epoch}, "
|
|
|
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
|
|
f"{speed_stats}, "
|
|
|
f"(loss: {loss.detach().cpu().item():.3f}), "
|