File size: 1,156 Bytes
43637a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157d7fc
43637a6
 
 
 
 
 
 
 
 
 
 
 
10c5b4b
 
 
 
 
43637a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import numpy as np
from graph_decoder.diffusion_model import GraphDiT

def count_parameters(model):
    r"""
    Returns the number of trainable parameters and number of all parameters in the model.
    """
    trainable_params, all_param = 0, 0
    for param in model.parameters():
        num_params = param.numel()
        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params

    return trainable_params, all_param

def load_graph_decoder(path='model_labeled'):
    model_config_path = f"{path}/config.yaml"
    data_info_path = f"{path}/data.meta.json"

    model = GraphDiT(
        model_config_path=model_config_path,
        data_info_path=data_info_path,
        # model_dtype=torch.float16,
        model_dtype=torch.float32,
    )
    model.init_model(path)
    model.disable_grads()

    trainable_params, all_param = count_parameters(model)
    param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
        path, trainable_params, all_param, 100 * trainable_params / all_param
    )
    print(param_stats)
    return model