Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -52,23 +52,15 @@ atexit.register(cleanup_temp_files)
|
|
52 |
def random_properties():
|
53 |
return known_labels[all_properties].sample(1).values.tolist()[0]
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
model_all = load_graph_decoder(path='model_all')
|
60 |
-
model_labeled = load_graph_decoder(path='model_labeled')
|
61 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
62 |
-
print('device')
|
63 |
-
|
64 |
|
65 |
# Create a flagged folder if it doesn't exist
|
66 |
flagged_folder = "flagged"
|
67 |
os.makedirs(flagged_folder, exist_ok=True)
|
68 |
|
69 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
70 |
-
# model.to(device)
|
71 |
-
|
72 |
def save_interesting_log(smiles, properties, suggested_properties):
|
73 |
"""Save interesting polymer data to a CSV file."""
|
74 |
log_file = os.path.join(flagged_folder, "log.csv")
|
@@ -89,8 +81,10 @@ def save_interesting_log(smiles, properties, suggested_properties):
|
|
89 |
}
|
90 |
writer.writerow(log_data)
|
91 |
|
92 |
-
|
93 |
-
def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, num_chain_steps, fps):
|
|
|
|
|
94 |
properties = [CH4, CO2, H2, N2, O2]
|
95 |
|
96 |
def is_nan_like(x):
|
@@ -112,8 +106,8 @@ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_ti
|
|
112 |
# print('Before generation, move model to', device)
|
113 |
# return generated_molecule, img_list
|
114 |
# generated_molecule, img_list = generate_func()
|
115 |
-
|
116 |
-
|
117 |
generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
|
118 |
|
119 |
# Create GIF if img_list is available
|
@@ -186,12 +180,8 @@ def numpy_to_python(obj):
|
|
186 |
else:
|
187 |
return obj
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
# def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
|
192 |
-
# model = model_state
|
193 |
-
# result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model, num_chain_steps, fps)
|
194 |
-
result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, num_chain_steps, fps)
|
195 |
# Check if the generation was successful
|
196 |
if result[0].startswith("**Generated polymer SMILES:**"):
|
197 |
smiles = result[0].split("**Generated polymer SMILES:** `")[1].split("`")[0]
|
@@ -286,8 +276,7 @@ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
|
|
286 |
```
|
287 |
""")
|
288 |
|
289 |
-
|
290 |
-
# model_state = model_labeled
|
291 |
|
292 |
with gr.Row():
|
293 |
CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
|
@@ -316,17 +305,13 @@ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
|
|
316 |
feedback_btn = gr.Button("🌟 I think this polymer is interesting!", visible=True, interactive=False)
|
317 |
feedback_result = gr.Textbox(label="Feedback Result", visible=False)
|
318 |
|
319 |
-
#
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
# else:
|
327 |
-
# raise ValueError('Not support model', internal_name)
|
328 |
-
|
329 |
-
# model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
|
330 |
|
331 |
# Hidden components to store generation data
|
332 |
hidden_smiles = gr.Textbox(visible=False)
|
@@ -341,8 +326,7 @@ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
|
|
341 |
|
342 |
generate_btn.click(
|
343 |
on_generate,
|
344 |
-
|
345 |
-
inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale, num_nodes, repeating_time, num_chain_steps, fps],
|
346 |
outputs=[result_text, result_image, result_gif, hidden_smiles, hidden_properties, hidden_suggested_properties, feedback_btn]
|
347 |
)
|
348 |
|
|
|
52 |
def random_properties():
|
53 |
return known_labels[all_properties].sample(1).values.tolist()[0]
|
54 |
|
55 |
+
def load_model(model_choice):
|
56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
57 |
+
model = load_graph_decoder(path=model_choice)
|
58 |
+
return (model, device)
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
# Create a flagged folder if it doesn't exist
|
61 |
flagged_folder = "flagged"
|
62 |
os.makedirs(flagged_folder, exist_ok=True)
|
63 |
|
|
|
|
|
|
|
64 |
def save_interesting_log(smiles, properties, suggested_properties):
|
65 |
"""Save interesting polymer data to a CSV file."""
|
66 |
log_file = os.path.join(flagged_folder, "log.csv")
|
|
|
81 |
}
|
82 |
writer.writerow(log_data)
|
83 |
|
84 |
+
@spaces.GPU(duration=60)
|
85 |
+
def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
|
86 |
+
model, device = model_state
|
87 |
+
|
88 |
properties = [CH4, CO2, H2, N2, O2]
|
89 |
|
90 |
def is_nan_like(x):
|
|
|
106 |
# print('Before generation, move model to', device)
|
107 |
# return generated_molecule, img_list
|
108 |
# generated_molecule, img_list = generate_func()
|
109 |
+
|
110 |
+
model.to(device)
|
111 |
generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
|
112 |
|
113 |
# Create GIF if img_list is available
|
|
|
180 |
else:
|
181 |
return obj
|
182 |
|
183 |
+
def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
|
184 |
+
result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps)
|
|
|
|
|
|
|
|
|
185 |
# Check if the generation was successful
|
186 |
if result[0].startswith("**Generated polymer SMILES:**"):
|
187 |
smiles = result[0].split("**Generated polymer SMILES:** `")[1].split("`")[0]
|
|
|
276 |
```
|
277 |
""")
|
278 |
|
279 |
+
model_state = gr.State(lambda: load_model("model_all"))
|
|
|
280 |
|
281 |
with gr.Row():
|
282 |
CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
|
|
|
305 |
feedback_btn = gr.Button("🌟 I think this polymer is interesting!", visible=True, interactive=False)
|
306 |
feedback_result = gr.Textbox(label="Feedback Result", visible=False)
|
307 |
|
308 |
+
# Add model switching functionality
|
309 |
+
def switch_model(choice):
|
310 |
+
# Convert display name back to internal name
|
311 |
+
internal_name = next(key for key, value in model_name_mapping.items() if value == choice)
|
312 |
+
return load_model(internal_name)
|
313 |
+
|
314 |
+
model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
|
|
|
|
|
|
|
|
|
315 |
|
316 |
# Hidden components to store generation data
|
317 |
hidden_smiles = gr.Textbox(visible=False)
|
|
|
326 |
|
327 |
generate_btn.click(
|
328 |
on_generate,
|
329 |
+
inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps],
|
|
|
330 |
outputs=[result_text, result_image, result_gif, hidden_smiles, hidden_properties, hidden_suggested_properties, feedback_btn]
|
331 |
)
|
332 |
|