mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
|
|
def get_human_readable_count(number: int) -> str:
|
|
"""Return human_readable_count
|
|
|
|
Originated from:
|
|
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py
|
|
|
|
Abbreviates an integer number with K, M, B, T for thousands, millions,
|
|
billions and trillions, respectively.
|
|
Examples:
|
|
>>> get_human_readable_count(123)
|
|
'123 '
|
|
>>> get_human_readable_count(1234) # (one thousand)
|
|
'1 K'
|
|
>>> get_human_readable_count(2e6) # (two million)
|
|
'2 M'
|
|
>>> get_human_readable_count(3e9) # (three billion)
|
|
'3 B'
|
|
>>> get_human_readable_count(4e12) # (four trillion)
|
|
'4 T'
|
|
>>> get_human_readable_count(5e15) # (more than trillion)
|
|
'5,000 T'
|
|
Args:
|
|
number: a positive integer number
|
|
Return:
|
|
A string formatted according to the pattern described above.
|
|
"""
|
|
assert number >= 0
|
|
labels = [" ", "K", "M", "B", "T"]
|
|
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
|
|
num_groups = int(np.ceil(num_digits / 3))
|
|
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
|
|
shift = -3 * (num_groups - 1)
|
|
number = number * (10**shift)
|
|
index = num_groups - 1
|
|
return f"{number:.2f} {labels[index]}"
|
|
|
|
|
|
def to_bytes(dtype) -> int:
|
|
# torch.float16 -> 16
|
|
return int(str(dtype)[-2:]) // 8
|
|
|
|
|
|
def model_summary(model: torch.nn.Module) -> str:
|
|
message = "Model structure:\n"
|
|
message += str(model)
|
|
|
|
tot_params, num_params = 0, 0
|
|
for name, param in model.named_parameters():
|
|
print(
|
|
"name: {}, dtype: {}, device: {}, trainable: {}, shape: {}, numel: {}".format(
|
|
name, param.dtype, param.device, param.requires_grad, param.shape, param.numel()
|
|
)
|
|
)
|
|
tot_params += param.numel()
|
|
if param.requires_grad:
|
|
num_params += param.numel()
|
|
|
|
percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
|
|
tot_params = get_human_readable_count(tot_params)
|
|
num_params = get_human_readable_count(num_params)
|
|
message += "\n\nModel summary:\n"
|
|
message += f" Class Name: {model.__class__.__name__}\n"
|
|
message += f" Total Number of model parameters: {tot_params}\n"
|
|
message += f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
|
|
|
|
dtype = next(iter(model.parameters())).dtype
|
|
message += f" Type: {dtype}"
|
|
return message
|