liuganghuggingface commited on
Commit
e9af19b
Β·
verified Β·
1 Parent(s): 00c454f

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +315 -402
app.py CHANGED
@@ -1,136 +1,203 @@
1
  import spaces
 
2
 
3
- # Standard Libraries
4
- import os
5
- import io
6
- import csv
7
- import json
8
- import glob
9
  import random
 
 
 
10
  import tempfile
11
  import atexit
 
 
12
  from datetime import datetime
 
13
 
14
- # Third-Party Libraries
15
- import numpy as np
16
- import pandas as pd
17
- import torch
18
- import imageio
19
  from rdkit import Chem
20
  from rdkit.Chem import Draw
21
- import gradio as gr
22
 
23
- # Local Modules
24
  from evaluator import Evaluator
25
- from loader import load_graph_decoder
26
-
27
- # --------------------------- Configuration Constants --------------------------- #
28
-
29
- DATA_DIR = 'data'
30
- EVALUATORS_DIR = 'evaluators'
31
- FLAGGED_FOLDER = "flagged"
32
- KNOWN_LABELS_FILE = os.path.join(DATA_DIR, 'known_labels.csv')
33
- KNOWN_SMILES_FILE = os.path.join(DATA_DIR, 'known_polymers.csv')
34
-
35
- ALL_PROPERTIES = ['CH4', 'CO2', 'H2', 'N2', 'O2']
36
- MODEL_NAME_MAPPING = {
37
- "model_all": "Graph DiT (trained on labeled + unlabeled)",
38
- "model_labeled": "Graph DiT (trained on labeled)"
39
- }
40
-
41
- GIF_TEMP_PREFIX = "polymer_gifs_"
42
-
43
- # --------------------------- Data Loading --------------------------- #
44
-
45
- def load_known_data():
46
- """Load known labels and SMILES data from CSV files."""
47
- try:
48
- known_labels = pd.read_csv(KNOWN_LABELS_FILE)
49
- known_smiles = pd.read_csv(KNOWN_SMILES_FILE)
50
- return known_labels, known_smiles
51
- except Exception as e:
52
- raise FileNotFoundError(f"Error loading data files: {e}")
53
-
54
- # Load data
55
- known_labels, known_smiles = load_known_data()
56
 
57
- # --------------------------- Evaluator Setup --------------------------- #
58
-
59
- def initialize_evaluators(properties, evaluators_dir):
60
- """Initialize evaluators for each property."""
61
- evaluators = {}
62
- for prop in properties:
63
- evaluator_path = os.path.join(evaluators_dir, f'{prop}.joblib')
64
- evaluators[prop] = Evaluator(evaluator_path, prop)
65
- return evaluators
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- evaluators = initialize_evaluators(ALL_PROPERTIES, EVALUATORS_DIR)
 
 
 
 
 
 
68
 
69
- # --------------------------- Property Ranges --------------------------- #
 
 
70
 
71
- def get_property_ranges(labels, properties):
72
- """Get min and max values for each property."""
73
- return {prop: (labels[prop].min(), labels[prop].max()) for prop in properties}
74
 
75
- property_ranges = get_property_ranges(known_labels, ALL_PROPERTIES)
 
76
 
77
- # --------------------------- Temporary Directory Setup --------------------------- #
 
78
 
79
- temp_dir = tempfile.mkdtemp(prefix=GIF_TEMP_PREFIX)
 
80
 
81
  def cleanup_temp_files():
82
  """Clean up temporary GIF files on exit."""
83
- try:
84
- for file in glob.glob(os.path.join(temp_dir, "*.gif")):
85
  os.remove(file)
 
 
 
86
  os.rmdir(temp_dir)
87
  except Exception as e:
88
- print(f"Error during cleanup: {e}")
89
 
 
90
  atexit.register(cleanup_temp_files)
91
 
92
- # --------------------------- Utility Functions --------------------------- #
93
-
94
  def random_properties():
95
- """Select a random set of properties from known labels."""
96
- return known_labels[ALL_PROPERTIES].sample(1).values.tolist()[0]
97
 
98
  def load_model(model_choice):
99
- """Load the graph decoder model based on the choice."""
100
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
  model = load_graph_decoder(path=model_choice)
102
- return model, device
 
 
 
 
103
 
104
  def save_interesting_log(smiles, properties, suggested_properties):
105
- """Save interesting polymer data to a CSV log file."""
106
- log_file = os.path.join(FLAGGED_FOLDER, "log.csv")
107
- os.makedirs(FLAGGED_FOLDER, exist_ok=True)
108
  file_exists = os.path.isfile(log_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- fieldnames = ['timestamp', 'smiles'] + ALL_PROPERTIES + [f'suggested_{prop}' for prop in ALL_PROPERTIES]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- try:
113
- with open(log_file, 'a', newline='') as csvfile:
114
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
115
- if not file_exists:
116
- writer.writeheader()
117
-
118
- log_data = {
119
- 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
120
- 'smiles': smiles,
121
- **{prop: value for prop, value in zip(ALL_PROPERTIES, properties)},
122
- **{f'suggested_{prop}': value for prop, value in suggested_properties.items()}
123
- }
124
- writer.writerow(log_data)
125
- except Exception as e:
126
- print(f"Error saving log: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- def is_nan_like(x):
129
- """Check if a value should be treated as NaN."""
130
- return x == 0 or x == '' or (isinstance(x, float) and np.isnan(x))
 
 
 
 
 
131
 
132
  def numpy_to_python(obj):
133
- """Convert NumPy objects to native Python types."""
134
  if isinstance(obj, np.integer):
135
  return int(obj)
136
  elif isinstance(obj, np.floating):
@@ -143,329 +210,175 @@ def numpy_to_python(obj):
143
  return {k: numpy_to_python(v) for k, v in obj.items()}
144
  else:
145
  return obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- # --------------------------- Graph Generation Function --------------------------- #
148
-
149
- @spaces.GPU
150
- def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
151
- """
152
- Generate a polymer graph based on the input properties and model.
153
- Returns generation results including SMILES, images, and properties.
154
- """
155
- print('Generating graph...')
156
- model, device = model_state
157
- properties = [CH4, CO2, H2, N2, O2]
 
158
 
159
- # Handle NaN-like values
160
- properties = [None if is_nan_like(prop) else prop for prop in properties]
 
161
 
162
- nan_gases = [gas for gas, prop in zip(ALL_PROPERTIES, properties) if prop is None]
163
- nan_message = "The following gas properties were treated as NaN: " + (", ".join(nan_gases) if nan_gases else "None")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- num_nodes = None if num_nodes == 0 else num_nodes
 
 
166
 
167
- for attempt in range(repeating_time):
168
- try:
169
- generated_molecule, img_list = model.generate(
170
- properties,
171
- device=device,
172
- guide_scale=guidance_scale,
173
- num_nodes=num_nodes,
174
- number_chain_steps=num_chain_steps
175
- )
176
 
177
- gif_path = None
178
- if img_list:
179
- imgs = [np.array(pil_img) for pil_img in img_list]
180
- imgs.extend([imgs[-1]] * 10) # Extend the last image for GIF
181
- gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
182
- imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
183
-
184
- if generated_molecule:
185
- mol = Chem.MolFromSmiles(generated_molecule)
186
- if mol:
187
- standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
188
- is_novel = standardized_smiles not in known_smiles['SMILES'].values
189
- novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
190
- img = Draw.MolToImage(mol)
191
-
192
- # Evaluate the generated molecule
193
- suggested_properties = {prop: evaluator([standardized_smiles])[0] for prop, evaluator in evaluators.items()}
194
-
195
- suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
196
-
197
- return (
198
- f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
199
- f"**{nan_message}**\n\n"
200
- f"**{novelty_status}**\n\n"
201
- f"**Suggested Properties:**\n{suggested_properties_text}",
202
- img,
203
- gif_path,
204
- standardized_smiles,
205
- properties,
206
- suggested_properties
207
- )
208
- except Exception as e:
209
- print(f"Attempt {attempt + 1} failed: {e}")
210
- continue
211
-
212
- # If all attempts fail
213
- return (
214
- f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**",
215
- None,
216
- None,
217
- "",
218
- [],
219
- {}
220
  )
221
 
222
- # --------------------------- Feedback Processing --------------------------- #
 
 
 
223
 
224
- def process_feedback(checkbox_value, smiles, properties, suggested_properties):
225
- """
226
- Process user feedback. If the user finds the polymer interesting,
227
- log it accordingly.
228
- """
229
- if checkbox_value and smiles:
230
- save_interesting_log(smiles, properties, suggested_properties)
231
- return "Thank you for your feedback! This polymer has been saved to our interesting polymers log."
232
- return "Thank you for your feedback!"
233
-
234
- # --------------------------- Model Switching --------------------------- #
235
-
236
- def switch_model(choice):
237
- """Switch the model based on user selection."""
238
- internal_name = next(key for key, value in MODEL_NAME_MAPPING.items() if value == choice)
239
- return load_model(internal_name)
240
-
241
- # --------------------------- Gradio Interface Setup --------------------------- #
242
-
243
- def create_gradio_interface():
244
- """Create and return the Gradio Blocks interface."""
245
- with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
246
- # Navigation Bar
247
- with gr.Row(elem_id="navbar"):
248
- gr.Markdown("""
249
- <div style="text-align: center;">
250
- <h1>πŸ”—πŸ”¬ Polymer Design with GraphDiT</h1>
251
- <div style="display: flex; gap: 20px; justify-content: center; align-items: center; margin-top: 10px;">
252
- <a href="https://github.com/liugangcode/Graph-DiT" target="_blank" style="display: flex; align-items: center; gap: 5px; text-decoration: none; color: inherit;">
253
- <img src="https://img.icons8.com/ios-glyphs/30/000000/github.png" alt="GitHub" />
254
- <span>View Code</span>
255
- </a>
256
- <a href="https://arxiv.org/abs/2401.13858" target="_blank" style="text-decoration: none; color: inherit;">
257
- πŸ“„ View Paper
258
- </a>
259
- </div>
260
- </div>
261
- """)
262
 
263
- # Main Description
264
- gr.Markdown("""
265
- ## Introduction
 
 
266
 
267
- Input the desired gas barrier properties for CHβ‚„, COβ‚‚, Hβ‚‚, Nβ‚‚, and Oβ‚‚ to generate novel polymer structures. The results are visualized as molecular graphs and represented by SMILES strings if they are successfully generated. **Note:** Gas barrier values set to 0 will be treated as `NaN` (unconditionally). If the generation fails, please retry or increase the number of repetition attempts.
 
 
268
  """)
269
 
270
- # Model Selection
271
- model_choice = gr.Radio(
272
- choices=list(MODEL_NAME_MAPPING.values()),
273
- label="Model Zoo",
274
- value=MODEL_NAME_MAPPING["model_labeled"]
275
- )
276
-
277
- # Model Description Accordion
278
- with gr.Accordion("πŸ” Model Description", open=False):
279
- gr.Markdown("""
280
- ### GraphDiT: Graph Diffusion Transformer
281
-
282
- GraphDiT is a graph diffusion model designed for targeted molecular generation. It employs a conditional diffusion process to iteratively refine molecular structures based on user-specified properties.
283
-
284
- We have collected a labeled polymer database for gas permeability from [Membrane Database](https://research.csiro.au/virtualscreening/membrane-database-polymer-gas-separation-membranes/). Additionally, we utilize unlabeled polymer structures from [PolyInfo](https://polymer.nims.go.jp/).
285
-
286
- The gas permeability ranges from 0 to over ten thousand, with only hundreds of labeled data points, making this task particularly challenging.
287
-
288
- We are actively working on improving the model. We welcome any feedback regarding model usage or suggestions for improvement.
289
-
290
- #### Currently, we have two variants of Graph DiT:
291
- - **Graph DiT (trained on labeled + unlabeled)**: This model uses both labeled and unlabeled data for training, potentially leading to more diverse/novel polymer generation.
292
- - **Graph DiT (trained on labeled)**: This model is trained exclusively on labeled data, which may result in higher validity but potentially less diverse/novel outputs.
293
- """)
294
-
295
- # Citation Accordion
296
- with gr.Accordion("πŸ“„ Citation", open=False):
297
- gr.Markdown("""
298
- If you use this model or interface useful, please cite the following paper:
299
- ```bibtex
300
- @article{graphdit2024,
301
- title={Graph Diffusion Transformers for Multi-Conditional Molecular Generation},
302
- author={Liu, Gang and Xu, Jiaxin and Luo, Tengfei and Jiang, Meng},
303
- journal={NeurIPS},
304
- year={2024},
305
- }
306
- ```
307
- """)
308
-
309
- # Initialize Model State
310
- model_state = gr.State(load_model("model_labeled"))
311
-
312
- # Property Inputs
313
- with gr.Row():
314
- CH4_input = gr.Slider(
315
- minimum=0,
316
- maximum=property_ranges['CH4'][1],
317
- value=2.5,
318
- label=f"CHβ‚„ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]"
319
- )
320
- CO2_input = gr.Slider(
321
- minimum=0,
322
- maximum=property_ranges['CO2'][1],
323
- value=15.4,
324
- label=f"COβ‚‚ (Barrier) [0-{property_ranges['CO2'][1]:.1f}]"
325
- )
326
- H2_input = gr.Slider(
327
- minimum=0,
328
- maximum=property_ranges['H2'][1],
329
- value=21.0,
330
- label=f"Hβ‚‚ (Barrier) [0-{property_ranges['H2'][1]:.1f}]"
331
- )
332
- N2_input = gr.Slider(
333
- minimum=0,
334
- maximum=property_ranges['N2'][1],
335
- value=1.5,
336
- label=f"Nβ‚‚ (Barrier) [0-{property_ranges['N2'][1]:.1f}]"
337
- )
338
- O2_input = gr.Slider(
339
- minimum=0,
340
- maximum=property_ranges['O2'][1],
341
- value=2.8,
342
- label=f"OοΏ½οΏ½οΏ½ (Barrier) [0-{property_ranges['O2'][1]:.1f}]"
343
- )
344
 
345
- # Generation Parameters
346
- with gr.Row():
347
- guidance_scale = gr.Slider(
348
- minimum=1,
349
- maximum=3,
350
- value=2,
351
- label="Guidance Scale from Properties"
352
- )
353
- num_nodes = gr.Slider(
354
- minimum=0,
355
- maximum=50,
356
- step=1,
357
- value=0,
358
- label="Number of Nodes (0 for Random, Larger Graphs Take More Time)"
359
- )
360
- repeating_time = gr.Slider(
361
- minimum=1,
362
- maximum=10,
363
- step=1,
364
- value=3,
365
- label="Repetition Until Success"
366
- )
367
- num_chain_steps = gr.Slider(
368
- minimum=0,
369
- maximum=499,
370
- step=1,
371
- value=50,
372
- label="Number of Diffusion Steps to Visualize (Larger Numbers Take More Time)"
373
- )
374
- fps = gr.Slider(
375
- minimum=0.25,
376
- maximum=10,
377
- step=0.25,
378
- value=5,
379
- label="Frames Per Second"
380
- )
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- # Action Buttons
383
- with gr.Row():
384
- random_btn = gr.Button("πŸ”€ Randomize Properties (from Labeled Data)")
385
- generate_btn = gr.Button("πŸš€ Generate Polymer")
386
-
387
- # Results Display
388
- with gr.Row():
389
- result_text = gr.Textbox(label="πŸ“ Generation Result", lines=10)
390
- result_image = gr.Image(label="Final Molecule Visualization", type="pil")
391
- result_gif = gr.Image(label="Generation Process Visualization", type="filepath", format="gif")
392
-
393
- # Feedback Section
394
- with gr.Row():
395
- feedback_btn = gr.Button("🌟 I think this polymer is interesting!", interactive=False)
396
- feedback_result = gr.Textbox(label="Feedback Result", visible=False)
397
-
398
- # Hidden Components to Store Generation Data
399
- hidden_smiles = gr.Textbox(visible=False)
400
- hidden_properties = gr.JSON(visible=False)
401
- hidden_suggested_properties = gr.JSON(visible=False)
402
-
403
- # Event Handlers
404
-
405
- # Model Selection Change
406
- model_choice.change(
407
- switch_model,
408
- inputs=[model_choice],
409
- outputs=[model_state]
410
- )
411
-
412
- # Randomize Properties Button
413
- random_btn.click(
414
- random_properties,
415
- outputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input]
416
- )
417
-
418
- # Generate Polymer Button
419
- generate_btn.click(
420
- generate_graph,
421
- inputs=[
422
- CH4_input, CO2_input, H2_input, N2_input, O2_input,
423
- guidance_scale, num_nodes, repeating_time,
424
- model_state, num_chain_steps, fps
425
- ],
426
- outputs=[
427
- result_text, result_image, result_gif,
428
- hidden_smiles, hidden_properties, hidden_suggested_properties
429
- ]
430
- ).then(
431
- lambda text, img, gif, smiles, props, sugg_props: (
432
- smiles if text.startswith("**Generated polymer SMILES:**") else "",
433
- json.dumps(numpy_to_python(props)),
434
- json.dumps(numpy_to_python(sugg_props)),
435
- gr.Button(interactive=text.startswith("**Generated polymer SMILES:**"))
436
- ),
437
- inputs=[
438
- result_text, result_image, result_gif,
439
- hidden_smiles, hidden_properties, hidden_suggested_properties
440
- ],
441
- outputs=[hidden_smiles, hidden_properties, hidden_suggested_properties, feedback_btn]
442
- )
443
-
444
- # Feedback Button Click
445
- feedback_btn.click(
446
- process_feedback,
447
- inputs=[gr.Checkbox(label="Interested?", value=True, visible=False), hidden_smiles, hidden_properties, hidden_suggested_properties],
448
- outputs=[feedback_result]
449
- ).then(
450
- lambda: gr.Button(interactive=False),
451
- outputs=[feedback_btn]
452
- )
453
-
454
- # # Define the reset_feedback function
455
- # def reset_feedback():
456
- # return gr.Button(interactive=False)
457
-
458
- # CH4_input.change(reset_feedback, outputs=[feedback_btn])
459
- # CO2_input.change(reset_feedback, outputs=[feedback_btn])
460
- # H2_input.change(reset_feedback, outputs=[feedback_btn])
461
- # N2_input.change(reset_feedback, outputs=[feedback_btn])
462
- # O2_input.change(reset_feedback, outputs=[feedback_btn])
463
- # random_btn.click(reset_feedback, outputs=[feedback_btn])
464
-
465
- return iface
466
-
467
- # --------------------------- Main Execution --------------------------- #
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  if __name__ == "__main__":
470
- interface = create_gradio_interface()
471
- interface.launch(share=False)
 
1
  import spaces
2
+ import gradio as gr
3
 
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
 
 
 
7
  import random
8
+ import io
9
+ import imageio
10
+ import os
11
  import tempfile
12
  import atexit
13
+ import glob
14
+ import csv
15
  from datetime import datetime
16
+ import json
17
 
 
 
 
 
 
18
  from rdkit import Chem
19
  from rdkit.Chem import Draw
 
20
 
 
21
  from evaluator import Evaluator
22
+ # from loader import load_graph_decoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ ### load model start
25
+ from graph_decoder.diffusion_model import GraphDiT
26
+ def count_parameters(model):
27
+ r"""
28
+ Returns the number of trainable parameters and number of all parameters in the model.
29
+ """
30
+ trainable_params, all_param = 0, 0
31
+ for param in model.parameters():
32
+ num_params = param.numel()
33
+ all_param += num_params
34
+ if param.requires_grad:
35
+ trainable_params += num_params
36
+
37
+ return trainable_params, all_param
38
+
39
+ def load_graph_decoder(path='model_labeled'):
40
+ model_config_path = f"{path}/config.yaml"
41
+ data_info_path = f"{path}/data.meta.json"
42
+
43
+ model = GraphDiT(
44
+ model_config_path=model_config_path,
45
+ data_info_path=data_info_path,
46
+ model_dtype=torch.float32,
47
+ )
48
+ model.init_model(path)
49
+ model.disable_grads()
50
 
51
+ trainable_params, all_param = count_parameters(model)
52
+ param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
53
+ path, trainable_params, all_param, 100 * trainable_params / all_param
54
+ )
55
+ print(param_stats)
56
+ return model
57
+ ### load model end
58
 
59
+ # Load the CSV data
60
+ known_labels = pd.read_csv('data/known_labels.csv')
61
+ knwon_smiles = pd.read_csv('data/known_polymers.csv')
62
 
63
+ all_properties = ['CH4', 'CO2', 'H2', 'N2', 'O2']
 
 
64
 
65
+ # Initialize evaluators
66
+ evaluators = {prop: Evaluator(f'evaluators/{prop}.joblib', prop) for prop in all_properties}
67
 
68
+ # Get min and max values for each property
69
+ property_ranges = {prop: (known_labels[prop].min(), known_labels[prop].max()) for prop in all_properties}
70
 
71
+ # Create a temporary directory for GIFs
72
+ temp_dir = tempfile.mkdtemp(prefix="polymer_gifs_")
73
 
74
  def cleanup_temp_files():
75
  """Clean up temporary GIF files on exit."""
76
+ for file in glob.glob(os.path.join(temp_dir, "*.gif")):
77
+ try:
78
  os.remove(file)
79
+ except Exception as e:
80
+ print(f"Error deleting {file}: {e}")
81
+ try:
82
  os.rmdir(temp_dir)
83
  except Exception as e:
84
+ print(f"Error deleting temporary directory {temp_dir}: {e}")
85
 
86
+ # Register the cleanup function to be called on exit
87
  atexit.register(cleanup_temp_files)
88
 
 
 
89
  def random_properties():
90
+ return known_labels[all_properties].sample(1).values.tolist()[0]
 
91
 
92
  def load_model(model_choice):
 
93
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
  model = load_graph_decoder(path=model_choice)
95
+ return (model, device)
96
+
97
+ # Create a flagged folder if it doesn't exist
98
+ flagged_folder = "flagged"
99
+ os.makedirs(flagged_folder, exist_ok=True)
100
 
101
  def save_interesting_log(smiles, properties, suggested_properties):
102
+ """Save interesting polymer data to a CSV file."""
103
+ log_file = os.path.join(flagged_folder, "log.csv")
 
104
  file_exists = os.path.isfile(log_file)
105
+
106
+ with open(log_file, 'a', newline='') as csvfile:
107
+ fieldnames = ['timestamp', 'smiles'] + all_properties + [f'suggested_{prop}' for prop in all_properties]
108
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
109
+
110
+ if not file_exists:
111
+ writer.writeheader()
112
+
113
+ log_data = {
114
+ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
115
+ 'smiles': smiles,
116
+ **{prop: value for prop, value in zip(all_properties, properties)},
117
+ **{f'suggested_{prop}': value for prop, value in suggested_properties.items()}
118
+ }
119
+ writer.writerow(log_data)
120
 
121
+ @spaces.GPU
122
+ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
123
+ print('in generate_graph')
124
+ model, device = model_state
125
+
126
+ properties = [CH4, CO2, H2, N2, O2]
127
+
128
+ def is_nan_like(x):
129
+ return x == 0 or x == '' or (isinstance(x, float) and np.isnan(x))
130
+
131
+ properties = [None if is_nan_like(prop) else prop for prop in properties]
132
+
133
+ nan_message = "The following gas properties were treated as NaN: "
134
+ nan_gases = [gas for gas, prop in zip(all_properties, properties) if prop is None]
135
+ nan_message += ", ".join(nan_gases) if nan_gases else "None"
136
 
137
+ num_nodes = None if num_nodes == 0 else num_nodes
138
+
139
+ for _ in range(repeating_time):
140
+ # try:
141
+ model.to(device)
142
+ generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
143
+
144
+ # Create GIF if img_list is available
145
+ gif_path = None
146
+ if img_list and len(img_list) > 0:
147
+ imgs = [np.array(pil_img) for pil_img in img_list]
148
+ imgs.extend([imgs[-1]] * 10)
149
+ gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
150
+ imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
151
+
152
+ if generated_molecule is not None:
153
+ mol = Chem.MolFromSmiles(generated_molecule)
154
+ if mol is not None:
155
+ standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
156
+ is_novel = standardized_smiles not in knwon_smiles['SMILES'].values
157
+ novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
158
+ img = Draw.MolToImage(mol)
159
+
160
+ # Evaluate the generated molecule
161
+ suggested_properties = {}
162
+ for prop, evaluator in evaluators.items():
163
+ suggested_properties[prop] = evaluator([standardized_smiles])[0]
164
+
165
+ suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
166
+
167
+ return (
168
+ f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
169
+ f"**{nan_message}**\n\n"
170
+ f"**{novelty_status}**\n\n"
171
+ f"**Suggested Properties:**\n{suggested_properties_text}",
172
+ img,
173
+ gif_path,
174
+ properties, # Add this
175
+ suggested_properties # Add this
176
+ )
177
+ else:
178
+ return (
179
+ f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
180
+ None,
181
+ gif_path,
182
+ properties,
183
+ None,
184
+ )
185
+ # except Exception as e:
186
+ # print(f"Error in generation: {e}")
187
+ # continue
188
+
189
+ return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
190
 
191
+ def set_random_properties():
192
+ return random_properties()
193
+
194
+ # Create a mapping of internal names to display names
195
+ model_name_mapping = {
196
+ "model_all": "Graph DiT (trained on labeled + unlabeled)",
197
+ "model_labeled": "Graph DiT (trained on labeled)"
198
+ }
199
 
200
  def numpy_to_python(obj):
 
201
  if isinstance(obj, np.integer):
202
  return int(obj)
203
  elif isinstance(obj, np.floating):
 
210
  return {k: numpy_to_python(v) for k, v in obj.items()}
211
  else:
212
  return obj
213
+
214
+ def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
215
+ print('in on_generate', on_generate)
216
+ result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps)
217
+ # Check if the generation was successful
218
+ if result[0].startswith("**Generated polymer SMILES:**"):
219
+ smiles = result[0].split("**Generated polymer SMILES:** `")[1].split("`")[0]
220
+ properties = json.dumps(numpy_to_python(result[3]))
221
+ suggested_properties = json.dumps(numpy_to_python(result[4]))
222
+ # Return the result with an enabled feedback button
223
+ return [*result[:3], smiles, properties, suggested_properties, gr.Button(interactive=True)]
224
+ else:
225
+ # Return the result with a disabled feedback button
226
+ return [*result[:3], "", "[]", "[]", gr.Button(interactive=False)]
227
 
228
+ def process_feedback(checkbox_value, smiles, properties, suggested_properties):
229
+ if checkbox_value:
230
+ # Check if properties and suggested_properties are already Python objects
231
+ if isinstance(properties, str):
232
+ properties = json.loads(properties)
233
+ if isinstance(suggested_properties, str):
234
+ suggested_properties = json.loads(suggested_properties)
235
+
236
+ save_interesting_log(smiles, properties, suggested_properties)
237
+ return gr.Textbox(value="Thank you for your feedback! This polymer has been saved to our interesting polymers log.", visible=True)
238
+ else:
239
+ return gr.Textbox(value="Thank you for your feedback!", visible=True)
240
 
241
+ # ADD THIS FUNCTION
242
+ def reset_feedback_button():
243
+ return gr.Button(interactive=False)
244
 
245
+ # Create the Gradio interface using Blocks
246
+ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
247
+ # Navigation Bar
248
+ with gr.Row(elem_id="navbar"):
249
+ gr.Markdown("""
250
+ <div style="text-align: center;">
251
+ <h1>πŸ”—πŸ”¬ Polymer Design with GraphDiT</h1>
252
+ <div style="display: flex; gap: 20px; justify-content: center; align-items: center; margin-top: 10px;">
253
+ <a href="https://github.com/liugangcode/Graph-DiT" target="_blank" style="display: flex; align-items: center; gap: 5px; text-decoration: none; color: inherit;">
254
+ <img src="https://img.icons8.com/ios-glyphs/30/000000/github.png" alt="GitHub" />
255
+ <span>View Code</span>
256
+ </a>
257
+ <a href="https://arxiv.org/abs/2401.13858" target="_blank" style="text-decoration: none; color: inherit;">
258
+ πŸ“„ View Paper
259
+ </a>
260
+ </div>
261
+ </div>
262
+ """)
263
 
264
+ # Main Description
265
+ gr.Markdown("""
266
+ ## Introduction
267
 
268
+ Input the desired gas barrier properties for CHβ‚„, COβ‚‚, Hβ‚‚, Nβ‚‚, and Oβ‚‚ to generate novel polymer structures. The results are visualized as molecular graphs and represented by SMILES strings if they are successfully generated. Note: Gas barrier values set to 0 will be treated as `NaN` (unconditionally). If the generation fails, please retry or increase the number of repetition attempts.
269
+ """)
 
 
 
 
 
 
 
270
 
271
+ # Model Selection
272
+ model_choice = gr.Radio(
273
+ choices=list(model_name_mapping.values()),
274
+ label="Model Zoo",
275
+ # value="Graph DiT (trained on labeled + unlabeled)"
276
+ value="Graph DiT (trained on labeled)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  )
278
 
279
+ # Model Description Accordion
280
+ with gr.Accordion("πŸ” Model Description", open=False):
281
+ gr.Markdown("""
282
+ ### GraphDiT: Graph Diffusion Transformer
283
 
284
+ GraphDiT is a graph diffusion model designed for targeted molecular generation. It employs a conditional diffusion process to iteratively refine molecular structures based on user-specified properties.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ We have collected a labeled polymer database for gas permeability from [Membrane Database](https://research.csiro.au/virtualscreening/membrane-database-polymer-gas-separation-membranes/). Additionally, we utilize unlabeled polymer structures from [PolyInfo](https://polymer.nims.go.jp/).
287
+
288
+ The gas permeability ranges from 0 to over ten thousand, with only hundreds of labeled data points, making this task particularly challenging.
289
+
290
+ We are actively working on improving the model. We welcome any feedback regarding model usage or suggestions for improvement.
291
 
292
+ #### Currently, we have two variants of Graph DiT:
293
+ - **Graph DiT (trained on labeled + unlabeled)**: This model uses both labeled and unlabeled data for training, potentially leading to more diverse/novel polymer generation.
294
+ - **Graph DiT (trained on labeled)**: This model is trained exclusively on labeled data, which may result in higher validity but potentially less diverse/novel outputs.
295
  """)
296
 
297
+ # Citation Accordion
298
+ with gr.Accordion("πŸ“„ Citation", open=False):
299
+ gr.Markdown("""
300
+ If you use this model or interface useful, please cite the following paper:
301
+ ```bibtex
302
+ @article{graphdit2024,
303
+ title={Graph Diffusion Transformers for Multi-Conditional Molecular Generation},
304
+ author={Liu, Gang and Xu, Jiaxin and Luo, Tengfei and Jiang, Meng},
305
+ journal={NeurIPS},
306
+ year={2024},
307
+ }
308
+ ```
309
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
+ model_state = gr.State(lambda: load_model("model_labeled"))
312
+
313
+ with gr.Row():
314
+ CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CHβ‚„ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
315
+ CO2_input = gr.Slider(0, property_ranges['CO2'][1], value=15.4, label=f"COβ‚‚ (Barrier) [0-{property_ranges['CO2'][1]:.1f}]")
316
+ H2_input = gr.Slider(0, property_ranges['H2'][1], value=21.0, label=f"Hβ‚‚ (Barrier) [0-{property_ranges['H2'][1]:.1f}]")
317
+ N2_input = gr.Slider(0, property_ranges['N2'][1], value=1.5, label=f"Nβ‚‚ (Barrier) [0-{property_ranges['N2'][1]:.1f}]")
318
+ O2_input = gr.Slider(0, property_ranges['O2'][1], value=2.8, label=f"Oβ‚‚ (Barrier) [0-{property_ranges['O2'][1]:.1f}]")
319
+
320
+ with gr.Row():
321
+ guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale from Properties")
322
+ num_nodes = gr.Slider(0, 50, step=1, value=0, label="Number of Nodes (0 for Random, Larger Graphs Take More Time)")
323
+ repeating_time = gr.Slider(1, 10, step=1, value=3, label="Repetition Until Success")
324
+ num_chain_steps = gr.Slider(0, 499, step=1, value=50, label="Number of Diffusion Steps to Visualize (Larger Numbers Take More Time)")
325
+ fps = gr.Slider(0.25, 10, step=0.25, value=5, label="Frames Per Second")
326
+
327
+ with gr.Row():
328
+ random_btn = gr.Button("πŸ”€ Randomize Properties (from Labeled Data)")
329
+ generate_btn = gr.Button("πŸš€ Generate Polymer")
330
+
331
+ with gr.Row():
332
+ result_text = gr.Textbox(label="πŸ“ Generation Result")
333
+ result_image = gr.Image(label="Final Molecule Visualization", type="pil")
334
+ result_gif = gr.Image(label="Generation Process Visualization", type="filepath", format="gif")
335
+
336
+ with gr.Row() as feedback_row:
337
+ feedback_btn = gr.Button("🌟 I think this polymer is interesting!", visible=True, interactive=False)
338
+ feedback_result = gr.Textbox(label="Feedback Result", visible=False)
339
+
340
+ # Add model switching functionality
341
+ def switch_model(choice):
342
+ # Convert display name back to internal name
343
+ internal_name = next(key for key, value in model_name_mapping.items() if value == choice)
344
+ return load_model(internal_name)
345
+
346
+ model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
347
+
348
+ # Hidden components to store generation data
349
+ hidden_smiles = gr.Textbox(visible=False)
350
+ hidden_properties = gr.JSON(visible=False)
351
+ hidden_suggested_properties = gr.JSON(visible=False)
352
+
353
+ # Set up event handlers
354
+ random_btn.click(
355
+ set_random_properties,
356
+ outputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input]
357
+ )
358
 
359
+ generate_btn.click(
360
+ on_generate,
361
+ inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps],
362
+ outputs=[result_text, result_image, result_gif, hidden_smiles, hidden_properties, hidden_suggested_properties, feedback_btn]
363
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
+ feedback_btn.click(
366
+ process_feedback,
367
+ inputs=[gr.Checkbox(value=True, visible=False), hidden_smiles, hidden_properties, hidden_suggested_properties],
368
+ outputs=[feedback_result]
369
+ ).then(
370
+ lambda: gr.Button(interactive=False),
371
+ outputs=[feedback_btn]
372
+ )
373
+
374
+ CH4_input.change(reset_feedback_button, outputs=[feedback_btn])
375
+ CO2_input.change(reset_feedback_button, outputs=[feedback_btn])
376
+ H2_input.change(reset_feedback_button, outputs=[feedback_btn])
377
+ N2_input.change(reset_feedback_button, outputs=[feedback_btn])
378
+ O2_input.change(reset_feedback_button, outputs=[feedback_btn])
379
+ random_btn.click(reset_feedback_button, outputs=[feedback_btn])
380
+
381
+ # Launch the interface
382
  if __name__ == "__main__":
383
+ # iface.launch(share=True)
384
+ iface.launch(share=False)