|
import torch |
|
import numpy as np |
|
from graph_decoder.diffusion_model import GraphDiT |
|
|
|
|
|
|
|
|
|
def count_parameters(model): |
|
r""" |
|
Returns the number of trainable parameters and number of all parameters in the model. |
|
""" |
|
trainable_params, all_param = 0, 0 |
|
for param in model.parameters(): |
|
num_params = param.numel() |
|
all_param += num_params |
|
if param.requires_grad: |
|
trainable_params += num_params |
|
|
|
return trainable_params, all_param |
|
|
|
def load_graph_decoder(device, path='model_labeled'): |
|
model_config_path = f"{path}/config.yaml" |
|
data_info_path = f"{path}/data.meta.json" |
|
|
|
model = GraphDiT( |
|
model_config_path=model_config_path, |
|
data_info_path=data_info_path, |
|
|
|
model_dtype=torch.float32, |
|
) |
|
model.init_model(path) |
|
model.disable_grads() |
|
model.to(device) |
|
print('Moving model to', device) |
|
|
|
trainable_params, all_param = count_parameters(model) |
|
param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( |
|
path, trainable_params, all_param, 100 * trainable_params / all_param |
|
) |
|
print(param_stats) |
|
return model |
|
|