average_nbest_models.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. from typing import Sequence
  5. from typing import Union
  6. import warnings
  7. import os
  8. from io import BytesIO
  9. import torch
  10. from typing import Collection
  11. from funasr.train.reporter import Reporter
  12. @torch.no_grad()
  13. def average_nbest_models(
  14. output_dir: Path,
  15. reporter: Reporter,
  16. best_model_criterion: Sequence[Sequence[str]],
  17. nbest: Union[Collection[int], int],
  18. suffix: Optional[str] = None,
  19. oss_bucket=None,
  20. pai_output_dir=None,
  21. ) -> None:
  22. """Generate averaged model from n-best models
  23. Args:
  24. output_dir: The directory contains the model file for each epoch
  25. reporter: Reporter instance
  26. best_model_criterion: Give criterions to decide the best model.
  27. e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
  28. nbest: Number of best model files to be averaged
  29. suffix: A suffix added to the averaged model file name
  30. """
  31. if isinstance(nbest, int):
  32. nbests = [nbest]
  33. else:
  34. nbests = list(nbest)
  35. if len(nbests) == 0:
  36. warnings.warn("At least 1 nbest values are required")
  37. nbests = [1]
  38. if suffix is not None:
  39. suffix = suffix + "."
  40. else:
  41. suffix = ""
  42. # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
  43. nbest_epochs = [
  44. (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
  45. for ph, k, m in best_model_criterion
  46. if reporter.has(ph, k)
  47. ]
  48. _loaded = {}
  49. for ph, cr, epoch_and_values in nbest_epochs:
  50. _nbests = [i for i in nbests if i <= len(epoch_and_values)]
  51. if len(_nbests) == 0:
  52. _nbests = [1]
  53. for n in _nbests:
  54. if n == 0:
  55. continue
  56. elif n == 1:
  57. # The averaged model is same as the best model
  58. e, _ = epoch_and_values[0]
  59. op = output_dir / f"{e}epoch.pb"
  60. sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
  61. if sym_op.is_symlink() or sym_op.exists():
  62. sym_op.unlink()
  63. sym_op.symlink_to(op.name)
  64. else:
  65. op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
  66. logging.info(
  67. f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
  68. )
  69. avg = None
  70. # 2.a. Averaging model
  71. for e, _ in epoch_and_values[:n]:
  72. if e not in _loaded:
  73. if oss_bucket is None:
  74. _loaded[e] = torch.load(
  75. output_dir / f"{e}epoch.pb",
  76. map_location="cpu",
  77. )
  78. else:
  79. buffer = BytesIO(
  80. oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
  81. _loaded[e] = torch.load(buffer)
  82. states = _loaded[e]
  83. if avg is None:
  84. avg = states
  85. else:
  86. # Accumulated
  87. for k in avg:
  88. avg[k] = avg[k] + states[k]
  89. for k in avg:
  90. if str(avg[k].dtype).startswith("torch.int"):
  91. # For int type, not averaged, but only accumulated.
  92. # e.g. BatchNorm.num_batches_tracked
  93. # (If there are any cases that requires averaging
  94. # or the other reducing method, e.g. max/min, for integer type,
  95. # please report.)
  96. pass
  97. else:
  98. avg[k] = avg[k] / n
  99. # 2.b. Save the ave model and create a symlink
  100. if oss_bucket is None:
  101. torch.save(avg, op)
  102. else:
  103. buffer = BytesIO()
  104. torch.save(avg, buffer)
  105. oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
  106. buffer.getvalue())
  107. # 3. *.*.ave.pb is a symlink to the max ave model
  108. if oss_bucket is None:
  109. op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
  110. sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
  111. if sym_op.is_symlink() or sym_op.exists():
  112. sym_op.unlink()
  113. sym_op.symlink_to(op.name)