|
|
@@ -5,6 +5,8 @@ import logging
|
|
|
from tqdm import tqdm
|
|
|
import torch.distributed as dist
|
|
|
from contextlib import nullcontext
|
|
|
+# from torch.utils.tensorboard import SummaryWriter
|
|
|
+from tensorboardX import SummaryWriter
|
|
|
|
|
|
from funasr.train_utils.device_funcs import to_device
|
|
|
from funasr.train_utils.recursive_op import recursive_average
|
|
|
@@ -34,6 +36,7 @@ class Trainer:
|
|
|
local_rank,
|
|
|
use_ddp=False,
|
|
|
use_fsdp=False,
|
|
|
+ output_dir: str="./",
|
|
|
**kwargs):
|
|
|
"""
|
|
|
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
|
|
|
@@ -55,7 +58,7 @@ class Trainer:
|
|
|
self.scheduler = scheduler
|
|
|
self.dataloader_train = dataloader_train
|
|
|
self.dataloader_val = dataloader_val
|
|
|
- self.output_dir = kwargs.get('output_dir', './')
|
|
|
+ self.output_dir = output_dir
|
|
|
self.resume = kwargs.get('resume', True)
|
|
|
self.start_epoch = 0
|
|
|
self.max_epoch = kwargs.get('max_epoch', 100)
|
|
|
@@ -77,6 +80,10 @@ class Trainer:
|
|
|
logging.warning("distributed is not initialized, only single shard")
|
|
|
self.rank = rank
|
|
|
self.world_size = world_size
|
|
|
+
|
|
|
+ os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
|
|
|
+ self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
|
|
|
+
|
|
|
|
|
|
def _save_checkpoint(self, epoch):
|
|
|
"""
|
|
|
@@ -128,6 +135,8 @@ class Trainer:
|
|
|
if self.rank == 0:
|
|
|
self._save_checkpoint(epoch)
|
|
|
self.scheduler.step()
|
|
|
+
|
|
|
+ self.writer.close()
|
|
|
|
|
|
def _train_epoch(self, epoch):
|
|
|
"""
|
|
|
@@ -215,7 +224,16 @@ class Trainer:
|
|
|
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
|
|
|
)
|
|
|
pbar.set_description(description)
|
|
|
-
|
|
|
+ if self.writer:
|
|
|
+ self.writer.add_scalar('Loss/train', loss.item(),
|
|
|
+ epoch*len(self.dataloader_train) + batch_idx)
|
|
|
+ for key, var in stats.items():
|
|
|
+ self.writer.add_scalar(f'{key}/train', var.item(),
|
|
|
+ epoch * len(self.dataloader_train) + batch_idx)
|
|
|
+ for key, var in speed_stats.items():
|
|
|
+ self.writer.add_scalar(f'{key}/train', eval(var),
|
|
|
+ epoch * len(self.dataloader_train) + batch_idx)
|
|
|
+
|
|
|
# if batch_idx == 2:
|
|
|
# break
|
|
|
pbar.close()
|