liuganghuggingface commited on
Commit
43637a6
·
verified ·
1 Parent(s): ca17d9e

Upload loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. loader.py +41 -0
loader.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from graph_decoder.diffusion_model import GraphDiT
4
+
5
+ # model_state = load_model()
6
+ # generate_graph(2.5, 15.4, 21.0, 1.5, 2.8, 2, 0, 1, model_state, 50)
7
+
8
+ def count_parameters(model):
9
+ r"""
10
+ Returns the number of trainable parameters and number of all parameters in the model.
11
+ """
12
+ trainable_params, all_param = 0, 0
13
+ for param in model.parameters():
14
+ num_params = param.numel()
15
+ all_param += num_params
16
+ if param.requires_grad:
17
+ trainable_params += num_params
18
+
19
+ return trainable_params, all_param
20
+
21
+ def load_graph_decoder(device, path='model_labeled'):
22
+ model_config_path = f"{path}/config.yaml"
23
+ data_info_path = f"{path}/data.meta.json"
24
+
25
+ model = GraphDiT(
26
+ model_config_path=model_config_path,
27
+ data_info_path=data_info_path,
28
+ # model_dtype=torch.float16,
29
+ model_dtype=torch.float32,
30
+ )
31
+ model.init_model(path)
32
+ model.disable_grads()
33
+ model.to(device)
34
+ print('Moving model to', device)
35
+
36
+ trainable_params, all_param = count_parameters(model)
37
+ param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
38
+ path, trainable_params, all_param, 100 * trainable_params / all_param
39
+ )
40
+ print(param_stats)
41
+ return model