liuganghuggingface
commited on
Upload loader.py with huggingface_hub
Browse files
loader.py
CHANGED
@@ -18,7 +18,7 @@ def count_parameters(model):
|
|
18 |
|
19 |
return trainable_params, all_param
|
20 |
|
21 |
-
def load_graph_decoder(
|
22 |
model_config_path = f"{path}/config.yaml"
|
23 |
data_info_path = f"{path}/data.meta.json"
|
24 |
|
@@ -30,8 +30,6 @@ def load_graph_decoder(device, path='model_labeled'):
|
|
30 |
)
|
31 |
model.init_model(path)
|
32 |
model.disable_grads()
|
33 |
-
model.to(device)
|
34 |
-
print('Moving model to', device)
|
35 |
|
36 |
trainable_params, all_param = count_parameters(model)
|
37 |
param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
|
|
|
18 |
|
19 |
return trainable_params, all_param
|
20 |
|
21 |
+
def load_graph_decoder(path='model_labeled'):
|
22 |
model_config_path = f"{path}/config.yaml"
|
23 |
data_info_path = f"{path}/data.meta.json"
|
24 |
|
|
|
30 |
)
|
31 |
model.init_model(path)
|
32 |
model.disable_grads()
|
|
|
|
|
33 |
|
34 |
trainable_params, all_param = count_parameters(model)
|
35 |
param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
|