Saturday, November 16, 2024

Thursday, October 10, 2024

PyTorch - get the total number of model parameter

Total number of model parameters


1. simple version

pytorch_total_params = sum(p.numel() for p in model.parameters())

2. listed version

def count_parameters(model):

  str_name = "name"

  str_parameter = "parameter"

  print(f"{str_name:50s}: {str_parameter:10s}")

  total_params = 0

  for name, parameter in model.named_parameters():

    if not parameter.requires_grad:

      continue

    params = parameter.numel()

    print(f"{name:50s}: {params:10s}")

    total_params += params

  print(f"Total Trainable Params: {total_params}")

  return total_params