File size: 1,156 Bytes
43637a6 157d7fc 43637a6 10c5b4b 43637a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch
import numpy as np
from graph_decoder.diffusion_model import GraphDiT
def count_parameters(model):
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return trainable_params, all_param
def load_graph_decoder(path='model_labeled'):
model_config_path = f"{path}/config.yaml"
data_info_path = f"{path}/data.meta.json"
model = GraphDiT(
model_config_path=model_config_path,
data_info_path=data_info_path,
# model_dtype=torch.float16,
model_dtype=torch.float32,
)
model.init_model(path)
model.disable_grads()
trainable_params, all_param = count_parameters(model)
param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
path, trainable_params, all_param, 100 * trainable_params / all_param
)
print(param_stats)
return model
|