model_summary.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import humanfriendly
  2. import numpy as np
  3. import torch
  4. def get_human_readable_count(number: int) -> str:
  5. """Return human_readable_count
  6. Originated from:
  7. https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py
  8. Abbreviates an integer number with K, M, B, T for thousands, millions,
  9. billions and trillions, respectively.
  10. Examples:
  11. >>> get_human_readable_count(123)
  12. '123 '
  13. >>> get_human_readable_count(1234) # (one thousand)
  14. '1 K'
  15. >>> get_human_readable_count(2e6) # (two million)
  16. '2 M'
  17. >>> get_human_readable_count(3e9) # (three billion)
  18. '3 B'
  19. >>> get_human_readable_count(4e12) # (four trillion)
  20. '4 T'
  21. >>> get_human_readable_count(5e15) # (more than trillion)
  22. '5,000 T'
  23. Args:
  24. number: a positive integer number
  25. Return:
  26. A string formatted according to the pattern described above.
  27. """
  28. assert number >= 0
  29. labels = [" ", "K", "M", "B", "T"]
  30. num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
  31. num_groups = int(np.ceil(num_digits / 3))
  32. num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
  33. shift = -3 * (num_groups - 1)
  34. number = number * (10**shift)
  35. index = num_groups - 1
  36. return f"{number:.2f} {labels[index]}"
  37. def to_bytes(dtype) -> int:
  38. # torch.float16 -> 16
  39. return int(str(dtype)[-2:]) // 8
  40. def model_summary(model: torch.nn.Module) -> str:
  41. message = "Model structure:\n"
  42. message += str(model)
  43. tot_params = sum(p.numel() for p in model.parameters())
  44. num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  45. percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
  46. tot_params = get_human_readable_count(tot_params)
  47. num_params = get_human_readable_count(num_params)
  48. message += "\n\nModel summary:\n"
  49. message += f" Class Name: {model.__class__.__name__}\n"
  50. message += f" Total Number of model parameters: {tot_params}\n"
  51. message += (
  52. f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
  53. )
  54. num_bytes = humanfriendly.format_size(
  55. sum(
  56. p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad
  57. )
  58. )
  59. message += f" Size: {num_bytes}\n"
  60. dtype = next(iter(model.parameters())).dtype
  61. message += f" Type: {dtype}"
  62. return message