trainer.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. import os
  2. import time
  3. import torch
  4. import logging
  5. from tqdm import tqdm
  6. from datetime import datetime
  7. import torch.distributed as dist
  8. from torch.cuda.amp import autocast, GradScaler
  9. from contextlib import nullcontext, contextmanager
  10. # from torch.utils.tensorboard import SummaryWriter
  11. from tensorboardX import SummaryWriter
  12. from pathlib import Path
  13. from funasr.train_utils.device_funcs import to_device
  14. from funasr.train_utils.recursive_op import recursive_average
  15. from funasr.train_utils.average_nbest_models import average_checkpoints
  16. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  17. @contextmanager
  18. def maybe_autocast(enabled):
  19. if enabled:
  20. with autocast():
  21. yield
  22. else:
  23. yield
  24. class Trainer:
  25. """
  26. A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
  27. and optionally resuming from a saved checkpoint.
  28. Attributes:
  29. max_epoch (int): Maximum number of epochs for training.
  30. model (torch.nn.Module): The model to be trained.
  31. optim (torch.optim.Optimizer): The optimizer to use for training.
  32. scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
  33. dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
  34. dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
  35. output_dir (str): Directory where model checkpoints will be saved.
  36. resume (str, optional): Path to a checkpoint to resume training from.
  37. """
  38. def __init__(self, model,
  39. optim,
  40. scheduler,
  41. dataloader_train,
  42. dataloader_val,
  43. local_rank,
  44. use_ddp: bool = False,
  45. use_fsdp: bool = False,
  46. use_fp16: bool = False,
  47. output_dir: str="./",
  48. **kwargs):
  49. """
  50. Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
  51. Args:
  52. model (torch.nn.Module): The model to be trained.
  53. optim (torch.optim.Optimizer): The optimizer to use for training.
  54. scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
  55. dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
  56. dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
  57. **kwargs: Additional keyword arguments:
  58. max_epoch (int): The maximum number of epochs for training.
  59. output_dir (str): The directory where model checkpoints will be saved. Default is './'.
  60. resume (str, optional): The file path to a checkpoint to resume training from.
  61. """
  62. self.model = model
  63. self.optim = optim
  64. self.scheduler = scheduler
  65. self.dataloader_train = dataloader_train
  66. self.dataloader_val = dataloader_val
  67. self.output_dir = output_dir
  68. self.resume = kwargs.get('resume', True)
  69. self.start_epoch = 0
  70. self.max_epoch = kwargs.get('max_epoch', 100)
  71. self.local_rank = local_rank
  72. self.use_ddp = use_ddp
  73. self.use_fsdp = use_fsdp
  74. self.device = next(model.parameters()).device
  75. self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
  76. self.kwargs = kwargs
  77. self.log_interval = kwargs.get("log_interval", 50)
  78. self.batch_total = 0
  79. self.use_fp16 = use_fp16
  80. self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
  81. scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
  82. scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
  83. self.scaler = scaler
  84. try:
  85. rank = dist.get_rank()
  86. world_size = dist.get_world_size()
  87. except:
  88. rank = 0
  89. world_size = 1
  90. logging.warning("distributed is not initialized, only single shard")
  91. self.rank = rank
  92. self.world_size = world_size
  93. os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
  94. self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
  95. def _save_checkpoint(self, epoch):
  96. """
  97. Saves a checkpoint containing the model's state, the optimizer's state,
  98. and the scheduler's state at the end of the given epoch. This method is
  99. intended to be called at the end of each epoch to save the training progress.
  100. Args:
  101. epoch (int): The epoch number at which the checkpoint is being saved.
  102. """
  103. state = {
  104. 'epoch': epoch,
  105. 'state_dict': self.model.state_dict(),
  106. 'optimizer': self.optim.state_dict(),
  107. 'scheduler': self.scheduler.state_dict(),
  108. }
  109. if self.scaler:
  110. state["scaler_state"] = self.scaler.state_dict()
  111. # Create output directory if it does not exist
  112. os.makedirs(self.output_dir, exist_ok=True)
  113. filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
  114. torch.save(state, filename)
  115. print(f'\nCheckpoint saved to {filename}\n')
  116. latest = Path(os.path.join(self.output_dir, f'model.pt'))
  117. torch.save(state, latest)
  118. def _resume_checkpoint(self, resume_path):
  119. """
  120. Resumes training from a checkpoint at the given file path.
  121. Loads the model's state, the optimizer's state, and the scheduler's state.
  122. Args:
  123. resume_path (str): The file path to the checkpoint to resume from.
  124. """
  125. ckpt = os.path.join(resume_path, "model.pt")
  126. if os.path.isfile(ckpt):
  127. checkpoint = torch.load(ckpt)
  128. self.start_epoch = checkpoint['epoch'] + 1
  129. # self.model.load_state_dict(checkpoint['state_dict'])
  130. src_state = checkpoint['state_dict']
  131. dst_state = self.model.state_dict()
  132. for k in dst_state.keys():
  133. if not k.startswith("module.") and "module."+k in src_state.keys():
  134. k_ddp = "module."+k
  135. else:
  136. k_ddp = k
  137. if k_ddp in src_state.keys():
  138. dst_state[k] = src_state[k_ddp]
  139. else:
  140. print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
  141. self.model.load_state_dict(dst_state)
  142. self.optim.load_state_dict(checkpoint['optimizer'])
  143. self.scheduler.load_state_dict(checkpoint['scheduler'])
  144. if self.scaler and 'scaler_state' in checkpoint:
  145. self.scaler.load_state_dict(checkpoint['scaler_state'])
  146. print(f"Checkpoint loaded successfully from '{ckpt}'")
  147. else:
  148. print(f"No checkpoint found at '{ckpt}', starting from scratch")
  149. if self.use_ddp or self.use_fsdp:
  150. dist.barrier()
  151. def run(self):
  152. """
  153. Starts the training process, iterating over epochs, training the model,
  154. and saving checkpoints at the end of each epoch.
  155. """
  156. if self.resume:
  157. self._resume_checkpoint(self.output_dir)
  158. for epoch in range(self.start_epoch, self.max_epoch + 1):
  159. time1 = time.perf_counter()
  160. self._train_epoch(epoch)
  161. if self.use_ddp or self.use_fsdp:
  162. dist.barrier()
  163. self._validate_epoch(epoch)
  164. if self.use_ddp or self.use_fsdp:
  165. dist.barrier()
  166. if self.rank == 0:
  167. self._save_checkpoint(epoch)
  168. if self.use_ddp or self.use_fsdp:
  169. dist.barrier()
  170. self.scheduler.step()
  171. time2 = time.perf_counter()
  172. time_escaped = (time2 - time1)/3600.0
  173. print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
  174. if self.rank == 0:
  175. average_checkpoints(self.output_dir, self.avg_nbest_model)
  176. if self.use_ddp or self.use_fsdp:
  177. dist.barrier()
  178. if self.writer:
  179. self.writer.close()
  180. def _train_epoch(self, epoch):
  181. """
  182. Defines the training process for a single epoch with gradient accumulation.
  183. Args:
  184. epoch (int): The current epoch number.
  185. """
  186. self.model.train()
  187. pbar = tqdm(colour="blue", desc=f"rank: {self.local_rank}, Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
  188. dynamic_ncols=True)
  189. # Set the number of steps for gradient accumulation
  190. accum_grad = self.kwargs.get("accum_grad", 1)
  191. # Initialize the gradient accumulation
  192. self.optim.zero_grad()
  193. speed_stats = {}
  194. time5 = time.perf_counter()
  195. for batch_idx, batch in enumerate(self.dataloader_train):
  196. self.batch_total += 1
  197. time1 = time.perf_counter()
  198. speed_stats["data_load"] = f"{time1-time5:0.3f}"
  199. batch = to_device(batch, self.device)
  200. my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
  201. with my_context():
  202. time2 = time.perf_counter()
  203. with maybe_autocast(self.use_fp16):
  204. retval = self.model(**batch)
  205. if self.disable_gpu_cache: torch.cuda.empty_cache()
  206. time3 = time.perf_counter()
  207. speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
  208. loss, stats, weight = retval
  209. stats = {k: v for k, v in stats.items() if v is not None}
  210. if self.use_ddp or self.use_fsdp:
  211. # Apply weighted averaging for loss and stats
  212. loss = (loss * weight.type(loss.dtype)).sum()
  213. # if distributed, this method can also apply all_reduce()
  214. stats, weight = recursive_average(stats, weight, distributed=True)
  215. # Now weight is summation over all workers
  216. loss /= weight
  217. # Multiply world_size because DistributedDataParallel
  218. # automatically normalizes the gradient by world_size.
  219. loss *= self.world_size
  220. # Scale the loss since we're not updating for every mini-batch
  221. loss = loss / accum_grad
  222. if self.use_fp16:
  223. self.scaler.scale(loss).backward()
  224. else:
  225. loss.backward()
  226. time4 = time.perf_counter()
  227. speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
  228. # Perform an optimizer step only after accumulating enough gradients
  229. if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
  230. # Perform gradient clipping if it is set
  231. if self.kwargs.get("grad_clip", None) is not None:
  232. grad_norm = torch.nn.utils.clip_grad_norm_(
  233. self.model.parameters(),
  234. max_norm=self.kwargs.get("grad_clip", 10.0),
  235. norm_type=self.kwargs.get("grad_clip_type", 2.0),
  236. )
  237. if not torch.isfinite(grad_norm):
  238. logging.warning(
  239. f"The grad norm is {grad_norm}. Skipping updating the model."
  240. )
  241. self.optim.zero_grad() # Reset gradients
  242. continue
  243. # Execute an optimization step (update model parameters)
  244. if self.use_ddp or self.use_fsdp:
  245. dist.barrier()
  246. if self.use_fp16:
  247. self.scaler.step(self.optim)
  248. self.scaler.update()
  249. else:
  250. self.optim.step()
  251. self.scheduler.step()
  252. # Clear gradients for the next accumulation stage
  253. self.optim.zero_grad(set_to_none=True)
  254. total_time = f"{time.perf_counter() - time5:0.3f}"
  255. time5 = time.perf_counter()
  256. speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
  257. speed_stats["total_time"] = total_time
  258. if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_train):
  259. pbar.update(self.log_interval)
  260. gpu_info = "GPU, memory: {:.3f} GB, " \
  261. "{:.3f} GB, "\
  262. "{:.3f} GB, "\
  263. "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
  264. torch.cuda.max_memory_allocated()/1024/1024/1024,
  265. torch.cuda.memory_reserved()/1024/1024/1024,
  266. torch.cuda.max_memory_reserved()/1024/1024/1024,
  267. )
  268. lr = self.scheduler.get_last_lr()[0]
  269. time_now = datetime.now()
  270. time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
  271. description = (
  272. f"{time_now}, "
  273. f"rank: {self.local_rank}, "
  274. f"epoch: {epoch}/{self.max_epoch}, "
  275. f"step: {batch_idx+1}/{len(self.dataloader_train)}, total step: {self.batch_total}, "
  276. f"(loss: {loss.detach().cpu().item():.3f}), "
  277. f"(lr: {lr:.3e}), "
  278. f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
  279. f"{speed_stats}, "
  280. f"{gpu_info}"
  281. )
  282. pbar.set_description(description)
  283. if self.writer:
  284. self.writer.add_scalar(f'rank{self.local_rank}_Loss/train', loss.item(), self.batch_total)
  285. self.writer.add_scalar(f'rank{self.local_rank}_lr/train', lr, self.batch_total)
  286. for key, var in stats.items():
  287. self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', var.item(), self.batch_total)
  288. for key, var in speed_stats.items():
  289. self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total)
  290. pbar.close()
  291. def _validate_epoch(self, epoch):
  292. """
  293. Defines the validation process for a single epoch.
  294. Should be implemented with the actual model validation steps.
  295. Args:
  296. epoch (int): The current epoch number.
  297. """
  298. self.model.eval()
  299. with torch.no_grad():
  300. pbar = tqdm(colour="red", desc=f"rank: {self.local_rank}, Validation Epoch: {epoch + 1}", total=len(self.dataloader_val),
  301. dynamic_ncols=True)
  302. speed_stats = {}
  303. time5 = time.perf_counter()
  304. for batch_idx, batch in enumerate(self.dataloader_val):
  305. time1 = time.perf_counter()
  306. speed_stats["data_load"] = f"{time1 - time5:0.3f}"
  307. batch = to_device(batch, self.device)
  308. time2 = time.perf_counter()
  309. retval = self.model(**batch)
  310. time3 = time.perf_counter()
  311. speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
  312. loss, stats, weight = retval
  313. stats = {k: v for k, v in stats.items() if v is not None}
  314. if self.use_ddp or self.use_fsdp:
  315. # Apply weighted averaging for loss and stats
  316. loss = (loss * weight.type(loss.dtype)).sum()
  317. # if distributed, this method can also apply all_reduce()
  318. stats, weight = recursive_average(stats, weight, distributed=True)
  319. # Now weight is summation over all workers
  320. loss /= weight
  321. # Multiply world_size because DistributedDataParallel
  322. # automatically normalizes the gradient by world_size.
  323. loss *= self.world_size
  324. # Scale the loss since we're not updating for every mini-batch
  325. loss = loss
  326. time4 = time.perf_counter()
  327. if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_val):
  328. pbar.update(self.log_interval)
  329. time_now = datetime.now()
  330. time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
  331. description = (
  332. f"{time_now}, "
  333. f"rank: {self.local_rank}, "
  334. f"validation epoch: {epoch}/{self.max_epoch}, "
  335. f"step: {batch_idx+1}/{len(self.dataloader_val)}, "
  336. f"(loss: {loss.detach().cpu().item():.3f}), "
  337. f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
  338. f"{speed_stats}, "
  339. )
  340. pbar.set_description(description)
  341. if self.writer:
  342. self.writer.add_scalar(f"rank{self.local_rank}_Loss/val", loss.item(),
  343. epoch*len(self.dataloader_val) + batch_idx)
  344. for key, var in stats.items():
  345. self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', var.item(),
  346. epoch * len(self.dataloader_val) + batch_idx)
  347. for key, var in speed_stats.items():
  348. self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
  349. epoch * len(self.dataloader_val) + batch_idx)