liuganghuggingface's picture
Update loader.py
debc746 verified
raw
history blame
1.16 kB
import torch
import numpy as np
from graph_decoder.diffusion_model import GraphDiT
def count_parameters(model):
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return trainable_params, all_param
def load_graph_decoder(path='model_labeled'):
model_config_path = f"{path}/config.yaml"
data_info_path = f"{path}/data.meta.json"
model = GraphDiT(
model_config_path=model_config_path,
data_info_path=data_info_path,
# model_dtype=torch.float16,
model_dtype=torch.float32,
)
model.init_model(path)
model.disable_grads()
trainable_params, all_param = count_parameters(model)
param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
path, trainable_params, all_param, 100 * trainable_params / all_param
)
print(param_stats)
return model