average_nbest_models.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. from typing import Sequence
  5. from typing import Union
  6. import warnings
  7. import os
  8. from io import BytesIO
  9. import torch
  10. from typing import Collection
  11. import os
  12. import torch
  13. import re
  14. from collections import OrderedDict
  15. from functools import cmp_to_key
  16. # @torch.no_grad()
  17. # def average_nbest_models(
  18. # output_dir: Path,
  19. # best_model_criterion: Sequence[Sequence[str]],
  20. # nbest: Union[Collection[int], int],
  21. # suffix: Optional[str] = None,
  22. # oss_bucket=None,
  23. # pai_output_dir=None,
  24. # ) -> None:
  25. # """Generate averaged model from n-best models
  26. #
  27. # Args:
  28. # output_dir: The directory contains the model file for each epoch
  29. # reporter: Reporter instance
  30. # best_model_criterion: Give criterions to decide the best model.
  31. # e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
  32. # nbest: Number of best model files to be averaged
  33. # suffix: A suffix added to the averaged model file name
  34. # """
  35. # if isinstance(nbest, int):
  36. # nbests = [nbest]
  37. # else:
  38. # nbests = list(nbest)
  39. # if len(nbests) == 0:
  40. # warnings.warn("At least 1 nbest values are required")
  41. # nbests = [1]
  42. # if suffix is not None:
  43. # suffix = suffix + "."
  44. # else:
  45. # suffix = ""
  46. #
  47. # # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
  48. # nbest_epochs = [
  49. # (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
  50. # for ph, k, m in best_model_criterion
  51. # if reporter.has(ph, k)
  52. # ]
  53. #
  54. # _loaded = {}
  55. # for ph, cr, epoch_and_values in nbest_epochs:
  56. # _nbests = [i for i in nbests if i <= len(epoch_and_values)]
  57. # if len(_nbests) == 0:
  58. # _nbests = [1]
  59. #
  60. # for n in _nbests:
  61. # if n == 0:
  62. # continue
  63. # elif n == 1:
  64. # # The averaged model is same as the best model
  65. # e, _ = epoch_and_values[0]
  66. # op = output_dir / f"{e}epoch.pb"
  67. # sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
  68. # if sym_op.is_symlink() or sym_op.exists():
  69. # sym_op.unlink()
  70. # sym_op.symlink_to(op.name)
  71. # else:
  72. # op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
  73. # logging.info(
  74. # f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
  75. # )
  76. #
  77. # avg = None
  78. # # 2.a. Averaging model
  79. # for e, _ in epoch_and_values[:n]:
  80. # if e not in _loaded:
  81. # if oss_bucket is None:
  82. # _loaded[e] = torch.load(
  83. # output_dir / f"{e}epoch.pb",
  84. # map_location="cpu",
  85. # )
  86. # else:
  87. # buffer = BytesIO(
  88. # oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
  89. # _loaded[e] = torch.load(buffer)
  90. # states = _loaded[e]
  91. #
  92. # if avg is None:
  93. # avg = states
  94. # else:
  95. # # Accumulated
  96. # for k in avg:
  97. # avg[k] = avg[k] + states[k]
  98. # for k in avg:
  99. # if str(avg[k].dtype).startswith("torch.int"):
  100. # # For int type, not averaged, but only accumulated.
  101. # # e.g. BatchNorm.num_batches_tracked
  102. # # (If there are any cases that requires averaging
  103. # # or the other reducing method, e.g. max/min, for integer type,
  104. # # please report.)
  105. # pass
  106. # else:
  107. # avg[k] = avg[k] / n
  108. #
  109. # # 2.b. Save the ave model and create a symlink
  110. # if oss_bucket is None:
  111. # torch.save(avg, op)
  112. # else:
  113. # buffer = BytesIO()
  114. # torch.save(avg, buffer)
  115. # oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
  116. # buffer.getvalue())
  117. #
  118. # # 3. *.*.ave.pb is a symlink to the max ave model
  119. # if oss_bucket is None:
  120. # op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
  121. # sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
  122. # if sym_op.is_symlink() or sym_op.exists():
  123. # sym_op.unlink()
  124. # sym_op.symlink_to(op.name)
  125. def _get_checkpoint_paths(output_dir: str, last_n: int=5):
  126. """
  127. Get the paths of the last 'last_n' checkpoints by parsing filenames
  128. in the output directory.
  129. """
  130. # List all files in the output directory
  131. files = os.listdir(output_dir)
  132. # Filter out checkpoint files and extract epoch numbers
  133. checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
  134. # Sort files by epoch number in descending order
  135. checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
  136. # Get the last 'last_n' checkpoint paths
  137. checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
  138. return checkpoint_paths
  139. @torch.no_grad()
  140. def average_checkpoints(output_dir: str, last_n: int=5):
  141. """
  142. Average the last 'last_n' checkpoints' model state_dicts.
  143. If a tensor is of type torch.int, perform sum instead of average.
  144. """
  145. checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
  146. state_dicts = []
  147. # Load state_dicts from checkpoints
  148. for path in checkpoint_paths:
  149. if os.path.isfile(path):
  150. state_dicts.append(torch.load(path, map_location='cpu')['state_dict'])
  151. else:
  152. print(f"Checkpoint file {path} not found.")
  153. continue
  154. # Check if we have any state_dicts to average
  155. if not state_dicts:
  156. raise RuntimeError("No checkpoints found for averaging.")
  157. # Average or sum weights
  158. avg_state_dict = OrderedDict()
  159. for key in state_dicts[0].keys():
  160. tensors = [state_dict[key].cpu() for state_dict in state_dicts]
  161. # Check the type of the tensor
  162. if str(tensors[0].dtype).startswith("torch.int"):
  163. # Perform sum for integer tensors
  164. summed_tensor = sum(tensors)
  165. avg_state_dict[key] = summed_tensor
  166. else:
  167. # Perform average for other types of tensors
  168. stacked_tensors = torch.stack(tensors)
  169. avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
  170. torch.save({'state_dict': avg_state_dict}, os.path.join(output_dir, f"model.pt.avg{last_n}"))
  171. return avg_state_dict