liuganghuggingface commited on
Commit
80330c5
·
verified ·
1 Parent(s): 31acec8

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -58,6 +58,8 @@ def random_properties():
58
  model_all = load_graph_decoder(path='model_all')
59
  model_labeled = load_graph_decoder(path='model_labeled')
60
  # return (model, device)
 
 
61
 
62
  # Create a flagged folder if it doesn't exist
63
  flagged_folder = "flagged"
@@ -110,7 +112,7 @@ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_ti
110
  # return generated_molecule, img_list
111
  # generated_molecule, img_list = generate_func()
112
  # print('Before generation, move model to', device)
113
- model.to(device)
114
  generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
115
 
116
  # Create GIF if img_list is available
 
58
  model_all = load_graph_decoder(path='model_all')
59
  model_labeled = load_graph_decoder(path='model_labeled')
60
  # return (model, device)
61
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ print('device')
63
 
64
  # Create a flagged folder if it doesn't exist
65
  flagged_folder = "flagged"
 
112
  # return generated_molecule, img_list
113
  # generated_molecule, img_list = generate_func()
114
  # print('Before generation, move model to', device)
115
+ # model.to(device)
116
  generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
117
 
118
  # Create GIF if img_list is available