|
|
@@ -302,17 +302,14 @@ class Trainer:
|
|
|
)
|
|
|
pbar.set_description(description)
|
|
|
if self.writer:
|
|
|
- self.writer.add_scalar(f'rank{self.local_rank}_Loss/train', loss.item(),
|
|
|
- epoch*len(self.dataloader_train) + batch_idx)
|
|
|
+ self.writer.add_scalar(f'rank{self.local_rank}_Loss/train', loss.item(), self.batch_total)
|
|
|
+ self.writer.add_scalar(f'rank{self.local_rank}_lr/train', lr, self.batch_total)
|
|
|
for key, var in stats.items():
|
|
|
- self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', var.item(),
|
|
|
- epoch * len(self.dataloader_train) + batch_idx)
|
|
|
+ self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', var.item(), self.batch_total)
|
|
|
for key, var in speed_stats.items():
|
|
|
- self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var),
|
|
|
- epoch * len(self.dataloader_train) + batch_idx)
|
|
|
-
|
|
|
- # if batch_idx == 2:
|
|
|
- # break
|
|
|
+ self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total)
|
|
|
+
|
|
|
+
|
|
|
pbar.close()
|
|
|
|
|
|
def _validate_epoch(self, epoch):
|
|
|
@@ -356,7 +353,10 @@ class Trainer:
|
|
|
|
|
|
if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_val):
|
|
|
pbar.update(self.log_interval)
|
|
|
+ time_now = datetime.now()
|
|
|
+ time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
description = (
|
|
|
+ f"{time_now}, "
|
|
|
f"rank: {self.local_rank}, "
|
|
|
f"validation epoch: {epoch}/{self.max_epoch}, "
|
|
|
f"step: {batch_idx+1}/{len(self.dataloader_val)}, "
|