liuganghuggingface commited on
Commit
491cb22
Β·
verified Β·
1 Parent(s): 5fc2dd1

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +349 -0
app.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from rdkit import Chem
3
+ from rdkit.Chem import Draw
4
+ import numpy as np
5
+ import torch
6
+ import pandas as pd
7
+ import random
8
+ import io
9
+ import imageio
10
+ import os
11
+ import tempfile
12
+ import atexit
13
+ import glob
14
+ import csv
15
+ from datetime import datetime
16
+ import json
17
+
18
+ from evaluator import Evaluator
19
+ from loader import load_graph_decoder
20
+
21
+ # Load the CSV data
22
+ known_labels = pd.read_csv('data/known_labels.csv')
23
+ knwon_smiles = pd.read_csv('data/known_polymers.csv')
24
+
25
+ all_properties = ['CH4', 'CO2', 'H2', 'N2', 'O2']
26
+
27
+ # Initialize evaluators
28
+ evaluators = {prop: Evaluator(f'evaluators/{prop}.joblib', prop) for prop in all_properties}
29
+
30
+ # Get min and max values for each property
31
+ property_ranges = {prop: (known_labels[prop].min(), known_labels[prop].max()) for prop in all_properties}
32
+
33
+ # Create a temporary directory for GIFs
34
+ temp_dir = tempfile.mkdtemp(prefix="polymer_gifs_")
35
+
36
+ def cleanup_temp_files():
37
+ """Clean up temporary GIF files on exit."""
38
+ for file in glob.glob(os.path.join(temp_dir, "*.gif")):
39
+ try:
40
+ os.remove(file)
41
+ except Exception as e:
42
+ print(f"Error deleting {file}: {e}")
43
+ try:
44
+ os.rmdir(temp_dir)
45
+ except Exception as e:
46
+ print(f"Error deleting temporary directory {temp_dir}: {e}")
47
+
48
+ # Register the cleanup function to be called on exit
49
+ atexit.register(cleanup_temp_files)
50
+
51
+ def random_properties():
52
+ return known_labels[all_properties].sample(1).values.tolist()[0]
53
+
54
+ def load_model(model_choice):
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ model = load_graph_decoder(device, path=model_choice)
57
+ return (model, device)
58
+
59
+ # Create a flagged folder if it doesn't exist
60
+ flagged_folder = "flagged"
61
+ os.makedirs(flagged_folder, exist_ok=True)
62
+
63
+ def save_interesting_log(smiles, properties, suggested_properties):
64
+ """Save interesting polymer data to a CSV file."""
65
+ log_file = os.path.join(flagged_folder, "log.csv")
66
+ file_exists = os.path.isfile(log_file)
67
+
68
+ with open(log_file, 'a', newline='') as csvfile:
69
+ fieldnames = ['timestamp', 'smiles'] + all_properties + [f'suggested_{prop}' for prop in all_properties]
70
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
71
+
72
+ if not file_exists:
73
+ writer.writeheader()
74
+
75
+ log_data = {
76
+ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
77
+ 'smiles': smiles,
78
+ **{prop: value for prop, value in zip(all_properties, properties)},
79
+ **{f'suggested_{prop}': value for prop, value in suggested_properties.items()}
80
+ }
81
+ writer.writerow(log_data)
82
+
83
+ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
84
+ model, device = model_state
85
+
86
+ properties = [CH4, CO2, H2, N2, O2]
87
+
88
+ def is_nan_like(x):
89
+ return x == 0 or x == '' or (isinstance(x, float) and np.isnan(x))
90
+
91
+ properties = [None if is_nan_like(prop) else prop for prop in properties]
92
+
93
+ nan_message = "The following gas properties were treated as NaN: "
94
+ nan_gases = [gas for gas, prop in zip(all_properties, properties) if prop is None]
95
+ nan_message += ", ".join(nan_gases) if nan_gases else "None"
96
+
97
+ num_nodes = None if num_nodes == 0 else num_nodes
98
+
99
+ for _ in range(repeating_time):
100
+ try:
101
+ @spaces.GPU(duration=60)
102
+ def generate_func():
103
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
+ model.to(device)
105
+ print('Before generation, move model to', device)
106
+ generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
107
+ return generated_molecule, img_list
108
+
109
+ generated_molecule, img_list = generate_func()
110
+
111
+ # Create GIF if img_list is available
112
+ gif_path = None
113
+ if img_list and len(img_list) > 0:
114
+ imgs = [np.array(pil_img) for pil_img in img_list]
115
+ imgs.extend([imgs[-1]] * 10)
116
+ gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
117
+ imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
118
+
119
+ if generated_molecule is not None:
120
+ mol = Chem.MolFromSmiles(generated_molecule)
121
+ if mol is not None:
122
+ standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
123
+ is_novel = standardized_smiles not in knwon_smiles['smiles'].values
124
+ novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
125
+ img = Draw.MolToImage(mol)
126
+
127
+ # Evaluate the generated molecule
128
+ suggested_properties = {}
129
+ for prop, evaluator in evaluators.items():
130
+ suggested_properties[prop] = evaluator([standardized_smiles])[0]
131
+
132
+ suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
133
+
134
+ return (
135
+ f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
136
+ f"**{nan_message}**\n\n"
137
+ f"**{novelty_status}**\n\n"
138
+ f"**Suggested Properties:**\n{suggested_properties_text}",
139
+ img,
140
+ gif_path,
141
+ properties, # Add this
142
+ suggested_properties # Add this
143
+ )
144
+ else:
145
+ return (
146
+ f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
147
+ None,
148
+ gif_path,
149
+ properties,
150
+ None,
151
+ )
152
+ except Exception as e:
153
+ print(f"Error in generation: {e}")
154
+ continue
155
+
156
+ return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
157
+
158
+ def set_random_properties():
159
+ return random_properties()
160
+
161
+ # Create a mapping of internal names to display names
162
+ model_name_mapping = {
163
+ "model_all": "Graph DiT (trained on labeled + unlabeled)",
164
+ "model_labeled": "Graph DiT (trained on labeled)"
165
+ }
166
+
167
+ def numpy_to_python(obj):
168
+ if isinstance(obj, np.integer):
169
+ return int(obj)
170
+ elif isinstance(obj, np.floating):
171
+ return float(obj)
172
+ elif isinstance(obj, np.ndarray):
173
+ return obj.tolist()
174
+ elif isinstance(obj, list):
175
+ return [numpy_to_python(item) for item in obj]
176
+ elif isinstance(obj, dict):
177
+ return {k: numpy_to_python(v) for k, v in obj.items()}
178
+ else:
179
+ return obj
180
+
181
+ def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
182
+ result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps)
183
+ # Check if the generation was successful
184
+ if result[0].startswith("**Generated polymer SMILES:**"):
185
+ smiles = result[0].split("**Generated polymer SMILES:** `")[1].split("`")[0]
186
+ properties = json.dumps(numpy_to_python(result[3]))
187
+ suggested_properties = json.dumps(numpy_to_python(result[4]))
188
+ # Return the result with an enabled feedback button
189
+ return [*result[:3], smiles, properties, suggested_properties, gr.Button(interactive=True)]
190
+ else:
191
+ # Return the result with a disabled feedback button
192
+ return [*result[:3], "", "[]", "[]", gr.Button(interactive=False)]
193
+
194
+ def process_feedback(checkbox_value, smiles, properties, suggested_properties):
195
+ if checkbox_value:
196
+ # Check if properties and suggested_properties are already Python objects
197
+ if isinstance(properties, str):
198
+ properties = json.loads(properties)
199
+ if isinstance(suggested_properties, str):
200
+ suggested_properties = json.loads(suggested_properties)
201
+
202
+ save_interesting_log(smiles, properties, suggested_properties)
203
+ return gr.Textbox(value="Thank you for your feedback! This polymer has been saved to our interesting polymers log.", visible=True)
204
+ else:
205
+ return gr.Textbox(value="Thank you for your feedback!", visible=True)
206
+
207
+ # ADD THIS FUNCTION
208
+ def reset_feedback_button():
209
+ return gr.Button(interactive=False)
210
+
211
+ # Create the Gradio interface using Blocks
212
+ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
213
+ # Navigation Bar
214
+ with gr.Row(elem_id="navbar"):
215
+ gr.Markdown("""
216
+ <div style="text-align: center;">
217
+ <h1>πŸ”—πŸ”¬ Polymer Design with GraphDiT</h1>
218
+ <div style="display: flex; gap: 20px; justify-content: center; align-items: center; margin-top: 10px;">
219
+ <a href="https://github.com/liugangcode/Graph-DiT" target="_blank" style="display: flex; align-items: center; gap: 5px; text-decoration: none; color: inherit;">
220
+ <img src="https://img.icons8.com/ios-glyphs/30/000000/github.png" alt="GitHub" />
221
+ <span>View Code</span>
222
+ </a>
223
+ <a href="https://arxiv.org/abs/2401.13858" target="_blank" style="text-decoration: none; color: inherit;">
224
+ πŸ“„ View Paper
225
+ </a>
226
+ </div>
227
+ </div>
228
+ """)
229
+
230
+ # Main Description
231
+ gr.Markdown("""
232
+ ## Introduction
233
+
234
+ Input the desired gas barrier properties for CHβ‚„, COβ‚‚, Hβ‚‚, Nβ‚‚, and Oβ‚‚ to generate novel polymer structures. The results are visualized as molecular graphs and represented by SMILES strings if they are successfully generated. Note: Gas barrier values set to 0 will be treated as `NaN` (unconditionally). If the generation fails, please retry or increase the number of repetition attempts.
235
+ """)
236
+
237
+ # Model Selection
238
+ model_choice = gr.Radio(
239
+ choices=list(model_name_mapping.values()),
240
+ label="Model Zoo",
241
+ # value="Graph DiT (trained on labeled + unlabeled)"
242
+ value="Graph DiT (trained on labeled)"
243
+ )
244
+
245
+ # Model Description Accordion
246
+ with gr.Accordion("πŸ” Model Description", open=False):
247
+ gr.Markdown("""
248
+ ### GraphDiT: Graph Diffusion Transformer
249
+
250
+ GraphDiT is a graph diffusion model designed for targeted molecular generation. It employs a conditional diffusion process to iteratively refine molecular structures based on user-specified properties.
251
+
252
+ We have collected a labeled polymer database for gas permeability from [Membrane Database](https://research.csiro.au/virtualscreening/membrane-database-polymer-gas-separation-membranes/). Additionally, we utilize unlabeled polymer structures from [PolyInfo](https://polymer.nims.go.jp/).
253
+
254
+ The gas permeability ranges from 0 to over ten thousand, with only hundreds of labeled data points, making this task particularly challenging.
255
+
256
+ We are actively working on improving the model. We welcome any feedback regarding model usage or suggestions for improvement.
257
+
258
+ #### Currently, we have two variants of Graph DiT:
259
+ - **Graph DiT (trained on labeled + unlabeled)**: This model uses both labeled and unlabeled data for training, potentially leading to more diverse/novel polymer generation.
260
+ - **Graph DiT (trained on labeled)**: This model is trained exclusively on labeled data, which may result in higher validity but potentially less diverse/novel outputs.
261
+ """)
262
+
263
+ # Citation Accordion
264
+ with gr.Accordion("πŸ“„ Citation", open=False):
265
+ gr.Markdown("""
266
+ If you use this model or interface useful, please cite the following paper:
267
+ ```bibtex
268
+ @article{graphdit2024,
269
+ title={Graph Diffusion Transformers for Multi-Conditional Molecular Generation},
270
+ author={Liu, Gang and Xu, Jiaxin and Luo, Tengfei and Jiang, Meng},
271
+ journal={NeurIPS},
272
+ year={2024},
273
+ }
274
+ ```
275
+ """)
276
+
277
+ model_state = gr.State(lambda: load_model("model_all"))
278
+
279
+ with gr.Row():
280
+ CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CHβ‚„ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
281
+ CO2_input = gr.Slider(0, property_ranges['CO2'][1], value=15.4, label=f"COβ‚‚ (Barrier) [0-{property_ranges['CO2'][1]:.1f}]")
282
+ H2_input = gr.Slider(0, property_ranges['H2'][1], value=21.0, label=f"Hβ‚‚ (Barrier) [0-{property_ranges['H2'][1]:.1f}]")
283
+ N2_input = gr.Slider(0, property_ranges['N2'][1], value=1.5, label=f"Nβ‚‚ (Barrier) [0-{property_ranges['N2'][1]:.1f}]")
284
+ O2_input = gr.Slider(0, property_ranges['O2'][1], value=2.8, label=f"Oβ‚‚ (Barrier) [0-{property_ranges['O2'][1]:.1f}]")
285
+
286
+ with gr.Row():
287
+ guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale from Properties")
288
+ num_nodes = gr.Slider(0, 50, step=1, value=0, label="Number of Nodes (0 for Random, Larger Graphs Take More Time)")
289
+ repeating_time = gr.Slider(1, 10, step=1, value=3, label="Repetition Until Success")
290
+ num_chain_steps = gr.Slider(0, 499, step=1, value=50, label="Number of Diffusion Steps to Visualize (Larger Numbers Take More Time)")
291
+ fps = gr.Slider(0.25, 10, step=0.25, value=5, label="Frames Per Second")
292
+
293
+ with gr.Row():
294
+ random_btn = gr.Button("πŸ”€ Randomize Properties (from Labeled Data)")
295
+ generate_btn = gr.Button("πŸš€ Generate Polymer")
296
+
297
+ with gr.Row():
298
+ result_text = gr.Textbox(label="πŸ“ Generation Result")
299
+ result_image = gr.Image(label="Final Molecule Visualization", type="pil")
300
+ result_gif = gr.Image(label="Generation Process Visualization", type="filepath", format="gif")
301
+
302
+ with gr.Row() as feedback_row:
303
+ feedback_btn = gr.Button("🌟 I think this polymer is interesting!", visible=True, interactive=False)
304
+ feedback_result = gr.Textbox(label="Feedback Result", visible=False)
305
+
306
+ # Add model switching functionality
307
+ def switch_model(choice):
308
+ # Convert display name back to internal name
309
+ internal_name = next(key for key, value in model_name_mapping.items() if value == choice)
310
+ return load_model(internal_name)
311
+
312
+ model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
313
+
314
+ # Hidden components to store generation data
315
+ hidden_smiles = gr.Textbox(visible=False)
316
+ hidden_properties = gr.JSON(visible=False)
317
+ hidden_suggested_properties = gr.JSON(visible=False)
318
+
319
+ # Set up event handlers
320
+ random_btn.click(
321
+ set_random_properties,
322
+ outputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input]
323
+ )
324
+
325
+ generate_btn.click(
326
+ on_generate,
327
+ inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps],
328
+ outputs=[result_text, result_image, result_gif, hidden_smiles, hidden_properties, hidden_suggested_properties, feedback_btn]
329
+ )
330
+
331
+ feedback_btn.click(
332
+ process_feedback,
333
+ inputs=[gr.Checkbox(value=True, visible=False), hidden_smiles, hidden_properties, hidden_suggested_properties],
334
+ outputs=[feedback_result]
335
+ ).then(
336
+ lambda: gr.Button(interactive=False),
337
+ outputs=[feedback_btn]
338
+ )
339
+
340
+ CH4_input.change(reset_feedback_button, outputs=[feedback_btn])
341
+ CO2_input.change(reset_feedback_button, outputs=[feedback_btn])
342
+ H2_input.change(reset_feedback_button, outputs=[feedback_btn])
343
+ N2_input.change(reset_feedback_button, outputs=[feedback_btn])
344
+ O2_input.change(reset_feedback_button, outputs=[feedback_btn])
345
+ random_btn.click(reset_feedback_button, outputs=[feedback_btn])
346
+
347
+ # Launch the interface
348
+ if __name__ == "__main__":
349
+ iface.launch(share=True)