liuganghuggingface commited on
Commit
157d7fc
·
verified ·
1 Parent(s): f72298f

Upload loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. loader.py +1 -3
loader.py CHANGED
@@ -18,7 +18,7 @@ def count_parameters(model):
18
 
19
  return trainable_params, all_param
20
 
21
- def load_graph_decoder(device, path='model_labeled'):
22
  model_config_path = f"{path}/config.yaml"
23
  data_info_path = f"{path}/data.meta.json"
24
 
@@ -30,8 +30,6 @@ def load_graph_decoder(device, path='model_labeled'):
30
  )
31
  model.init_model(path)
32
  model.disable_grads()
33
- model.to(device)
34
- print('Moving model to', device)
35
 
36
  trainable_params, all_param = count_parameters(model)
37
  param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
 
18
 
19
  return trainable_params, all_param
20
 
21
+ def load_graph_decoder(path='model_labeled'):
22
  model_config_path = f"{path}/config.yaml"
23
  data_info_path = f"{path}/data.meta.json"
24
 
 
30
  )
31
  model.init_model(path)
32
  model.disable_grads()
 
 
33
 
34
  trainable_params, all_param = count_parameters(model)
35
  param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(