liuganghuggingface commited on
Commit
e513fba
·
verified ·
1 Parent(s): f0d7ed4

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +368 -1
app.py CHANGED
@@ -6,7 +6,374 @@ import random
6
  from rdkit import Chem
7
  from rdkit.Chem import Draw
8
 
9
- from graph_decoder.diffusion_model import GraphDiT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def load_graph_decoder(path='model_labeled'):
11
  model = GraphDiT(
12
  model_config_path=f"{path}/config.yaml",
 
6
  from rdkit import Chem
7
  from rdkit.Chem import Draw
8
 
9
+ #####
10
+
11
+ import os
12
+ import yaml
13
+ import json
14
+
15
+ import torch.nn.functional as F
16
+ from graph_decoder import diffusion_utils as utils
17
+ from graph_decoder.molecule_utils import graph_to_smiles, check_valid
18
+ from graph_decoder.transformer import Transformer
19
+ from graph_decoder.visualize_utils import MolecularVisualization
20
+
21
+ class GraphDiT(nn.Module):
22
+ def __init__(
23
+ self,
24
+ model_config_path,
25
+ data_info_path,
26
+ model_dtype,
27
+ ):
28
+ super().__init__()
29
+ pass
30
+
31
+ # dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
32
+
33
+ # input_dims = data_info.input_dims
34
+ # output_dims = data_info.output_dims
35
+ # nodes_dist = data_info.nodes_dist
36
+ # active_index = data_info.active_index
37
+
38
+ # self.model_config = dm_cfg
39
+ # self.data_info = data_info
40
+ # self.T = dm_cfg.diffusion_steps
41
+ # self.Xdim = input_dims["X"]
42
+ # self.Edim = input_dims["E"]
43
+ # self.ydim = input_dims["y"]
44
+ # self.Xdim_output = output_dims["X"]
45
+ # self.Edim_output = output_dims["E"]
46
+ # self.ydim_output = output_dims["y"]
47
+ # self.node_dist = nodes_dist
48
+ # self.active_index = active_index
49
+ # self.max_n_nodes = data_info.max_n_nodes
50
+ # self.atom_decoder = data_info.atom_decoder
51
+ # self.hidden_size = dm_cfg.hidden_size
52
+ # self.mol_visualizer = MolecularVisualization(self.atom_decoder)
53
+
54
+ # self.denoiser = Transformer(
55
+ # max_n_nodes=self.max_n_nodes,
56
+ # hidden_size=dm_cfg.hidden_size,
57
+ # depth=dm_cfg.depth,
58
+ # num_heads=dm_cfg.num_heads,
59
+ # mlp_ratio=dm_cfg.mlp_ratio,
60
+ # drop_condition=dm_cfg.drop_condition,
61
+ # Xdim=self.Xdim,
62
+ # Edim=self.Edim,
63
+ # ydim=self.ydim,
64
+ # )
65
+
66
+ # self.model_dtype = model_dtype
67
+ # self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
68
+ # dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
69
+ # )
70
+ # x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
71
+ # data_info.node_types.to(self.model_dtype)
72
+ # )
73
+ # e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
74
+ # data_info.edge_types.to(self.model_dtype)
75
+ # )
76
+ # x_marginals = x_marginals / x_marginals.sum()
77
+ # e_marginals = e_marginals / e_marginals.sum()
78
+
79
+ # xe_conditions = data_info.transition_E.to(self.model_dtype)
80
+ # xe_conditions = xe_conditions[self.active_index][:, self.active_index]
81
+
82
+ # xe_conditions = xe_conditions.sum(dim=1)
83
+ # ex_conditions = xe_conditions.t()
84
+ # xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
85
+ # ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
86
+
87
+ # self.transition_model = utils.MarginalTransition(
88
+ # x_marginals=x_marginals,
89
+ # e_marginals=e_marginals,
90
+ # xe_conditions=xe_conditions,
91
+ # ex_conditions=ex_conditions,
92
+ # y_classes=self.ydim_output,
93
+ # n_nodes=self.max_n_nodes,
94
+ # )
95
+ # self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
96
+
97
+ def init_model(self, model_dir):
98
+ model_file = os.path.join(model_dir, 'model.pt')
99
+ if os.path.exists(model_file):
100
+ self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
101
+ else:
102
+ raise FileNotFoundError(f"Model file not found: {model_file}")
103
+
104
+ def disable_grads(self):
105
+ self.denoiser.disable_grads()
106
+
107
+ def forward(
108
+ self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
109
+ ):
110
+ raise ValueError('Not Implement')
111
+
112
+ def _forward(self, noisy_data, unconditioned=False):
113
+ noisy_x, noisy_e, properties = (
114
+ noisy_data["X_t"].to(self.model_dtype),
115
+ noisy_data["E_t"].to(self.model_dtype),
116
+ noisy_data["y_t"].to(self.model_dtype).clone(),
117
+ )
118
+ node_mask, timestep = (
119
+ noisy_data["node_mask"],
120
+ noisy_data["t"],
121
+ )
122
+
123
+ pred = self.denoiser(
124
+ noisy_x,
125
+ noisy_e,
126
+ node_mask,
127
+ properties,
128
+ timestep,
129
+ unconditioned=unconditioned,
130
+ )
131
+ return pred
132
+
133
+ def apply_noise(self, X, E, y, node_mask):
134
+ """Sample noise and apply it to the data."""
135
+
136
+ # Sample a timestep t.
137
+ # When evaluating, the loss for t=0 is computed separately
138
+ lowest_t = 0 if self.training else 1
139
+ t_int = torch.randint(
140
+ lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
141
+ ).to(
142
+ self.model_dtype
143
+ ) # (bs, 1)
144
+ s_int = t_int - 1
145
+
146
+ t_float = t_int / self.T
147
+ s_float = s_int / self.T
148
+
149
+ # beta_t and alpha_s_bar are used for denoising/loss computation
150
+ beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
151
+ alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
152
+ alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
153
+
154
+ Qtb = self.transition_model.get_Qt_bar(
155
+ alpha_t_bar, X.device
156
+ ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
157
+
158
+ bs, n, d = X.shape
159
+ X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
160
+ prob_all = X_all @ Qtb.X
161
+ probX = prob_all[:, :, : self.Xdim_output]
162
+ probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
163
+
164
+ sampled_t = utils.sample_discrete_features(
165
+ probX=probX, probE=probE, node_mask=node_mask
166
+ )
167
+
168
+ X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
169
+ E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
170
+ assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
171
+
172
+ y_t = y
173
+ z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
174
+
175
+ noisy_data = {
176
+ "t_int": t_int,
177
+ "t": t_float,
178
+ "beta_t": beta_t,
179
+ "alpha_s_bar": alpha_s_bar,
180
+ "alpha_t_bar": alpha_t_bar,
181
+ "X_t": z_t.X,
182
+ "E_t": z_t.E,
183
+ "y_t": z_t.y,
184
+ "node_mask": node_mask,
185
+ }
186
+ return noisy_data
187
+
188
+ @torch.no_grad()
189
+ def generate(
190
+ self,
191
+ properties,
192
+ device,
193
+ guide_scale=1.,
194
+ num_nodes=None,
195
+ number_chain_steps=50,
196
+ ):
197
+ properties = [float('nan') if x is None else x for x in properties]
198
+ properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
199
+ batch_size = properties.size(0)
200
+ assert batch_size == 1
201
+ if num_nodes is None:
202
+ num_nodes = self.node_dist.sample_n(batch_size, device)
203
+ else:
204
+ num_nodes = torch.LongTensor([num_nodes]).to(device)
205
+
206
+ arange = (
207
+ torch.arange(self.max_n_nodes, device=device)
208
+ .unsqueeze(0)
209
+ .expand(batch_size, -1)
210
+ )
211
+ node_mask = arange < num_nodes.unsqueeze(1)
212
+
213
+ z_T = utils.sample_discrete_feature_noise(
214
+ limit_dist=self.limit_dist, node_mask=node_mask
215
+ )
216
+ X, E = z_T.X, z_T.E
217
+
218
+ assert (E == torch.transpose(E, 1, 2)).all()
219
+
220
+ if number_chain_steps > 0:
221
+ chain_X_size = torch.Size((number_chain_steps, X.size(1)))
222
+ chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
223
+ chain_X = torch.zeros(chain_X_size)
224
+ chain_E = torch.zeros(chain_E_size)
225
+
226
+ # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
227
+ y = properties
228
+ for s_int in reversed(range(0, self.T)):
229
+ s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
230
+ t_array = s_array + 1
231
+ s_norm = s_array / self.T
232
+ t_norm = t_array / self.T
233
+
234
+ # Sample z_s
235
+ sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
236
+ s_norm, t_norm, X, E, y, node_mask, guide_scale, device
237
+ )
238
+ X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
239
+
240
+ if number_chain_steps > 0:
241
+ # Save the first keep_chain graphs
242
+ write_index = (s_int * number_chain_steps) // self.T
243
+ chain_X[write_index] = discrete_sampled_s.X[:1]
244
+ chain_E[write_index] = discrete_sampled_s.E[:1]
245
+
246
+ # Sample
247
+ sampled_s = sampled_s.mask(node_mask, collapse=True)
248
+ X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
249
+
250
+ molecule_list = []
251
+ n = num_nodes[0]
252
+ atom_types = X[0, :n].cpu()
253
+ edge_types = E[0, :n, :n].cpu()
254
+ molecule_list.append([atom_types, edge_types])
255
+ smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
256
+
257
+ # Visualize Chains
258
+ if number_chain_steps > 0:
259
+ final_X_chain = X[:1]
260
+ final_E_chain = E[:1]
261
+
262
+ chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
263
+ chain_E[0] = final_E_chain
264
+
265
+ chain_X = utils.reverse_tensor(chain_X)
266
+ chain_E = utils.reverse_tensor(chain_E)
267
+
268
+ # Repeat last frame to see final sample better
269
+ chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
270
+ chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
271
+ mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
272
+ else:
273
+ mol_img_list = []
274
+
275
+ return smiles, mol_img_list
276
+
277
+ def check_valid(self, smiles):
278
+ return check_valid(smiles)
279
+
280
+ def sample_p_zs_given_zt(
281
+ self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
282
+ ):
283
+ """Samples from zs ~ p(zs | zt). Only used during sampling.
284
+ if last_step, return the graph prediction as well"""
285
+ bs, n, _ = X_t.shape
286
+ beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
287
+ alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
288
+ alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
289
+
290
+ # Neural net predictions
291
+ noisy_data = {
292
+ "X_t": X_t,
293
+ "E_t": E_t,
294
+ "y_t": properties,
295
+ "t": t,
296
+ "node_mask": node_mask,
297
+ }
298
+
299
+ def get_prob(noisy_data, unconditioned=False):
300
+ pred = self._forward(noisy_data, unconditioned=unconditioned)
301
+
302
+ # Normalize predictions
303
+ pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
304
+ pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
305
+
306
+ # Retrieve transitions matrix
307
+ Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
308
+ Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
309
+ Qt = self.transition_model.get_Qt(beta_t, device)
310
+
311
+ Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
312
+ predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
313
+
314
+ unnormalized_probX_all = utils.reverse_diffusion(
315
+ predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
316
+ )
317
+
318
+ unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
319
+ unnormalized_prob_E = unnormalized_probX_all[
320
+ :, :, self.Xdim_output :
321
+ ].reshape(bs, n * n, -1)
322
+
323
+ unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
324
+ unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
325
+
326
+ prob_X = unnormalized_prob_X / torch.sum(
327
+ unnormalized_prob_X, dim=-1, keepdim=True
328
+ ) # bs, n, d_t-1
329
+ prob_E = unnormalized_prob_E / torch.sum(
330
+ unnormalized_prob_E, dim=-1, keepdim=True
331
+ ) # bs, n, d_t-1
332
+ prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
333
+
334
+ return prob_X, prob_E
335
+
336
+ prob_X, prob_E = get_prob(noisy_data)
337
+
338
+ ### Guidance
339
+ if guide_scale != 1:
340
+ uncon_prob_X, uncon_prob_E = get_prob(
341
+ noisy_data, unconditioned=True
342
+ )
343
+ prob_X = (
344
+ uncon_prob_X
345
+ * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
346
+ )
347
+ prob_E = (
348
+ uncon_prob_E
349
+ * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
350
+ )
351
+ prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
352
+ prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
353
+
354
+ # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
355
+ # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
356
+
357
+ sampled_s = utils.sample_discrete_features(
358
+ prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
359
+ )
360
+
361
+ X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
362
+ E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
363
+
364
+ assert (E_s == torch.transpose(E_s, 1, 2)).all()
365
+ assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
366
+
367
+ out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
368
+ out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
369
+
370
+ return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
371
+ node_mask, collapse=True
372
+ ).type_as(properties)
373
+
374
+
375
+ #####
376
+ # from graph_decoder.diffusion_model import GraphDiT
377
  def load_graph_decoder(path='model_labeled'):
378
  model = GraphDiT(
379
  model_config_path=f"{path}/config.yaml",