average_nbest_models.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. from typing import Sequence
  5. from typing import Union
  6. import warnings
  7. import os
  8. from io import BytesIO
  9. import torch
  10. from typeguard import check_argument_types
  11. from typing import Collection
  12. from funasr.train.reporter import Reporter
  13. @torch.no_grad()
  14. def average_nbest_models(
  15. output_dir: Path,
  16. reporter: Reporter,
  17. best_model_criterion: Sequence[Sequence[str]],
  18. nbest: Union[Collection[int], int],
  19. suffix: Optional[str] = None,
  20. oss_bucket=None,
  21. pai_output_dir=None,
  22. ) -> None:
  23. """Generate averaged model from n-best models
  24. Args:
  25. output_dir: The directory contains the model file for each epoch
  26. reporter: Reporter instance
  27. best_model_criterion: Give criterions to decide the best model.
  28. e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
  29. nbest: Number of best model files to be averaged
  30. suffix: A suffix added to the averaged model file name
  31. """
  32. assert check_argument_types()
  33. if isinstance(nbest, int):
  34. nbests = [nbest]
  35. else:
  36. nbests = list(nbest)
  37. if len(nbests) == 0:
  38. warnings.warn("At least 1 nbest values are required")
  39. nbests = [1]
  40. if suffix is not None:
  41. suffix = suffix + "."
  42. else:
  43. suffix = ""
  44. # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
  45. nbest_epochs = [
  46. (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
  47. for ph, k, m in best_model_criterion
  48. if reporter.has(ph, k)
  49. ]
  50. _loaded = {}
  51. for ph, cr, epoch_and_values in nbest_epochs:
  52. _nbests = [i for i in nbests if i <= len(epoch_and_values)]
  53. if len(_nbests) == 0:
  54. _nbests = [1]
  55. for n in _nbests:
  56. if n == 0:
  57. continue
  58. elif n == 1:
  59. # The averaged model is same as the best model
  60. e, _ = epoch_and_values[0]
  61. op = output_dir / f"{e}epoch.pb"
  62. sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
  63. if sym_op.is_symlink() or sym_op.exists():
  64. sym_op.unlink()
  65. sym_op.symlink_to(op.name)
  66. else:
  67. op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
  68. logging.info(
  69. f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
  70. )
  71. avg = None
  72. # 2.a. Averaging model
  73. for e, _ in epoch_and_values[:n]:
  74. if e not in _loaded:
  75. if oss_bucket is None:
  76. _loaded[e] = torch.load(
  77. output_dir / f"{e}epoch.pb",
  78. map_location="cpu",
  79. )
  80. else:
  81. buffer = BytesIO(
  82. oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
  83. _loaded[e] = torch.load(buffer)
  84. states = _loaded[e]
  85. if avg is None:
  86. avg = states
  87. else:
  88. # Accumulated
  89. for k in avg:
  90. avg[k] = avg[k] + states[k]
  91. for k in avg:
  92. if str(avg[k].dtype).startswith("torch.int"):
  93. # For int type, not averaged, but only accumulated.
  94. # e.g. BatchNorm.num_batches_tracked
  95. # (If there are any cases that requires averaging
  96. # or the other reducing method, e.g. max/min, for integer type,
  97. # please report.)
  98. pass
  99. else:
  100. avg[k] = avg[k] / n
  101. # 2.b. Save the ave model and create a symlink
  102. if oss_bucket is None:
  103. torch.save(avg, op)
  104. else:
  105. buffer = BytesIO()
  106. torch.save(avg, buffer)
  107. oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
  108. buffer.getvalue())
  109. # 3. *.*.ave.pb is a symlink to the max ave model
  110. if oss_bucket is None:
  111. op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
  112. sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
  113. if sym_op.is_symlink() or sym_op.exists():
  114. sym_op.unlink()
  115. sym_op.symlink_to(op.name)