trainer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. import os
  2. import time
  3. import torch
  4. import logging
  5. from tqdm import tqdm
  6. import torch.distributed as dist
  7. from contextlib import nullcontext
  8. # from torch.utils.tensorboard import SummaryWriter
  9. from tensorboardX import SummaryWriter
  10. from pathlib import Path
  11. from funasr.train_utils.device_funcs import to_device
  12. from funasr.train_utils.recursive_op import recursive_average
  13. from funasr.train_utils.average_nbest_models import average_checkpoints
  14. class Trainer:
  15. """
  16. A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
  17. and optionally resuming from a saved checkpoint.
  18. Attributes:
  19. max_epoch (int): Maximum number of epochs for training.
  20. model (torch.nn.Module): The model to be trained.
  21. optim (torch.optim.Optimizer): The optimizer to use for training.
  22. scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
  23. dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
  24. dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
  25. output_dir (str): Directory where model checkpoints will be saved.
  26. resume (str, optional): Path to a checkpoint to resume training from.
  27. """
  28. def __init__(self, model,
  29. optim,
  30. scheduler,
  31. dataloader_train,
  32. dataloader_val,
  33. local_rank,
  34. use_ddp=False,
  35. use_fsdp=False,
  36. output_dir: str="./",
  37. **kwargs):
  38. """
  39. Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
  40. Args:
  41. model (torch.nn.Module): The model to be trained.
  42. optim (torch.optim.Optimizer): The optimizer to use for training.
  43. scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
  44. dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
  45. dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
  46. **kwargs: Additional keyword arguments:
  47. max_epoch (int): The maximum number of epochs for training.
  48. output_dir (str): The directory where model checkpoints will be saved. Default is './'.
  49. resume (str, optional): The file path to a checkpoint to resume training from.
  50. """
  51. self.model = model
  52. self.optim = optim
  53. self.scheduler = scheduler
  54. self.dataloader_train = dataloader_train
  55. self.dataloader_val = dataloader_val
  56. self.output_dir = output_dir
  57. self.resume = kwargs.get('resume', True)
  58. self.start_epoch = 0
  59. self.max_epoch = kwargs.get('max_epoch', 100)
  60. self.local_rank = local_rank
  61. self.use_ddp = use_ddp
  62. self.use_fsdp = use_fsdp
  63. self.device = next(model.parameters()).device
  64. self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
  65. self.kwargs = kwargs
  66. self.log_interval = kwargs.get("log_interval", 50)
  67. try:
  68. rank = dist.get_rank()
  69. world_size = dist.get_world_size()
  70. except:
  71. rank = 0
  72. world_size = 1
  73. logging.warning("distributed is not initialized, only single shard")
  74. self.rank = rank
  75. self.world_size = world_size
  76. os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
  77. self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
  78. def _save_checkpoint(self, epoch):
  79. """
  80. Saves a checkpoint containing the model's state, the optimizer's state,
  81. and the scheduler's state at the end of the given epoch. This method is
  82. intended to be called at the end of each epoch to save the training progress.
  83. Args:
  84. epoch (int): The epoch number at which the checkpoint is being saved.
  85. """
  86. state = {
  87. 'epoch': epoch,
  88. 'state_dict': self.model.state_dict(),
  89. 'optimizer': self.optim.state_dict(),
  90. 'scheduler': self.scheduler.state_dict(),
  91. }
  92. # Create output directory if it does not exist
  93. os.makedirs(self.output_dir, exist_ok=True)
  94. filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
  95. torch.save(state, filename)
  96. print(f'Checkpoint saved to {filename}')
  97. latest = Path(os.path.join(self.output_dir, f'model.pt'))
  98. try:
  99. latest.unlink()
  100. except:
  101. pass
  102. latest.symlink_to(filename)
  103. def _resume_checkpoint(self, resume_path):
  104. """
  105. Resumes training from a checkpoint at the given file path.
  106. Loads the model's state, the optimizer's state, and the scheduler's state.
  107. Args:
  108. resume_path (str): The file path to the checkpoint to resume from.
  109. """
  110. ckpt = os.path.join(resume_path, "model.pt")
  111. if os.path.isfile(ckpt):
  112. checkpoint = torch.load(ckpt)
  113. self.start_epoch = checkpoint['epoch'] + 1
  114. self.model.load_state_dict(checkpoint['state_dict'])
  115. self.optim.load_state_dict(checkpoint['optimizer'])
  116. self.scheduler.load_state_dict(checkpoint['scheduler'])
  117. print(f"Checkpoint loaded successfully from '{ckpt}'")
  118. else:
  119. print(f"No checkpoint found at '{ckpt}', starting from scratch")
  120. if self.use_ddp or self.use_fsdp:
  121. dist.barrier()
  122. def run(self):
  123. """
  124. Starts the training process, iterating over epochs, training the model,
  125. and saving checkpoints at the end of each epoch.
  126. """
  127. if self.resume:
  128. self._resume_checkpoint(self.output_dir)
  129. for epoch in range(self.start_epoch, self.max_epoch + 1):
  130. self._train_epoch(epoch)
  131. if self.use_ddp or self.use_fsdp:
  132. dist.barrier()
  133. self._validate_epoch(epoch)
  134. if self.use_ddp or self.use_fsdp:
  135. dist.barrier()
  136. if self.rank == 0:
  137. self._save_checkpoint(epoch)
  138. if self.use_ddp or self.use_fsdp:
  139. dist.barrier()
  140. self.scheduler.step()
  141. if self.rank == 0:
  142. average_checkpoints(self.output_dir, self.avg_nbest_model)
  143. if self.use_ddp or self.use_fsdp:
  144. dist.barrier()
  145. if self.writer:
  146. self.writer.close()
  147. def _train_epoch(self, epoch):
  148. """
  149. Defines the training process for a single epoch with gradient accumulation.
  150. Args:
  151. epoch (int): The current epoch number.
  152. """
  153. self.model.train()
  154. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
  155. dynamic_ncols=True)
  156. # Set the number of steps for gradient accumulation
  157. accum_grad = self.kwargs.get("accum_grad", 1)
  158. # Initialize the gradient accumulation
  159. self.optim.zero_grad()
  160. speed_stats = {}
  161. time5 = time.perf_counter()
  162. for batch_idx, batch in enumerate(self.dataloader_train):
  163. time1 = time.perf_counter()
  164. speed_stats["data_load"] = f"{time1-time5:0.3f}"
  165. batch = to_device(batch, self.device)
  166. my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
  167. with my_context():
  168. time2 = time.perf_counter()
  169. # print("before, GPU, memory: {:.3f} GB, "
  170. # "{:.3f} GB, "
  171. # "{:.3f} GB, "
  172. # "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
  173. # torch.cuda.max_memory_allocated()/1024/1024/1024,
  174. # torch.cuda.memory_reserved()/1024/1024/1024,
  175. # torch.cuda.max_memory_reserved()/1024/1024/1024,
  176. # ))
  177. retval = self.model(**batch)
  178. torch.cuda.empty_cache()
  179. # print("after, GPU, memory: {:.3f} GB, "
  180. # "{:.3f} GB, "
  181. # "{:.3f} GB, "
  182. # "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
  183. # torch.cuda.max_memory_allocated()/1024/1024/1024,
  184. # torch.cuda.memory_reserved()/1024/1024/1024,
  185. # torch.cuda.max_memory_reserved()/1024/1024/1024,
  186. # ))
  187. time3 = time.perf_counter()
  188. speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
  189. loss, stats, weight = retval
  190. stats = {k: v for k, v in stats.items() if v is not None}
  191. if self.use_ddp or self.use_fsdp:
  192. # Apply weighted averaging for loss and stats
  193. loss = (loss * weight.type(loss.dtype)).sum()
  194. # if distributed, this method can also apply all_reduce()
  195. stats, weight = recursive_average(stats, weight, distributed=True)
  196. # Now weight is summation over all workers
  197. loss /= weight
  198. # Multiply world_size because DistributedDataParallel
  199. # automatically normalizes the gradient by world_size.
  200. loss *= self.world_size
  201. # Scale the loss since we're not updating for every mini-batch
  202. loss = loss / accum_grad
  203. loss.backward()
  204. time4 = time.perf_counter()
  205. speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
  206. # Perform an optimizer step only after accumulating enough gradients
  207. if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
  208. # Perform gradient clipping if it is set
  209. if self.kwargs.get("grad_clip", None) is not None:
  210. grad_norm = torch.nn.utils.clip_grad_norm_(
  211. self.model.parameters(),
  212. max_norm=self.kwargs.get("grad_clip", 10.0),
  213. norm_type=self.kwargs.get("grad_clip_type", 2.0),
  214. )
  215. if not torch.isfinite(grad_norm):
  216. logging.warning(
  217. f"The grad norm is {grad_norm}. Skipping updating the model."
  218. )
  219. self.optim.zero_grad() # Reset gradients
  220. continue
  221. # Execute an optimization step (update model parameters)
  222. if self.use_ddp or self.use_fsdp:
  223. dist.barrier()
  224. self.optim.step()
  225. self.scheduler.step()
  226. # Clear gradients for the next accumulation stage
  227. self.optim.zero_grad()
  228. total_time = f"{time.perf_counter() - time5:0.3f}"
  229. time5 = time.perf_counter()
  230. speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
  231. speed_stats["total_time"] = total_time
  232. if batch_idx % self.log_interval == 0 or batch_idx == len(self.dataloader_train) - 1:
  233. pbar.update(self.log_interval)
  234. gpu_info = "GPU, memory: {:.3f} GB, " \
  235. "{:.3f} GB, "\
  236. "{:.3f} GB, "\
  237. "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
  238. torch.cuda.max_memory_allocated()/1024/1024/1024,
  239. torch.cuda.memory_reserved()/1024/1024/1024,
  240. torch.cuda.max_memory_reserved()/1024/1024/1024,
  241. )
  242. description = (
  243. f"rank: {self.local_rank}, "
  244. f"Train epoch: {epoch}/{self.max_epoch}, "
  245. f"step {batch_idx}/{len(self.dataloader_train)}, "
  246. f"{speed_stats}, "
  247. f"(loss: {loss.detach().cpu().item():.3f}), "
  248. f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
  249. f"{gpu_info}"
  250. )
  251. pbar.set_description(description)
  252. if self.writer:
  253. self.writer.add_scalar(f'rank{self.local_rank}, Loss/train', loss.item(),
  254. epoch*len(self.dataloader_train) + batch_idx)
  255. for key, var in stats.items():
  256. self.writer.add_scalar(f'rank{self.local_rank}, {key}/train', var.item(),
  257. epoch * len(self.dataloader_train) + batch_idx)
  258. for key, var in speed_stats.items():
  259. self.writer.add_scalar(f'rank{self.local_rank}, {key}/train', eval(var),
  260. epoch * len(self.dataloader_train) + batch_idx)
  261. # if batch_idx == 2:
  262. # break
  263. pbar.close()
  264. def _validate_epoch(self, epoch):
  265. """
  266. Defines the validation process for a single epoch.
  267. Should be implemented with the actual model validation steps.
  268. Args:
  269. epoch (int): The current epoch number.
  270. """
  271. self.model.eval()
  272. with torch.no_grad():
  273. pbar = tqdm(colour="red", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_val),
  274. dynamic_ncols=True)
  275. speed_stats = {}
  276. time5 = time.perf_counter()
  277. for batch_idx, batch in enumerate(self.dataloader_val):
  278. time1 = time.perf_counter()
  279. speed_stats["data_load"] = f"{time1 - time5:0.3f}"
  280. batch = to_device(batch, self.device)
  281. time2 = time.perf_counter()
  282. retval = self.model(**batch)
  283. time3 = time.perf_counter()
  284. speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
  285. loss, stats, weight = retval
  286. stats = {k: v for k, v in stats.items() if v is not None}
  287. if self.use_ddp or self.use_fsdp:
  288. # Apply weighted averaging for loss and stats
  289. loss = (loss * weight.type(loss.dtype)).sum()
  290. # if distributed, this method can also apply all_reduce()
  291. stats, weight = recursive_average(stats, weight, distributed=True)
  292. # Now weight is summation over all workers
  293. loss /= weight
  294. # Multiply world_size because DistributedDataParallel
  295. # automatically normalizes the gradient by world_size.
  296. loss *= self.world_size
  297. # Scale the loss since we're not updating for every mini-batch
  298. loss = loss
  299. time4 = time.perf_counter()
  300. if batch_idx % self.log_interval == 0 or batch_idx == len(self.dataloader_train) - 1:
  301. pbar.update(self.log_interval)
  302. description = (
  303. f"rank: {self.local_rank}, "
  304. f"validation epoch: {epoch}/{self.max_epoch}, "
  305. f"step {batch_idx}/{len(self.dataloader_train)}, "
  306. f"{speed_stats}, "
  307. f"(loss: {loss.detach().cpu().item():.3f}), "
  308. f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
  309. f"rank: {self.local_rank}"
  310. )
  311. pbar.set_description(description)
  312. if self.writer:
  313. self.writer.add_scalar(f"rank{self.local_rank}, Loss/val", loss.item(),
  314. epoch*len(self.dataloader_train) + batch_idx)
  315. for key, var in stats.items():
  316. self.writer.add_scalar(f'rank{self.local_rank}, {key}/val', var.item(),
  317. epoch * len(self.dataloader_train) + batch_idx)
  318. for key, var in speed_stats.items():
  319. self.writer.add_scalar(f'rank{self.local_rank}, {key}/val', eval(var),
  320. epoch * len(self.dataloader_train) + batch_idx)