| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import numpy as np
- import torch
- def get_human_readable_count(number: int) -> str:
- """Return human_readable_count
- Originated from:
- https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py
- Abbreviates an integer number with K, M, B, T for thousands, millions,
- billions and trillions, respectively.
- Examples:
- >>> get_human_readable_count(123)
- '123 '
- >>> get_human_readable_count(1234) # (one thousand)
- '1 K'
- >>> get_human_readable_count(2e6) # (two million)
- '2 M'
- >>> get_human_readable_count(3e9) # (three billion)
- '3 B'
- >>> get_human_readable_count(4e12) # (four trillion)
- '4 T'
- >>> get_human_readable_count(5e15) # (more than trillion)
- '5,000 T'
- Args:
- number: a positive integer number
- Return:
- A string formatted according to the pattern described above.
- """
- assert number >= 0
- labels = [" ", "K", "M", "B", "T"]
- num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
- num_groups = int(np.ceil(num_digits / 3))
- num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
- shift = -3 * (num_groups - 1)
- number = number * (10**shift)
- index = num_groups - 1
- return f"{number:.2f} {labels[index]}"
- def to_bytes(dtype) -> int:
- # torch.float16 -> 16
- return int(str(dtype)[-2:]) // 8
- def model_summary(model: torch.nn.Module) -> str:
- message = "Model structure:\n"
- message += str(model)
- tot_params = sum(p.numel() for p in model.parameters())
- num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
- tot_params = get_human_readable_count(tot_params)
- num_params = get_human_readable_count(num_params)
- message += "\n\nModel summary:\n"
- message += f" Class Name: {model.__class__.__name__}\n"
- message += f" Total Number of model parameters: {tot_params}\n"
- message += (
- f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
- )
- dtype = next(iter(model.parameters())).dtype
- message += f" Type: {dtype}"
- return message
|