liuganghuggingface commited on
Commit
ac6db00
·
verified ·
1 Parent(s): debc746

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -58
app.py CHANGED
@@ -56,18 +56,6 @@ def random_properties():
56
  def load_model(model_choice):
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
  model = load_graph_decoder(path=model_choice)
59
- #### test
60
- # from graph_decoder.diffusion_model import GraphDiT
61
-
62
- # model_config_path = f"model_labeled/config.yaml"
63
- # data_info_path = f"model_labeled/data.meta.json"
64
- # model = GraphDiT(
65
- # model_config_path=model_config_path,
66
- # data_info_path=data_info_path,
67
- # # model_dtype=torch.float16,
68
- # model_dtype=torch.float32,
69
- # )
70
- ### test
71
  return (model, device)
72
 
73
  # Create a flagged folder if it doesn't exist
@@ -112,55 +100,53 @@ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_ti
112
  num_nodes = None if num_nodes == 0 else num_nodes
113
 
114
  for _ in range(repeating_time):
115
- # try:
116
- model.to(device)
117
- generated_molecule, img_list = model.generate(properties, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
118
- # generated_molecule = 'C'
119
- # img_list = []
120
- # Create GIF if img_list is available
121
- gif_path = None
122
- if img_list and len(img_list) > 0:
123
- imgs = [np.array(pil_img) for pil_img in img_list]
124
- imgs.extend([imgs[-1]] * 10)
125
- gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
126
- imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
127
-
128
- if generated_molecule is not None:
129
- mol = Chem.MolFromSmiles(generated_molecule)
130
- if mol is not None:
131
- standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
132
- is_novel = standardized_smiles not in knwon_smiles['SMILES'].values
133
- novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
134
- img = Draw.MolToImage(mol)
135
-
136
- # Evaluate the generated molecule
137
- suggested_properties = {}
138
- for prop, evaluator in evaluators.items():
139
- suggested_properties[prop] = evaluator([standardized_smiles])[0]
140
-
141
- suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
142
-
 
 
 
 
 
 
 
 
 
143
  return (
144
- f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
145
- f"**{nan_message}**\n\n"
146
- f"**{novelty_status}**\n\n"
147
- f"**Suggested Properties:**\n{suggested_properties_text}",
148
- img,
149
  gif_path,
150
- properties, # Add this
151
- suggested_properties # Add this
152
  )
153
- else:
154
- return (
155
- f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
156
- None,
157
- gif_path,
158
- properties,
159
- None,
160
- )
161
- # except Exception as e:
162
- # print(f"Error in generation: {e}")
163
- # continue
164
 
165
  return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
166
 
 
56
  def load_model(model_choice):
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
  model = load_graph_decoder(path=model_choice)
 
 
 
 
 
 
 
 
 
 
 
 
59
  return (model, device)
60
 
61
  # Create a flagged folder if it doesn't exist
 
100
  num_nodes = None if num_nodes == 0 else num_nodes
101
 
102
  for _ in range(repeating_time):
103
+ try:
104
+ model.to(device)
105
+ generated_molecule, img_list = model.generate(properties, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
106
+ # Create GIF if img_list is available
107
+ gif_path = None
108
+ if img_list and len(img_list) > 0:
109
+ imgs = [np.array(pil_img) for pil_img in img_list]
110
+ imgs.extend([imgs[-1]] * 10)
111
+ gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
112
+ imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
113
+
114
+ if generated_molecule is not None:
115
+ mol = Chem.MolFromSmiles(generated_molecule)
116
+ if mol is not None:
117
+ standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
118
+ is_novel = standardized_smiles not in knwon_smiles['SMILES'].values
119
+ novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
120
+ img = Draw.MolToImage(mol)
121
+
122
+ # Evaluate the generated molecule
123
+ suggested_properties = {}
124
+ for prop, evaluator in evaluators.items():
125
+ suggested_properties[prop] = evaluator([standardized_smiles])[0]
126
+
127
+ suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
128
+
129
+ return (
130
+ f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
131
+ f"**{nan_message}**\n\n"
132
+ f"**{novelty_status}**\n\n"
133
+ f"**Suggested Properties:**\n{suggested_properties_text}",
134
+ img,
135
+ gif_path,
136
+ properties, # Add this
137
+ suggested_properties # Add this
138
+ )
139
+ else:
140
  return (
141
+ f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
142
+ None,
 
 
 
143
  gif_path,
144
+ properties,
145
+ None,
146
  )
147
+ except Exception as e:
148
+ print(f"Error in generation: {e}")
149
+ continue
 
 
 
 
 
 
 
 
150
 
151
  return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
152