import torch import numpy as np from graph_decoder.diffusion_model import GraphDiT # model_state = load_model() # generate_graph(2.5, 15.4, 21.0, 1.5, 2.8, 2, 0, 1, model_state, 50) def count_parameters(model): r""" Returns the number of trainable parameters and number of all parameters in the model. """ trainable_params, all_param = 0, 0 for param in model.parameters(): num_params = param.numel() all_param += num_params if param.requires_grad: trainable_params += num_params return trainable_params, all_param def load_graph_decoder(device, path='model_labeled'): model_config_path = f"{path}/config.yaml" data_info_path = f"{path}/data.meta.json" model = GraphDiT( model_config_path=model_config_path, data_info_path=data_info_path, # model_dtype=torch.float16, model_dtype=torch.float32, ) model.init_model(path) model.disable_grads() model.to(device) print('Moving model to', device) trainable_params, all_param = count_parameters(model) param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( path, trainable_params, all_param, 100 * trainable_params / all_param ) print(param_stats) return model