liuganghuggingface commited on
Commit
66142af
·
verified ·
1 Parent(s): 7173a2e

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +21 -37
app.py CHANGED
@@ -52,23 +52,15 @@ atexit.register(cleanup_temp_files)
52
  def random_properties():
53
  return known_labels[all_properties].sample(1).values.tolist()[0]
54
 
55
- # def load_model(model_choice):
56
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
- # return (model, device)
58
-
59
- model_all = load_graph_decoder(path='model_all')
60
- model_labeled = load_graph_decoder(path='model_labeled')
61
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
- print('device')
63
-
64
 
65
  # Create a flagged folder if it doesn't exist
66
  flagged_folder = "flagged"
67
  os.makedirs(flagged_folder, exist_ok=True)
68
 
69
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
- # model.to(device)
71
-
72
  def save_interesting_log(smiles, properties, suggested_properties):
73
  """Save interesting polymer data to a CSV file."""
74
  log_file = os.path.join(flagged_folder, "log.csv")
@@ -89,8 +81,10 @@ def save_interesting_log(smiles, properties, suggested_properties):
89
  }
90
  writer.writerow(log_data)
91
 
92
- # def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model, num_chain_steps, fps):
93
- def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, num_chain_steps, fps):
 
 
94
  properties = [CH4, CO2, H2, N2, O2]
95
 
96
  def is_nan_like(x):
@@ -112,8 +106,8 @@ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_ti
112
  # print('Before generation, move model to', device)
113
  # return generated_molecule, img_list
114
  # generated_molecule, img_list = generate_func()
115
- # print('Before generation, move model to', device)
116
- # model.to(device)
117
  generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
118
 
119
  # Create GIF if img_list is available
@@ -186,12 +180,8 @@ def numpy_to_python(obj):
186
  else:
187
  return obj
188
 
189
- @spaces.GPU(duration=60)
190
- def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, num_chain_steps, fps):
191
- # def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
192
- # model = model_state
193
- # result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model, num_chain_steps, fps)
194
- result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, num_chain_steps, fps)
195
  # Check if the generation was successful
196
  if result[0].startswith("**Generated polymer SMILES:**"):
197
  smiles = result[0].split("**Generated polymer SMILES:** `")[1].split("`")[0]
@@ -286,8 +276,7 @@ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
286
  ```
287
  """)
288
 
289
- # model_state = gr.State(value=model_labeled)
290
- # model_state = model_labeled
291
 
292
  with gr.Row():
293
  CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
@@ -316,17 +305,13 @@ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
316
  feedback_btn = gr.Button("🌟 I think this polymer is interesting!", visible=True, interactive=False)
317
  feedback_result = gr.Textbox(label="Feedback Result", visible=False)
318
 
319
- # def switch_model(choice):
320
- # # Convert display name back to internal name
321
- # internal_name = next(key for key, value in model_name_mapping.items() if value == choice)
322
- # if internal_name == 'model_labeled':
323
- # return model_labeled
324
- # elif internal_name == 'model_all':
325
- # return model_all
326
- # else:
327
- # raise ValueError('Not support model', internal_name)
328
-
329
- # model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
330
 
331
  # Hidden components to store generation data
332
  hidden_smiles = gr.Textbox(visible=False)
@@ -341,8 +326,7 @@ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
341
 
342
  generate_btn.click(
343
  on_generate,
344
- # inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps],
345
- inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale, num_nodes, repeating_time, num_chain_steps, fps],
346
  outputs=[result_text, result_image, result_gif, hidden_smiles, hidden_properties, hidden_suggested_properties, feedback_btn]
347
  )
348
 
 
52
  def random_properties():
53
  return known_labels[all_properties].sample(1).values.tolist()[0]
54
 
55
+ def load_model(model_choice):
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ model = load_graph_decoder(path=model_choice)
58
+ return (model, device)
 
 
 
 
 
59
 
60
  # Create a flagged folder if it doesn't exist
61
  flagged_folder = "flagged"
62
  os.makedirs(flagged_folder, exist_ok=True)
63
 
 
 
 
64
  def save_interesting_log(smiles, properties, suggested_properties):
65
  """Save interesting polymer data to a CSV file."""
66
  log_file = os.path.join(flagged_folder, "log.csv")
 
81
  }
82
  writer.writerow(log_data)
83
 
84
+ @spaces.GPU(duration=60)
85
+ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
86
+ model, device = model_state
87
+
88
  properties = [CH4, CO2, H2, N2, O2]
89
 
90
  def is_nan_like(x):
 
106
  # print('Before generation, move model to', device)
107
  # return generated_molecule, img_list
108
  # generated_molecule, img_list = generate_func()
109
+
110
+ model.to(device)
111
  generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
112
 
113
  # Create GIF if img_list is available
 
180
  else:
181
  return obj
182
 
183
+ def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
184
+ result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps)
 
 
 
 
185
  # Check if the generation was successful
186
  if result[0].startswith("**Generated polymer SMILES:**"):
187
  smiles = result[0].split("**Generated polymer SMILES:** `")[1].split("`")[0]
 
276
  ```
277
  """)
278
 
279
+ model_state = gr.State(lambda: load_model("model_all"))
 
280
 
281
  with gr.Row():
282
  CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
 
305
  feedback_btn = gr.Button("🌟 I think this polymer is interesting!", visible=True, interactive=False)
306
  feedback_result = gr.Textbox(label="Feedback Result", visible=False)
307
 
308
+ # Add model switching functionality
309
+ def switch_model(choice):
310
+ # Convert display name back to internal name
311
+ internal_name = next(key for key, value in model_name_mapping.items() if value == choice)
312
+ return load_model(internal_name)
313
+
314
+ model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
 
 
 
 
315
 
316
  # Hidden components to store generation data
317
  hidden_smiles = gr.Textbox(visible=False)
 
326
 
327
  generate_btn.click(
328
  on_generate,
329
+ inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps],
 
330
  outputs=[result_text, result_image, result_gif, hidden_smiles, hidden_properties, hidden_suggested_properties, feedback_btn]
331
  )
332