trainer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. import os
  2. import time
  3. import torch
  4. import logging
  5. from tqdm import tqdm
  6. from datetime import datetime
  7. import torch.distributed as dist
  8. from contextlib import nullcontext
  9. # from torch.utils.tensorboard import SummaryWriter
  10. from tensorboardX import SummaryWriter
  11. from pathlib import Path
  12. from funasr.train_utils.device_funcs import to_device
  13. from funasr.train_utils.recursive_op import recursive_average
  14. from funasr.train_utils.average_nbest_models import average_checkpoints
  15. class Trainer:
  16. """
  17. A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
  18. and optionally resuming from a saved checkpoint.
  19. Attributes:
  20. max_epoch (int): Maximum number of epochs for training.
  21. model (torch.nn.Module): The model to be trained.
  22. optim (torch.optim.Optimizer): The optimizer to use for training.
  23. scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
  24. dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
  25. dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
  26. output_dir (str): Directory where model checkpoints will be saved.
  27. resume (str, optional): Path to a checkpoint to resume training from.
  28. """
  29. def __init__(self, model,
  30. optim,
  31. scheduler,
  32. dataloader_train,
  33. dataloader_val,
  34. local_rank,
  35. use_ddp=False,
  36. use_fsdp=False,
  37. output_dir: str="./",
  38. **kwargs):
  39. """
  40. Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
  41. Args:
  42. model (torch.nn.Module): The model to be trained.
  43. optim (torch.optim.Optimizer): The optimizer to use for training.
  44. scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
  45. dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
  46. dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
  47. **kwargs: Additional keyword arguments:
  48. max_epoch (int): The maximum number of epochs for training.
  49. output_dir (str): The directory where model checkpoints will be saved. Default is './'.
  50. resume (str, optional): The file path to a checkpoint to resume training from.
  51. """
  52. self.model = model
  53. self.optim = optim
  54. self.scheduler = scheduler
  55. self.dataloader_train = dataloader_train
  56. self.dataloader_val = dataloader_val
  57. self.output_dir = output_dir
  58. self.resume = kwargs.get('resume', True)
  59. self.start_epoch = 0
  60. self.max_epoch = kwargs.get('max_epoch', 100)
  61. self.local_rank = local_rank
  62. self.use_ddp = use_ddp
  63. self.use_fsdp = use_fsdp
  64. self.device = next(model.parameters()).device
  65. self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
  66. self.kwargs = kwargs
  67. self.log_interval = kwargs.get("log_interval", 50)
  68. self.batch_total = 0
  69. try:
  70. rank = dist.get_rank()
  71. world_size = dist.get_world_size()
  72. except:
  73. rank = 0
  74. world_size = 1
  75. logging.warning("distributed is not initialized, only single shard")
  76. self.rank = rank
  77. self.world_size = world_size
  78. os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
  79. self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
  80. def _save_checkpoint(self, epoch):
  81. """
  82. Saves a checkpoint containing the model's state, the optimizer's state,
  83. and the scheduler's state at the end of the given epoch. This method is
  84. intended to be called at the end of each epoch to save the training progress.
  85. Args:
  86. epoch (int): The epoch number at which the checkpoint is being saved.
  87. """
  88. state = {
  89. 'epoch': epoch,
  90. 'state_dict': self.model.state_dict(),
  91. 'optimizer': self.optim.state_dict(),
  92. 'scheduler': self.scheduler.state_dict(),
  93. }
  94. # Create output directory if it does not exist
  95. os.makedirs(self.output_dir, exist_ok=True)
  96. filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
  97. torch.save(state, filename)
  98. print(f'\nCheckpoint saved to {filename}\n')
  99. latest = Path(os.path.join(self.output_dir, f'model.pt'))
  100. torch.save(state, latest)
  101. def _resume_checkpoint(self, resume_path):
  102. """
  103. Resumes training from a checkpoint at the given file path.
  104. Loads the model's state, the optimizer's state, and the scheduler's state.
  105. Args:
  106. resume_path (str): The file path to the checkpoint to resume from.
  107. """
  108. ckpt = os.path.join(resume_path, "model.pt")
  109. if os.path.isfile(ckpt):
  110. checkpoint = torch.load(ckpt)
  111. self.start_epoch = checkpoint['epoch'] + 1
  112. # self.model.load_state_dict(checkpoint['state_dict'])
  113. src_state = checkpoint['state_dict']
  114. dst_state = self.model.state_dict()
  115. for k in dst_state.keys():
  116. if not k.startswith("module.") and "module."+k in src_state.keys():
  117. k_ddp = "module."+k
  118. else:
  119. k_ddp = k
  120. if k_ddp in src_state.keys():
  121. dst_state[k] = src_state[k_ddp]
  122. else:
  123. print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
  124. self.model.load_state_dict(dst_state)
  125. self.optim.load_state_dict(checkpoint['optimizer'])
  126. self.scheduler.load_state_dict(checkpoint['scheduler'])
  127. print(f"Checkpoint loaded successfully from '{ckpt}'")
  128. else:
  129. print(f"No checkpoint found at '{ckpt}', starting from scratch")
  130. if self.use_ddp or self.use_fsdp:
  131. dist.barrier()
  132. def run(self):
  133. """
  134. Starts the training process, iterating over epochs, training the model,
  135. and saving checkpoints at the end of each epoch.
  136. """
  137. if self.resume:
  138. self._resume_checkpoint(self.output_dir)
  139. for epoch in range(self.start_epoch, self.max_epoch + 1):
  140. time1 = time.perf_counter()
  141. self._train_epoch(epoch)
  142. if self.use_ddp or self.use_fsdp:
  143. dist.barrier()
  144. self._validate_epoch(epoch)
  145. if self.use_ddp or self.use_fsdp:
  146. dist.barrier()
  147. if self.rank == 0:
  148. self._save_checkpoint(epoch)
  149. if self.use_ddp or self.use_fsdp:
  150. dist.barrier()
  151. self.scheduler.step()
  152. time2 = time.perf_counter()
  153. time_escaped = (time2 - time1)/3600.0
  154. print(f"\ntime_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f}\n")
  155. if self.rank == 0:
  156. average_checkpoints(self.output_dir, self.avg_nbest_model)
  157. if self.use_ddp or self.use_fsdp:
  158. dist.barrier()
  159. if self.writer:
  160. self.writer.close()
  161. def _train_epoch(self, epoch):
  162. """
  163. Defines the training process for a single epoch with gradient accumulation.
  164. Args:
  165. epoch (int): The current epoch number.
  166. """
  167. self.model.train()
  168. pbar = tqdm(colour="blue", desc=f"rank: {self.local_rank}, Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
  169. dynamic_ncols=True)
  170. # Set the number of steps for gradient accumulation
  171. accum_grad = self.kwargs.get("accum_grad", 1)
  172. # Initialize the gradient accumulation
  173. self.optim.zero_grad()
  174. speed_stats = {}
  175. time5 = time.perf_counter()
  176. for batch_idx, batch in enumerate(self.dataloader_train):
  177. self.batch_total += 1
  178. time1 = time.perf_counter()
  179. speed_stats["data_load"] = f"{time1-time5:0.3f}"
  180. batch = to_device(batch, self.device)
  181. my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
  182. with my_context():
  183. time2 = time.perf_counter()
  184. retval = self.model(**batch)
  185. torch.cuda.empty_cache()
  186. time3 = time.perf_counter()
  187. speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
  188. loss, stats, weight = retval
  189. stats = {k: v for k, v in stats.items() if v is not None}
  190. if self.use_ddp or self.use_fsdp:
  191. # Apply weighted averaging for loss and stats
  192. loss = (loss * weight.type(loss.dtype)).sum()
  193. # if distributed, this method can also apply all_reduce()
  194. stats, weight = recursive_average(stats, weight, distributed=True)
  195. # Now weight is summation over all workers
  196. loss /= weight
  197. # Multiply world_size because DistributedDataParallel
  198. # automatically normalizes the gradient by world_size.
  199. loss *= self.world_size
  200. # Scale the loss since we're not updating for every mini-batch
  201. loss = loss / accum_grad
  202. loss.backward()
  203. time4 = time.perf_counter()
  204. speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
  205. # Perform an optimizer step only after accumulating enough gradients
  206. if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
  207. # Perform gradient clipping if it is set
  208. if self.kwargs.get("grad_clip", None) is not None:
  209. grad_norm = torch.nn.utils.clip_grad_norm_(
  210. self.model.parameters(),
  211. max_norm=self.kwargs.get("grad_clip", 10.0),
  212. norm_type=self.kwargs.get("grad_clip_type", 2.0),
  213. )
  214. if not torch.isfinite(grad_norm):
  215. logging.warning(
  216. f"The grad norm is {grad_norm}. Skipping updating the model."
  217. )
  218. self.optim.zero_grad() # Reset gradients
  219. continue
  220. # Execute an optimization step (update model parameters)
  221. if self.use_ddp or self.use_fsdp:
  222. dist.barrier()
  223. self.optim.step()
  224. self.scheduler.step()
  225. # Clear gradients for the next accumulation stage
  226. self.optim.zero_grad()
  227. total_time = f"{time.perf_counter() - time5:0.3f}"
  228. time5 = time.perf_counter()
  229. speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
  230. speed_stats["total_time"] = total_time
  231. if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_train):
  232. pbar.update(self.log_interval)
  233. gpu_info = "GPU, memory: {:.3f} GB, " \
  234. "{:.3f} GB, "\
  235. "{:.3f} GB, "\
  236. "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
  237. torch.cuda.max_memory_allocated()/1024/1024/1024,
  238. torch.cuda.memory_reserved()/1024/1024/1024,
  239. torch.cuda.max_memory_reserved()/1024/1024/1024,
  240. )
  241. lr = self.scheduler.get_last_lr()[0]
  242. time_now = datetime.now()
  243. time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
  244. description = (
  245. f"{time_now}, "
  246. f"rank: {self.local_rank}, "
  247. f"epoch: {epoch}/{self.max_epoch}, "
  248. f"step: {batch_idx+1}/{len(self.dataloader_train)}, total: {self.batch_total}, "
  249. f"(loss: {loss.detach().cpu().item():.3f}), "
  250. f"(lr: {lr:.3e}), "
  251. f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
  252. f"{speed_stats}, "
  253. f"{gpu_info}"
  254. )
  255. pbar.set_description(description)
  256. if self.writer:
  257. self.writer.add_scalar(f'rank{self.local_rank}_Loss/train', loss.item(),
  258. epoch*len(self.dataloader_train) + batch_idx)
  259. for key, var in stats.items():
  260. self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', var.item(),
  261. epoch * len(self.dataloader_train) + batch_idx)
  262. for key, var in speed_stats.items():
  263. self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var),
  264. epoch * len(self.dataloader_train) + batch_idx)
  265. # if batch_idx == 2:
  266. # break
  267. pbar.close()
  268. def _validate_epoch(self, epoch):
  269. """
  270. Defines the validation process for a single epoch.
  271. Should be implemented with the actual model validation steps.
  272. Args:
  273. epoch (int): The current epoch number.
  274. """
  275. self.model.eval()
  276. with torch.no_grad():
  277. pbar = tqdm(colour="red", desc=f"rank: {self.local_rank}, Validation Epoch: {epoch + 1}", total=len(self.dataloader_val),
  278. dynamic_ncols=True)
  279. speed_stats = {}
  280. time5 = time.perf_counter()
  281. for batch_idx, batch in enumerate(self.dataloader_val):
  282. time1 = time.perf_counter()
  283. speed_stats["data_load"] = f"{time1 - time5:0.3f}"
  284. batch = to_device(batch, self.device)
  285. time2 = time.perf_counter()
  286. retval = self.model(**batch)
  287. time3 = time.perf_counter()
  288. speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
  289. loss, stats, weight = retval
  290. stats = {k: v for k, v in stats.items() if v is not None}
  291. if self.use_ddp or self.use_fsdp:
  292. # Apply weighted averaging for loss and stats
  293. loss = (loss * weight.type(loss.dtype)).sum()
  294. # if distributed, this method can also apply all_reduce()
  295. stats, weight = recursive_average(stats, weight, distributed=True)
  296. # Now weight is summation over all workers
  297. loss /= weight
  298. # Multiply world_size because DistributedDataParallel
  299. # automatically normalizes the gradient by world_size.
  300. loss *= self.world_size
  301. # Scale the loss since we're not updating for every mini-batch
  302. loss = loss
  303. time4 = time.perf_counter()
  304. if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_val):
  305. pbar.update(self.log_interval)
  306. description = (
  307. f"rank: {self.local_rank}, "
  308. f"validation epoch: {epoch}/{self.max_epoch}, "
  309. f"step: {batch_idx+1}/{len(self.dataloader_val)}, "
  310. f"(loss: {loss.detach().cpu().item():.3f}), "
  311. f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
  312. f"{speed_stats}, "
  313. )
  314. pbar.set_description(description)
  315. if self.writer:
  316. self.writer.add_scalar(f"rank{self.local_rank}_Loss/val", loss.item(),
  317. epoch*len(self.dataloader_val) + batch_idx)
  318. for key, var in stats.items():
  319. self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', var.item(),
  320. epoch * len(self.dataloader_val) + batch_idx)
  321. for key, var in speed_stats.items():
  322. self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
  323. epoch * len(self.dataloader_val) + batch_idx)