liuganghuggingface commited on
Commit
064fd0a
Β·
verified Β·
1 Parent(s): 750cfac

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +400 -280
app.py CHANGED
@@ -1,168 +1,137 @@
1
  import spaces
2
- import gradio as gr
3
- import torch
4
 
5
- import numpy as np
6
- import pandas as pd
7
- import random
8
- import io
9
- import imageio
10
  import os
 
 
 
 
 
11
  import tempfile
12
  import atexit
13
- import glob
14
- import csv
15
  from datetime import datetime
16
- import json
17
 
 
 
 
 
 
18
  from rdkit import Chem
19
  from rdkit.Chem import Draw
 
20
 
 
21
  from evaluator import Evaluator
22
  from loader import load_graph_decoder
23
 
24
- # Load the CSV data
25
- known_labels = pd.read_csv('data/known_labels.csv')
26
- knwon_smiles = pd.read_csv('data/known_polymers.csv')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- all_properties = ['CH4', 'CO2', 'H2', 'N2', 'O2']
29
 
30
- # Initialize evaluators
31
- evaluators = {prop: Evaluator(f'evaluators/{prop}.joblib', prop) for prop in all_properties}
32
 
33
- # Get min and max values for each property
34
- property_ranges = {prop: (known_labels[prop].min(), known_labels[prop].max()) for prop in all_properties}
 
35
 
36
- # Create a temporary directory for GIFs
37
- temp_dir = tempfile.mkdtemp(prefix="polymer_gifs_")
 
 
 
38
 
39
  def cleanup_temp_files():
40
  """Clean up temporary GIF files on exit."""
41
- for file in glob.glob(os.path.join(temp_dir, "*.gif")):
42
- try:
43
- os.remove(file)
44
- except Exception as e:
45
- print(f"Error deleting {file}: {e}")
46
  try:
 
 
47
  os.rmdir(temp_dir)
48
  except Exception as e:
49
- print(f"Error deleting temporary directory {temp_dir}: {e}")
50
 
51
- # Register the cleanup function to be called on exit
52
  atexit.register(cleanup_temp_files)
53
 
 
 
54
  def random_properties():
55
- return known_labels[all_properties].sample(1).values.tolist()[0]
 
56
 
57
  def load_model(model_choice):
 
58
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
  model = load_graph_decoder(path=model_choice)
60
- return (model, device)
61
-
62
- # Create a flagged folder if it doesn't exist
63
- flagged_folder = "flagged"
64
- os.makedirs(flagged_folder, exist_ok=True)
65
 
66
  def save_interesting_log(smiles, properties, suggested_properties):
67
- """Save interesting polymer data to a CSV file."""
68
- log_file = os.path.join(flagged_folder, "log.csv")
 
69
  file_exists = os.path.isfile(log_file)
70
-
71
- with open(log_file, 'a', newline='') as csvfile:
72
- fieldnames = ['timestamp', 'smiles'] + all_properties + [f'suggested_{prop}' for prop in all_properties]
73
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
74
-
75
- if not file_exists:
76
- writer.writeheader()
77
-
78
- log_data = {
79
- 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
80
- 'smiles': smiles,
81
- **{prop: value for prop, value in zip(all_properties, properties)},
82
- **{f'suggested_{prop}': value for prop, value in suggested_properties.items()}
83
- }
84
- writer.writerow(log_data)
85
-
86
- @spaces.GPU
87
- def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
88
- print('in generate_graph')
89
- model, device = model_state
90
-
91
- properties = [CH4, CO2, H2, N2, O2]
92
-
93
- def is_nan_like(x):
94
- return x == 0 or x == '' or (isinstance(x, float) and np.isnan(x))
95
-
96
- properties = [None if is_nan_like(prop) else prop for prop in properties]
97
-
98
- nan_message = "The following gas properties were treated as NaN: "
99
- nan_gases = [gas for gas, prop in zip(all_properties, properties) if prop is None]
100
- nan_message += ", ".join(nan_gases) if nan_gases else "None"
101
 
102
- num_nodes = None if num_nodes == 0 else num_nodes
103
-
104
- for _ in range(repeating_time):
105
- # try:
106
- model.to(device)
107
- generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
108
-
109
- # Create GIF if img_list is available
110
- gif_path = None
111
- if img_list and len(img_list) > 0:
112
- imgs = [np.array(pil_img) for pil_img in img_list]
113
- imgs.extend([imgs[-1]] * 10)
114
- gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
115
- imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
116
-
117
- if generated_molecule is not None:
118
- mol = Chem.MolFromSmiles(generated_molecule)
119
- if mol is not None:
120
- standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
121
- is_novel = standardized_smiles not in knwon_smiles['SMILES'].values
122
- novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
123
- img = Draw.MolToImage(mol)
124
-
125
- # Evaluate the generated molecule
126
- suggested_properties = {}
127
- for prop, evaluator in evaluators.items():
128
- suggested_properties[prop] = evaluator([standardized_smiles])[0]
129
-
130
- suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
131
-
132
- return (
133
- f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
134
- f"**{nan_message}**\n\n"
135
- f"**{novelty_status}**\n\n"
136
- f"**Suggested Properties:**\n{suggested_properties_text}",
137
- img,
138
- gif_path,
139
- properties, # Add this
140
- suggested_properties # Add this
141
- )
142
- else:
143
- return (
144
- f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
145
- None,
146
- gif_path,
147
- properties,
148
- None,
149
- )
150
- # except Exception as e:
151
- # print(f"Error in generation: {e}")
152
- # continue
153
-
154
- return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
155
 
156
- def set_random_properties():
157
- return random_properties()
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- # Create a mapping of internal names to display names
160
- model_name_mapping = {
161
- "model_all": "Graph DiT (trained on labeled + unlabeled)",
162
- "model_labeled": "Graph DiT (trained on labeled)"
163
- }
164
 
165
  def numpy_to_python(obj):
 
166
  if isinstance(obj, np.integer):
167
  return int(obj)
168
  elif isinstance(obj, np.floating):
@@ -175,175 +144,326 @@ def numpy_to_python(obj):
175
  return {k: numpy_to_python(v) for k, v in obj.items()}
176
  else:
177
  return obj
178
-
179
- def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
180
- print('in on_generate', on_generate)
181
- result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps)
182
- # Check if the generation was successful
183
- if result[0].startswith("**Generated polymer SMILES:**"):
184
- smiles = result[0].split("**Generated polymer SMILES:** `")[1].split("`")[0]
185
- properties = json.dumps(numpy_to_python(result[3]))
186
- suggested_properties = json.dumps(numpy_to_python(result[4]))
187
- # Return the result with an enabled feedback button
188
- return [*result[:3], smiles, properties, suggested_properties, gr.Button(interactive=True)]
189
- else:
190
- # Return the result with a disabled feedback button
191
- return [*result[:3], "", "[]", "[]", gr.Button(interactive=False)]
192
 
193
- def process_feedback(checkbox_value, smiles, properties, suggested_properties):
194
- if checkbox_value:
195
- # Check if properties and suggested_properties are already Python objects
196
- if isinstance(properties, str):
197
- properties = json.loads(properties)
198
- if isinstance(suggested_properties, str):
199
- suggested_properties = json.loads(suggested_properties)
200
-
201
- save_interesting_log(smiles, properties, suggested_properties)
202
- return gr.Textbox(value="Thank you for your feedback! This polymer has been saved to our interesting polymers log.", visible=True)
203
- else:
204
- return gr.Textbox(value="Thank you for your feedback!", visible=True)
205
 
206
- # ADD THIS FUNCTION
207
- def reset_feedback_button():
208
- return gr.Button(interactive=False)
209
 
210
- # Create the Gradio interface using Blocks
211
- with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
212
- # Navigation Bar
213
- with gr.Row(elem_id="navbar"):
214
- gr.Markdown("""
215
- <div style="text-align: center;">
216
- <h1>πŸ”—πŸ”¬ Polymer Design with GraphDiT</h1>
217
- <div style="display: flex; gap: 20px; justify-content: center; align-items: center; margin-top: 10px;">
218
- <a href="https://github.com/liugangcode/Graph-DiT" target="_blank" style="display: flex; align-items: center; gap: 5px; text-decoration: none; color: inherit;">
219
- <img src="https://img.icons8.com/ios-glyphs/30/000000/github.png" alt="GitHub" />
220
- <span>View Code</span>
221
- </a>
222
- <a href="https://arxiv.org/abs/2401.13858" target="_blank" style="text-decoration: none; color: inherit;">
223
- πŸ“„ View Paper
224
- </a>
225
- </div>
226
- </div>
227
- """)
228
 
229
- # Main Description
230
- gr.Markdown("""
231
- ## Introduction
232
 
233
- Input the desired gas barrier properties for CHβ‚„, COβ‚‚, Hβ‚‚, Nβ‚‚, and Oβ‚‚ to generate novel polymer structures. The results are visualized as molecular graphs and represented by SMILES strings if they are successfully generated. Note: Gas barrier values set to 0 will be treated as `NaN` (unconditionally). If the generation fails, please retry or increase the number of repetition attempts.
234
- """)
 
 
 
 
 
 
 
235
 
236
- # Model Selection
237
- model_choice = gr.Radio(
238
- choices=list(model_name_mapping.values()),
239
- label="Model Zoo",
240
- # value="Graph DiT (trained on labeled + unlabeled)"
241
- value="Graph DiT (trained on labeled)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  )
243
 
244
- # Model Description Accordion
245
- with gr.Accordion("πŸ” Model Description", open=False):
246
- gr.Markdown("""
247
- ### GraphDiT: Graph Diffusion Transformer
248
 
249
- GraphDiT is a graph diffusion model designed for targeted molecular generation. It employs a conditional diffusion process to iteratively refine molecular structures based on user-specified properties.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- We have collected a labeled polymer database for gas permeability from [Membrane Database](https://research.csiro.au/virtualscreening/membrane-database-polymer-gas-separation-membranes/). Additionally, we utilize unlabeled polymer structures from [PolyInfo](https://polymer.nims.go.jp/).
 
 
252
 
253
- The gas permeability ranges from 0 to over ten thousand, with only hundreds of labeled data points, making this task particularly challenging.
 
254
 
255
- We are actively working on improving the model. We welcome any feedback regarding model usage or suggestions for improvement.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
- #### Currently, we have two variants of Graph DiT:
258
- - **Graph DiT (trained on labeled + unlabeled)**: This model uses both labeled and unlabeled data for training, potentially leading to more diverse/novel polymer generation.
259
- - **Graph DiT (trained on labeled)**: This model is trained exclusively on labeled data, which may result in higher validity but potentially less diverse/novel outputs.
260
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
- # Citation Accordion
263
- with gr.Accordion("πŸ“„ Citation", open=False):
264
- gr.Markdown("""
265
- If you use this model or interface useful, please cite the following paper:
266
- ```bibtex
267
- @article{graphdit2024,
268
- title={Graph Diffusion Transformers for Multi-Conditional Molecular Generation},
269
- author={Liu, Gang and Xu, Jiaxin and Luo, Tengfei and Jiang, Meng},
270
- journal={NeurIPS},
271
- year={2024},
272
- }
273
- ```
274
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
- model_state = gr.State(lambda: load_model("model_labeled"))
277
-
278
- with gr.Row():
279
- CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CHβ‚„ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
280
- CO2_input = gr.Slider(0, property_ranges['CO2'][1], value=15.4, label=f"COβ‚‚ (Barrier) [0-{property_ranges['CO2'][1]:.1f}]")
281
- H2_input = gr.Slider(0, property_ranges['H2'][1], value=21.0, label=f"Hβ‚‚ (Barrier) [0-{property_ranges['H2'][1]:.1f}]")
282
- N2_input = gr.Slider(0, property_ranges['N2'][1], value=1.5, label=f"Nβ‚‚ (Barrier) [0-{property_ranges['N2'][1]:.1f}]")
283
- O2_input = gr.Slider(0, property_ranges['O2'][1], value=2.8, label=f"Oβ‚‚ (Barrier) [0-{property_ranges['O2'][1]:.1f}]")
284
-
285
- with gr.Row():
286
- guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale from Properties")
287
- num_nodes = gr.Slider(0, 50, step=1, value=0, label="Number of Nodes (0 for Random, Larger Graphs Take More Time)")
288
- repeating_time = gr.Slider(1, 10, step=1, value=3, label="Repetition Until Success")
289
- num_chain_steps = gr.Slider(0, 499, step=1, value=50, label="Number of Diffusion Steps to Visualize (Larger Numbers Take More Time)")
290
- fps = gr.Slider(0.25, 10, step=0.25, value=5, label="Frames Per Second")
291
-
292
- with gr.Row():
293
- random_btn = gr.Button("πŸ”€ Randomize Properties (from Labeled Data)")
294
- generate_btn = gr.Button("πŸš€ Generate Polymer")
295
-
296
- with gr.Row():
297
- result_text = gr.Textbox(label="πŸ“ Generation Result")
298
- result_image = gr.Image(label="Final Molecule Visualization", type="pil")
299
- result_gif = gr.Image(label="Generation Process Visualization", type="filepath", format="gif")
300
-
301
- with gr.Row() as feedback_row:
302
- feedback_btn = gr.Button("🌟 I think this polymer is interesting!", visible=True, interactive=False)
303
- feedback_result = gr.Textbox(label="Feedback Result", visible=False)
304
-
305
- # Add model switching functionality
306
- def switch_model(choice):
307
- # Convert display name back to internal name
308
- internal_name = next(key for key, value in model_name_mapping.items() if value == choice)
309
- return load_model(internal_name)
310
-
311
- model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
312
-
313
- # Hidden components to store generation data
314
- hidden_smiles = gr.Textbox(visible=False)
315
- hidden_properties = gr.JSON(visible=False)
316
- hidden_suggested_properties = gr.JSON(visible=False)
317
-
318
- # Set up event handlers
319
- random_btn.click(
320
- set_random_properties,
321
- outputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input]
322
- )
323
 
324
- generate_btn.click(
325
- on_generate,
326
- inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps],
327
- outputs=[result_text, result_image, result_gif, hidden_smiles, hidden_properties, hidden_suggested_properties, feedback_btn]
328
- )
329
 
330
- feedback_btn.click(
331
- process_feedback,
332
- inputs=[gr.Checkbox(value=True, visible=False), hidden_smiles, hidden_properties, hidden_suggested_properties],
333
- outputs=[feedback_result]
334
- ).then(
335
- lambda: gr.Button(interactive=False),
336
- outputs=[feedback_btn]
337
- )
338
-
339
- CH4_input.change(reset_feedback_button, outputs=[feedback_btn])
340
- CO2_input.change(reset_feedback_button, outputs=[feedback_btn])
341
- H2_input.change(reset_feedback_button, outputs=[feedback_btn])
342
- N2_input.change(reset_feedback_button, outputs=[feedback_btn])
343
- O2_input.change(reset_feedback_button, outputs=[feedback_btn])
344
- random_btn.click(reset_feedback_button, outputs=[feedback_btn])
345
-
346
- # Launch the interface
347
  if __name__ == "__main__":
348
- # iface.launch(share=True)
349
- iface.launch(share=False)
 
1
  import spaces
 
 
2
 
3
+ # Standard Libraries
 
 
 
 
4
  import os
5
+ import io
6
+ import csv
7
+ import json
8
+ import glob
9
+ import random
10
  import tempfile
11
  import atexit
 
 
12
  from datetime import datetime
 
13
 
14
+ # Third-Party Libraries
15
+ import numpy as np
16
+ import pandas as pd
17
+ import torch
18
+ import imageio
19
  from rdkit import Chem
20
  from rdkit.Chem import Draw
21
+ import gradio as gr
22
 
23
+ # Local Modules
24
  from evaluator import Evaluator
25
  from loader import load_graph_decoder
26
 
27
+ # --------------------------- Configuration Constants --------------------------- #
28
+
29
+ DATA_DIR = 'data'
30
+ EVALUATORS_DIR = 'evaluators'
31
+ FLAGGED_FOLDER = "flagged"
32
+ KNOWN_LABELS_FILE = os.path.join(DATA_DIR, 'known_labels.csv')
33
+ KNOWN_SMILES_FILE = os.path.join(DATA_DIR, 'known_polymers.csv')
34
+
35
+ ALL_PROPERTIES = ['CH4', 'CO2', 'H2', 'N2', 'O2']
36
+ MODEL_NAME_MAPPING = {
37
+ "model_all": "Graph DiT (trained on labeled + unlabeled)",
38
+ "model_labeled": "Graph DiT (trained on labeled)"
39
+ }
40
+
41
+ GIF_TEMP_PREFIX = "polymer_gifs_"
42
+
43
+ # --------------------------- Data Loading --------------------------- #
44
+
45
+ def load_known_data():
46
+ """Load known labels and SMILES data from CSV files."""
47
+ try:
48
+ known_labels = pd.read_csv(KNOWN_LABELS_FILE)
49
+ known_smiles = pd.read_csv(KNOWN_SMILES_FILE)
50
+ return known_labels, known_smiles
51
+ except Exception as e:
52
+ raise FileNotFoundError(f"Error loading data files: {e}")
53
+
54
+ # Load data
55
+ known_labels, known_smiles = load_known_data()
56
+
57
+ # --------------------------- Evaluator Setup --------------------------- #
58
+
59
+ def initialize_evaluators(properties, evaluators_dir):
60
+ """Initialize evaluators for each property."""
61
+ evaluators = {}
62
+ for prop in properties:
63
+ evaluator_path = os.path.join(evaluators_dir, f'{prop}.joblib')
64
+ evaluators[prop] = Evaluator(evaluator_path, prop)
65
+ return evaluators
66
 
67
+ evaluators = initialize_evaluators(ALL_PROPERTIES, EVALUATORS_DIR)
68
 
69
+ # --------------------------- Property Ranges --------------------------- #
 
70
 
71
+ def get_property_ranges(labels, properties):
72
+ """Get min and max values for each property."""
73
+ return {prop: (labels[prop].min(), labels[prop].max()) for prop in properties}
74
 
75
+ property_ranges = get_property_ranges(known_labels, ALL_PROPERTIES)
76
+
77
+ # --------------------------- Temporary Directory Setup --------------------------- #
78
+
79
+ temp_dir = tempfile.mkdtemp(prefix=GIF_TEMP_PREFIX)
80
 
81
  def cleanup_temp_files():
82
  """Clean up temporary GIF files on exit."""
 
 
 
 
 
83
  try:
84
+ for file in glob.glob(os.path.join(temp_dir, "*.gif")):
85
+ os.remove(file)
86
  os.rmdir(temp_dir)
87
  except Exception as e:
88
+ print(f"Error during cleanup: {e}")
89
 
 
90
  atexit.register(cleanup_temp_files)
91
 
92
+ # --------------------------- Utility Functions --------------------------- #
93
+
94
  def random_properties():
95
+ """Select a random set of properties from known labels."""
96
+ return known_labels[ALL_PROPERTIES].sample(1).values.tolist()[0]
97
 
98
  def load_model(model_choice):
99
+ """Load the graph decoder model based on the choice."""
100
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
  model = load_graph_decoder(path=model_choice)
102
+ model.to(device)
103
+ return model, device
 
 
 
104
 
105
  def save_interesting_log(smiles, properties, suggested_properties):
106
+ """Save interesting polymer data to a CSV log file."""
107
+ log_file = os.path.join(FLAGGED_FOLDER, "log.csv")
108
+ os.makedirs(FLAGGED_FOLDER, exist_ok=True)
109
  file_exists = os.path.isfile(log_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ fieldnames = ['timestamp', 'smiles'] + ALL_PROPERTIES + [f'suggested_{prop}' for prop in ALL_PROPERTIES]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ try:
114
+ with open(log_file, 'a', newline='') as csvfile:
115
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
116
+ if not file_exists:
117
+ writer.writeheader()
118
+
119
+ log_data = {
120
+ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
121
+ 'smiles': smiles,
122
+ **{prop: value for prop, value in zip(ALL_PROPERTIES, properties)},
123
+ **{f'suggested_{prop}': value for prop, value in suggested_properties.items()}
124
+ }
125
+ writer.writerow(log_data)
126
+ except Exception as e:
127
+ print(f"Error saving log: {e}")
128
 
129
+ def is_nan_like(x):
130
+ """Check if a value should be treated as NaN."""
131
+ return x == 0 or x == '' or (isinstance(x, float) and np.isnan(x))
 
 
132
 
133
  def numpy_to_python(obj):
134
+ """Convert NumPy objects to native Python types."""
135
  if isinstance(obj, np.integer):
136
  return int(obj)
137
  elif isinstance(obj, np.floating):
 
144
  return {k: numpy_to_python(v) for k, v in obj.items()}
145
  else:
146
  return obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ # --------------------------- Graph Generation Function --------------------------- #
149
+
150
+ @spaces.GPU
151
+ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
152
+ """
153
+ Generate a polymer graph based on the input properties and model.
154
+ Returns generation results including SMILES, images, and properties.
155
+ """
156
+ print('Generating graph...')
157
+ model, device = model_state
158
+ properties = [CH4, CO2, H2, N2, O2]
 
159
 
160
+ # Handle NaN-like values
161
+ properties = [None if is_nan_like(prop) else prop for prop in properties]
 
162
 
163
+ nan_gases = [gas for gas, prop in zip(ALL_PROPERTIES, properties) if prop is None]
164
+ nan_message = "The following gas properties were treated as NaN: " + (", ".join(nan_gases) if nan_gases else "None")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ num_nodes = None if num_nodes == 0 else num_nodes
 
 
167
 
168
+ for attempt in range(repeating_time):
169
+ try:
170
+ generated_molecule, img_list = model.generate(
171
+ properties,
172
+ device=device,
173
+ guide_scale=guidance_scale,
174
+ num_nodes=num_nodes,
175
+ number_chain_steps=num_chain_steps
176
+ )
177
 
178
+ gif_path = None
179
+ if img_list:
180
+ imgs = [np.array(pil_img) for pil_img in img_list]
181
+ imgs.extend([imgs[-1]] * 10) # Extend the last image for GIF
182
+ gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
183
+ imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
184
+
185
+ if generated_molecule:
186
+ mol = Chem.MolFromSmiles(generated_molecule)
187
+ if mol:
188
+ standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
189
+ is_novel = standardized_smiles not in known_smiles['SMILES'].values
190
+ novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
191
+ img = Draw.MolToImage(mol)
192
+
193
+ # Evaluate the generated molecule
194
+ suggested_properties = {prop: evaluator([standardized_smiles])[0] for prop, evaluator in evaluators.items()}
195
+
196
+ suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
197
+
198
+ return (
199
+ f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
200
+ f"**{nan_message}**\n\n"
201
+ f"**{novelty_status}**\n\n"
202
+ f"**Suggested Properties:**\n{suggested_properties_text}",
203
+ img,
204
+ gif_path,
205
+ standardized_smiles,
206
+ properties,
207
+ suggested_properties
208
+ )
209
+ except Exception as e:
210
+ print(f"Attempt {attempt + 1} failed: {e}")
211
+ continue
212
+
213
+ # If all attempts fail
214
+ return (
215
+ f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**",
216
+ None,
217
+ None,
218
+ "",
219
+ [],
220
+ {}
221
  )
222
 
223
+ # --------------------------- Feedback Processing --------------------------- #
 
 
 
224
 
225
+ def process_feedback(checkbox_value, smiles, properties, suggested_properties):
226
+ """
227
+ Process user feedback. If the user finds the polymer interesting,
228
+ log it accordingly.
229
+ """
230
+ if checkbox_value and smiles:
231
+ save_interesting_log(smiles, properties, suggested_properties)
232
+ return "Thank you for your feedback! This polymer has been saved to our interesting polymers log."
233
+ return "Thank you for your feedback!"
234
+
235
+ # --------------------------- Model Switching --------------------------- #
236
+
237
+ def switch_model(choice):
238
+ """Switch the model based on user selection."""
239
+ internal_name = next(key for key, value in MODEL_NAME_MAPPING.items() if value == choice)
240
+ return load_model(internal_name)
241
+
242
+ # --------------------------- Gradio Interface Setup --------------------------- #
243
+
244
+ def create_gradio_interface():
245
+ """Create and return the Gradio Blocks interface."""
246
+ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
247
+ # Navigation Bar
248
+ with gr.Row(elem_id="navbar"):
249
+ gr.Markdown("""
250
+ <div style="text-align: center;">
251
+ <h1>πŸ”—πŸ”¬ Polymer Design with GraphDiT</h1>
252
+ <div style="display: flex; gap: 20px; justify-content: center; align-items: center; margin-top: 10px;">
253
+ <a href="https://github.com/liugangcode/Graph-DiT" target="_blank" style="display: flex; align-items: center; gap: 5px; text-decoration: none; color: inherit;">
254
+ <img src="https://img.icons8.com/ios-glyphs/30/000000/github.png" alt="GitHub" />
255
+ <span>View Code</span>
256
+ </a>
257
+ <a href="https://arxiv.org/abs/2401.13858" target="_blank" style="text-decoration: none; color: inherit;">
258
+ πŸ“„ View Paper
259
+ </a>
260
+ </div>
261
+ </div>
262
+ """)
263
 
264
+ # Main Description
265
+ gr.Markdown("""
266
+ ## Introduction
267
 
268
+ Input the desired gas barrier properties for CHβ‚„, COβ‚‚, Hβ‚‚, Nβ‚‚, and Oβ‚‚ to generate novel polymer structures. The results are visualized as molecular graphs and represented by SMILES strings if they are successfully generated. **Note:** Gas barrier values set to 0 will be treated as `NaN` (unconditionally). If the generation fails, please retry or increase the number of repetition attempts.
269
+ """)
270
 
271
+ # Model Selection
272
+ model_choice = gr.Radio(
273
+ choices=list(MODEL_NAME_MAPPING.values()),
274
+ label="Model Zoo",
275
+ value=MODEL_NAME_MAPPING["model_labeled"]
276
+ )
277
+
278
+ # Model Description Accordion
279
+ with gr.Accordion("πŸ” Model Description", open=False):
280
+ gr.Markdown("""
281
+ ### GraphDiT: Graph Diffusion Transformer
282
+
283
+ GraphDiT is a graph diffusion model designed for targeted molecular generation. It employs a conditional diffusion process to iteratively refine molecular structures based on user-specified properties.
284
+
285
+ We have collected a labeled polymer database for gas permeability from [Membrane Database](https://research.csiro.au/virtualscreening/membrane-database-polymer-gas-separation-membranes/). Additionally, we utilize unlabeled polymer structures from [PolyInfo](https://polymer.nims.go.jp/).
286
+
287
+ The gas permeability ranges from 0 to over ten thousand, with only hundreds of labeled data points, making this task particularly challenging.
288
+
289
+ We are actively working on improving the model. We welcome any feedback regarding model usage or suggestions for improvement.
290
+
291
+ #### Currently, we have two variants of Graph DiT:
292
+ - **Graph DiT (trained on labeled + unlabeled)**: This model uses both labeled and unlabeled data for training, potentially leading to more diverse/novel polymer generation.
293
+ - **Graph DiT (trained on labeled)**: This model is trained exclusively on labeled data, which may result in higher validity but potentially less diverse/novel outputs.
294
+ """)
295
+
296
+ # Citation Accordion
297
+ with gr.Accordion("πŸ“„ Citation", open=False):
298
+ gr.Markdown("""
299
+ If you use this model or interface useful, please cite the following paper:
300
+ ```bibtex
301
+ @article{graphdit2024,
302
+ title={Graph Diffusion Transformers for Multi-Conditional Molecular Generation},
303
+ author={Liu, Gang and Xu, Jiaxin and Luo, Tengfei and Jiang, Meng},
304
+ journal={NeurIPS},
305
+ year={2024},
306
+ }
307
+ ```
308
+ """)
309
+
310
+ # Initialize Model State
311
+ model_state = gr.State(load_model("model_labeled"))
312
+
313
+ # Property Inputs
314
+ with gr.Row():
315
+ CH4_input = gr.Slider(
316
+ minimum=0,
317
+ maximum=property_ranges['CH4'][1],
318
+ value=2.5,
319
+ label=f"CHβ‚„ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]"
320
+ )
321
+ CO2_input = gr.Slider(
322
+ minimum=0,
323
+ maximum=property_ranges['CO2'][1],
324
+ value=15.4,
325
+ label=f"COβ‚‚ (Barrier) [0-{property_ranges['CO2'][1]:.1f}]"
326
+ )
327
+ H2_input = gr.Slider(
328
+ minimum=0,
329
+ maximum=property_ranges['H2'][1],
330
+ value=21.0,
331
+ label=f"Hβ‚‚ (Barrier) [0-{property_ranges['H2'][1]:.1f}]"
332
+ )
333
+ N2_input = gr.Slider(
334
+ minimum=0,
335
+ maximum=property_ranges['N2'][1],
336
+ value=1.5,
337
+ label=f"Nβ‚‚ (Barrier) [0-{property_ranges['N2'][1]:.1f}]"
338
+ )
339
+ O2_input = gr.Slider(
340
+ minimum=0,
341
+ maximum=property_ranges['O2'][1],
342
+ value=2.8,
343
+ label=f"Oβ‚‚ (Barrier) [0-{property_ranges['O2'][1]:.1f}]"
344
+ )
345
 
346
+ # Generation Parameters
347
+ with gr.Row():
348
+ guidance_scale = gr.Slider(
349
+ minimum=1,
350
+ maximum=3,
351
+ value=2,
352
+ label="Guidance Scale from Properties"
353
+ )
354
+ num_nodes = gr.Slider(
355
+ minimum=0,
356
+ maximum=50,
357
+ step=1,
358
+ value=0,
359
+ label="Number of Nodes (0 for Random, Larger Graphs Take More Time)"
360
+ )
361
+ repeating_time = gr.Slider(
362
+ minimum=1,
363
+ maximum=10,
364
+ step=1,
365
+ value=3,
366
+ label="Repetition Until Success"
367
+ )
368
+ num_chain_steps = gr.Slider(
369
+ minimum=0,
370
+ maximum=499,
371
+ step=1,
372
+ value=50,
373
+ label="Number of Diffusion Steps to Visualize (Larger Numbers Take More Time)"
374
+ )
375
+ fps = gr.Slider(
376
+ minimum=0.25,
377
+ maximum=10,
378
+ step=0.25,
379
+ value=5,
380
+ label="Frames Per Second"
381
+ )
382
 
383
+ # Action Buttons
384
+ with gr.Row():
385
+ random_btn = gr.Button("πŸ”€ Randomize Properties (from Labeled Data)")
386
+ generate_btn = gr.Button("πŸš€ Generate Polymer")
387
+
388
+ # Results Display
389
+ with gr.Row():
390
+ result_text = gr.Textbox(label="πŸ“ Generation Result", lines=10)
391
+ result_image = gr.Image(label="Final Molecule Visualization", type="pil")
392
+ result_gif = gr.Image(label="Generation Process Visualization", type="filepath", format="gif")
393
+
394
+ # Feedback Section
395
+ with gr.Row():
396
+ feedback_btn = gr.Button("🌟 I think this polymer is interesting!", interactive=False)
397
+ feedback_result = gr.Textbox(label="Feedback Result", visible=False)
398
+
399
+ # Hidden Components to Store Generation Data
400
+ hidden_smiles = gr.Textbox(visible=False)
401
+ hidden_properties = gr.JSON(visible=False)
402
+ hidden_suggested_properties = gr.JSON(visible=False)
403
+
404
+ # Event Handlers
405
+
406
+ # Model Selection Change
407
+ model_choice.change(
408
+ switch_model,
409
+ inputs=[model_choice],
410
+ outputs=[model_state]
411
+ )
412
+
413
+ # Randomize Properties Button
414
+ random_btn.click(
415
+ random_properties,
416
+ outputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input]
417
+ )
418
+
419
+ # Generate Polymer Button
420
+ generate_btn.click(
421
+ generate_graph,
422
+ inputs=[
423
+ CH4_input, CO2_input, H2_input, N2_input, O2_input,
424
+ guidance_scale, num_nodes, repeating_time,
425
+ model_state, num_chain_steps, fps
426
+ ],
427
+ outputs=[
428
+ result_text, result_image, result_gif,
429
+ hidden_smiles, hidden_properties, hidden_suggested_properties
430
+ ]
431
+ ).then(
432
+ lambda text, img, gif, smiles, props, sugg_props: (
433
+ smiles if text.startswith("**Generated polymer SMILES:**") else "",
434
+ json.dumps(numpy_to_python(props)),
435
+ json.dumps(numpy_to_python(sugg_props)),
436
+ gr.Button(interactive=text.startswith("**Generated polymer SMILES:**"))
437
+ ),
438
+ inputs=[
439
+ result_text, result_image, result_gif,
440
+ hidden_smiles, hidden_properties, hidden_suggested_properties
441
+ ],
442
+ outputs=[hidden_smiles, hidden_properties, hidden_suggested_properties, feedback_btn]
443
+ )
444
+
445
+ # Feedback Button Click
446
+ feedback_btn.click(
447
+ process_feedback,
448
+ inputs=[gr.Checkbox(label="Interested?", value=True, visible=False), hidden_smiles, hidden_properties, hidden_suggested_properties],
449
+ outputs=[feedback_result]
450
+ ).then(
451
+ lambda: gr.Button(interactive=False),
452
+ outputs=[feedback_btn]
453
+ )
454
+
455
+ # Reset Feedback Button on Input Changes
456
+ for input_component in [CH4_input, CO2_input, H2_input, N2_input, O2_input, random_btn]:
457
+ input_component.change(
458
+ lambda: None,
459
+ outputs=[feedback_btn],
460
+ _js="() => feedback_btn.interactive = false"
461
+ )
462
 
463
+ return iface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
+ # --------------------------- Main Execution --------------------------- #
 
 
 
 
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  if __name__ == "__main__":
468
+ interface = create_gradio_interface()
469
+ interface.launch(share=False)