model_summary.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import numpy as np
  2. import torch
  3. def get_human_readable_count(number: int) -> str:
  4. """Return human_readable_count
  5. Originated from:
  6. https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py
  7. Abbreviates an integer number with K, M, B, T for thousands, millions,
  8. billions and trillions, respectively.
  9. Examples:
  10. >>> get_human_readable_count(123)
  11. '123 '
  12. >>> get_human_readable_count(1234) # (one thousand)
  13. '1 K'
  14. >>> get_human_readable_count(2e6) # (two million)
  15. '2 M'
  16. >>> get_human_readable_count(3e9) # (three billion)
  17. '3 B'
  18. >>> get_human_readable_count(4e12) # (four trillion)
  19. '4 T'
  20. >>> get_human_readable_count(5e15) # (more than trillion)
  21. '5,000 T'
  22. Args:
  23. number: a positive integer number
  24. Return:
  25. A string formatted according to the pattern described above.
  26. """
  27. assert number >= 0
  28. labels = [" ", "K", "M", "B", "T"]
  29. num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
  30. num_groups = int(np.ceil(num_digits / 3))
  31. num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
  32. shift = -3 * (num_groups - 1)
  33. number = number * (10**shift)
  34. index = num_groups - 1
  35. return f"{number:.2f} {labels[index]}"
  36. def to_bytes(dtype) -> int:
  37. # torch.float16 -> 16
  38. return int(str(dtype)[-2:]) // 8
  39. def model_summary(model: torch.nn.Module) -> str:
  40. message = "Model structure:\n"
  41. message += str(model)
  42. tot_params = sum(p.numel() for p in model.parameters())
  43. num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  44. percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
  45. tot_params = get_human_readable_count(tot_params)
  46. num_params = get_human_readable_count(num_params)
  47. message += "\n\nModel summary:\n"
  48. message += f" Class Name: {model.__class__.__name__}\n"
  49. message += f" Total Number of model parameters: {tot_params}\n"
  50. message += (
  51. f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
  52. )
  53. dtype = next(iter(model.parameters())).dtype
  54. message += f" Type: {dtype}"
  55. return message