liuganghuggingface commited on
Commit
c2b7ab5
·
verified ·
1 Parent(s): 71b6e98

Update graph_decoder/diffusion_model.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_model.py +351 -351
graph_decoder/diffusion_model.py CHANGED
@@ -11,22 +11,6 @@ from .molecule_utils import graph_to_smiles, check_valid
11
  from .transformer import Transformer
12
  from .visualize_utils import MolecularVisualization
13
 
14
- # class GraphDiT(nn.Module):
15
- # def __init__(
16
- # self,
17
- # model_config_path,
18
- # data_info_path,
19
- # model_dtype,
20
- # ):
21
- # super().__init__()
22
-
23
- # def init_model(self, model_dir):
24
- # pass
25
-
26
- # def disable_grads(self):
27
- # pass
28
-
29
-
30
  class GraphDiT(nn.Module):
31
  def __init__(
32
  self,
@@ -35,346 +19,362 @@ class GraphDiT(nn.Module):
35
  model_dtype,
36
  ):
37
  super().__init__()
38
- dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
39
-
40
- input_dims = data_info.input_dims
41
- output_dims = data_info.output_dims
42
- nodes_dist = data_info.nodes_dist
43
- active_index = data_info.active_index
44
-
45
- self.model_config = dm_cfg
46
- self.data_info = data_info
47
- self.T = dm_cfg.diffusion_steps
48
- self.Xdim = input_dims["X"]
49
- self.Edim = input_dims["E"]
50
- self.ydim = input_dims["y"]
51
- self.Xdim_output = output_dims["X"]
52
- self.Edim_output = output_dims["E"]
53
- self.ydim_output = output_dims["y"]
54
- self.node_dist = nodes_dist
55
- self.active_index = active_index
56
- self.max_n_nodes = data_info.max_n_nodes
57
- self.atom_decoder = data_info.atom_decoder
58
- self.hidden_size = dm_cfg.hidden_size
59
- self.mol_visualizer = MolecularVisualization(self.atom_decoder)
60
-
61
- self.denoiser = Transformer(
62
- max_n_nodes=self.max_n_nodes,
63
- hidden_size=dm_cfg.hidden_size,
64
- depth=dm_cfg.depth,
65
- num_heads=dm_cfg.num_heads,
66
- mlp_ratio=dm_cfg.mlp_ratio,
67
- drop_condition=dm_cfg.drop_condition,
68
- Xdim=self.Xdim,
69
- Edim=self.Edim,
70
- ydim=self.ydim,
71
- )
72
-
73
- self.model_dtype = model_dtype
74
- self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
75
- dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
76
- )
77
- x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
78
- data_info.node_types.to(self.model_dtype)
79
- )
80
- e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
81
- data_info.edge_types.to(self.model_dtype)
82
- )
83
- x_marginals = x_marginals / x_marginals.sum()
84
- e_marginals = e_marginals / e_marginals.sum()
85
-
86
- xe_conditions = data_info.transition_E.to(self.model_dtype)
87
- xe_conditions = xe_conditions[self.active_index][:, self.active_index]
88
-
89
- xe_conditions = xe_conditions.sum(dim=1)
90
- ex_conditions = xe_conditions.t()
91
- xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
92
- ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
93
-
94
- self.transition_model = utils.MarginalTransition(
95
- x_marginals=x_marginals,
96
- e_marginals=e_marginals,
97
- xe_conditions=xe_conditions,
98
- ex_conditions=ex_conditions,
99
- y_classes=self.ydim_output,
100
- n_nodes=self.max_n_nodes,
101
- )
102
- self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
103
 
104
  def init_model(self, model_dir):
105
- model_file = os.path.join(model_dir, 'model.pt')
106
- if os.path.exists(model_file):
107
- self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
108
- else:
109
- raise FileNotFoundError(f"Model file not found: {model_file}")
110
-
111
  def disable_grads(self):
112
- self.denoiser.disable_grads()
113
 
114
- def forward(
115
- self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
116
- ):
117
- raise ValueError('Not Implement')
118
-
119
- def _forward(self, noisy_data, unconditioned=False):
120
- noisy_x, noisy_e, properties = (
121
- noisy_data["X_t"].to(self.model_dtype),
122
- noisy_data["E_t"].to(self.model_dtype),
123
- noisy_data["y_t"].to(self.model_dtype).clone(),
124
- )
125
- node_mask, timestep = (
126
- noisy_data["node_mask"],
127
- noisy_data["t"],
128
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- pred = self.denoiser(
131
- noisy_x,
132
- noisy_e,
133
- node_mask,
134
- properties,
135
- timestep,
136
- unconditioned=unconditioned,
137
- )
138
- return pred
139
-
140
- def apply_noise(self, X, E, y, node_mask):
141
- """Sample noise and apply it to the data."""
142
-
143
- # Sample a timestep t.
144
- # When evaluating, the loss for t=0 is computed separately
145
- lowest_t = 0 if self.training else 1
146
- t_int = torch.randint(
147
- lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
148
- ).to(
149
- self.model_dtype
150
- ) # (bs, 1)
151
- s_int = t_int - 1
152
-
153
- t_float = t_int / self.T
154
- s_float = s_int / self.T
155
-
156
- # beta_t and alpha_s_bar are used for denoising/loss computation
157
- beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
158
- alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
159
- alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
160
-
161
- Qtb = self.transition_model.get_Qt_bar(
162
- alpha_t_bar, X.device
163
- ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
164
-
165
- bs, n, d = X.shape
166
- X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
167
- prob_all = X_all @ Qtb.X
168
- probX = prob_all[:, :, : self.Xdim_output]
169
- probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
170
-
171
- sampled_t = utils.sample_discrete_features(
172
- probX=probX, probE=probE, node_mask=node_mask
173
- )
174
-
175
- X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
176
- E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
177
- assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
178
-
179
- y_t = y
180
- z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
181
-
182
- noisy_data = {
183
- "t_int": t_int,
184
- "t": t_float,
185
- "beta_t": beta_t,
186
- "alpha_s_bar": alpha_s_bar,
187
- "alpha_t_bar": alpha_t_bar,
188
- "X_t": z_t.X,
189
- "E_t": z_t.E,
190
- "y_t": z_t.y,
191
- "node_mask": node_mask,
192
- }
193
- return noisy_data
194
-
195
- @torch.no_grad()
196
- def generate(
197
- self,
198
- properties,
199
- guide_scale=1.,
200
- num_nodes=None,
201
- number_chain_steps=50,
202
- ):
203
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
204
- properties = [float('nan') if x is None else x for x in properties]
205
- properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
206
- batch_size = properties.size(0)
207
- assert batch_size == 1
208
- if num_nodes is None:
209
- num_nodes = self.node_dist.sample_n(batch_size, device)
210
- else:
211
- num_nodes = torch.LongTensor([num_nodes]).to(device)
212
-
213
- arange = (
214
- torch.arange(self.max_n_nodes, device=device)
215
- .unsqueeze(0)
216
- .expand(batch_size, -1)
217
- )
218
- node_mask = arange < num_nodes.unsqueeze(1)
219
-
220
- z_T = utils.sample_discrete_feature_noise(
221
- limit_dist=self.limit_dist, node_mask=node_mask
222
- )
223
- X, E = z_T.X, z_T.E
224
-
225
- assert (E == torch.transpose(E, 1, 2)).all()
226
-
227
- if number_chain_steps > 0:
228
- chain_X_size = torch.Size((number_chain_steps, X.size(1)))
229
- chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
230
- chain_X = torch.zeros(chain_X_size)
231
- chain_E = torch.zeros(chain_E_size)
232
-
233
- # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
234
- y = properties
235
- for s_int in reversed(range(0, self.T)):
236
- s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
237
- t_array = s_array + 1
238
- s_norm = s_array / self.T
239
- t_norm = t_array / self.T
240
-
241
- # Sample z_s
242
- sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
243
- s_norm, t_norm, X, E, y, node_mask, guide_scale, device
244
- )
245
- X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
246
 
247
- if number_chain_steps > 0:
248
- # Save the first keep_chain graphs
249
- write_index = (s_int * number_chain_steps) // self.T
250
- chain_X[write_index] = discrete_sampled_s.X[:1]
251
- chain_E[write_index] = discrete_sampled_s.E[:1]
252
-
253
- # Sample
254
- sampled_s = sampled_s.mask(node_mask, collapse=True)
255
- X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
256
-
257
- molecule_list = []
258
- n = num_nodes[0]
259
- atom_types = X[0, :n].cpu()
260
- edge_types = E[0, :n, :n].cpu()
261
- molecule_list.append([atom_types, edge_types])
262
- smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
263
-
264
- # Visualize Chains
265
- if number_chain_steps > 0:
266
- final_X_chain = X[:1]
267
- final_E_chain = E[:1]
268
-
269
- chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
270
- chain_E[0] = final_E_chain
271
-
272
- chain_X = utils.reverse_tensor(chain_X)
273
- chain_E = utils.reverse_tensor(chain_E)
274
-
275
- # Repeat last frame to see final sample better
276
- chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
277
- chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
278
- mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
279
- else:
280
- mol_img_list = []
281
-
282
- return smiles, mol_img_list
283
-
284
- def check_valid(self, smiles):
285
- return check_valid(smiles)
286
 
287
- def sample_p_zs_given_zt(
288
- self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
289
- ):
290
- """Samples from zs ~ p(zs | zt). Only used during sampling.
291
- if last_step, return the graph prediction as well"""
292
- bs, n, _ = X_t.shape
293
- beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
294
- alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
295
- alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
296
-
297
- # Neural net predictions
298
- noisy_data = {
299
- "X_t": X_t,
300
- "E_t": E_t,
301
- "y_t": properties,
302
- "t": t,
303
- "node_mask": node_mask,
304
- }
305
-
306
- def get_prob(noisy_data, unconditioned=False):
307
- pred = self._forward(noisy_data, unconditioned=unconditioned)
308
-
309
- # Normalize predictions
310
- pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
311
- pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
312
-
313
- # Retrieve transitions matrix
314
- Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
315
- Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
316
- Qt = self.transition_model.get_Qt(beta_t, device)
317
-
318
- Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
319
- predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
320
-
321
- unnormalized_probX_all = utils.reverse_diffusion(
322
- predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
323
- )
324
-
325
- unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
326
- unnormalized_prob_E = unnormalized_probX_all[
327
- :, :, self.Xdim_output :
328
- ].reshape(bs, n * n, -1)
329
-
330
- unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
331
- unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
332
-
333
- prob_X = unnormalized_prob_X / torch.sum(
334
- unnormalized_prob_X, dim=-1, keepdim=True
335
- ) # bs, n, d_t-1
336
- prob_E = unnormalized_prob_E / torch.sum(
337
- unnormalized_prob_E, dim=-1, keepdim=True
338
- ) # bs, n, d_t-1
339
- prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
340
-
341
- return prob_X, prob_E
342
-
343
- prob_X, prob_E = get_prob(noisy_data)
344
-
345
- ### Guidance
346
- if guide_scale != 1:
347
- uncon_prob_X, uncon_prob_E = get_prob(
348
- noisy_data, unconditioned=True
349
- )
350
- prob_X = (
351
- uncon_prob_X
352
- * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
353
- )
354
- prob_E = (
355
- uncon_prob_E
356
- * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
357
- )
358
- prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
359
- prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
360
-
361
- # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
362
- # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
363
-
364
- sampled_s = utils.sample_discrete_features(
365
- prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
366
- )
367
-
368
- X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
369
- E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
370
-
371
- assert (E_s == torch.transpose(E_s, 1, 2)).all()
372
- assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
373
-
374
- out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
375
- out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
376
-
377
- return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
378
- node_mask, collapse=True
379
- ).type_as(properties)
380
 
 
11
  from .transformer import Transformer
12
  from .visualize_utils import MolecularVisualization
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class GraphDiT(nn.Module):
15
  def __init__(
16
  self,
 
19
  model_dtype,
20
  ):
21
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def init_model(self, model_dir):
24
+ pass
25
+
 
 
 
 
26
  def disable_grads(self):
27
+ pass
28
 
29
+
30
+ # class GraphDiT(nn.Module):
31
+ # def __init__(
32
+ # self,
33
+ # model_config_path,
34
+ # data_info_path,
35
+ # model_dtype,
36
+ # ):
37
+ # super().__init__()
38
+ # dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
39
+
40
+ # input_dims = data_info.input_dims
41
+ # output_dims = data_info.output_dims
42
+ # nodes_dist = data_info.nodes_dist
43
+ # active_index = data_info.active_index
44
+
45
+ # self.model_config = dm_cfg
46
+ # self.data_info = data_info
47
+ # self.T = dm_cfg.diffusion_steps
48
+ # self.Xdim = input_dims["X"]
49
+ # self.Edim = input_dims["E"]
50
+ # self.ydim = input_dims["y"]
51
+ # self.Xdim_output = output_dims["X"]
52
+ # self.Edim_output = output_dims["E"]
53
+ # self.ydim_output = output_dims["y"]
54
+ # self.node_dist = nodes_dist
55
+ # self.active_index = active_index
56
+ # self.max_n_nodes = data_info.max_n_nodes
57
+ # self.atom_decoder = data_info.atom_decoder
58
+ # self.hidden_size = dm_cfg.hidden_size
59
+ # self.mol_visualizer = MolecularVisualization(self.atom_decoder)
60
+
61
+ # self.denoiser = Transformer(
62
+ # max_n_nodes=self.max_n_nodes,
63
+ # hidden_size=dm_cfg.hidden_size,
64
+ # depth=dm_cfg.depth,
65
+ # num_heads=dm_cfg.num_heads,
66
+ # mlp_ratio=dm_cfg.mlp_ratio,
67
+ # drop_condition=dm_cfg.drop_condition,
68
+ # Xdim=self.Xdim,
69
+ # Edim=self.Edim,
70
+ # ydim=self.ydim,
71
+ # )
72
+
73
+ # self.model_dtype = model_dtype
74
+ # self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
75
+ # dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
76
+ # )
77
+ # x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
78
+ # data_info.node_types.to(self.model_dtype)
79
+ # )
80
+ # e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
81
+ # data_info.edge_types.to(self.model_dtype)
82
+ # )
83
+ # x_marginals = x_marginals / x_marginals.sum()
84
+ # e_marginals = e_marginals / e_marginals.sum()
85
+
86
+ # xe_conditions = data_info.transition_E.to(self.model_dtype)
87
+ # xe_conditions = xe_conditions[self.active_index][:, self.active_index]
88
+
89
+ # xe_conditions = xe_conditions.sum(dim=1)
90
+ # ex_conditions = xe_conditions.t()
91
+ # xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
92
+ # ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
93
+
94
+ # self.transition_model = utils.MarginalTransition(
95
+ # x_marginals=x_marginals,
96
+ # e_marginals=e_marginals,
97
+ # xe_conditions=xe_conditions,
98
+ # ex_conditions=ex_conditions,
99
+ # y_classes=self.ydim_output,
100
+ # n_nodes=self.max_n_nodes,
101
+ # )
102
+ # self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
103
+
104
+ # def init_model(self, model_dir):
105
+ # model_file = os.path.join(model_dir, 'model.pt')
106
+ # if os.path.exists(model_file):
107
+ # self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
108
+ # else:
109
+ # raise FileNotFoundError(f"Model file not found: {model_file}")
110
+
111
+ # def disable_grads(self):
112
+ # self.denoiser.disable_grads()
113
+
114
+ # def forward(
115
+ # self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
116
+ # ):
117
+ # raise ValueError('Not Implement')
118
+
119
+ # def _forward(self, noisy_data, unconditioned=False):
120
+ # noisy_x, noisy_e, properties = (
121
+ # noisy_data["X_t"].to(self.model_dtype),
122
+ # noisy_data["E_t"].to(self.model_dtype),
123
+ # noisy_data["y_t"].to(self.model_dtype).clone(),
124
+ # )
125
+ # node_mask, timestep = (
126
+ # noisy_data["node_mask"],
127
+ # noisy_data["t"],
128
+ # )
129
 
130
+ # pred = self.denoiser(
131
+ # noisy_x,
132
+ # noisy_e,
133
+ # node_mask,
134
+ # properties,
135
+ # timestep,
136
+ # unconditioned=unconditioned,
137
+ # )
138
+ # return pred
139
+
140
+ # def apply_noise(self, X, E, y, node_mask):
141
+ # """Sample noise and apply it to the data."""
142
+
143
+ # # Sample a timestep t.
144
+ # # When evaluating, the loss for t=0 is computed separately
145
+ # lowest_t = 0 if self.training else 1
146
+ # t_int = torch.randint(
147
+ # lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
148
+ # ).to(
149
+ # self.model_dtype
150
+ # ) # (bs, 1)
151
+ # s_int = t_int - 1
152
+
153
+ # t_float = t_int / self.T
154
+ # s_float = s_int / self.T
155
+
156
+ # # beta_t and alpha_s_bar are used for denoising/loss computation
157
+ # beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
158
+ # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
159
+ # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
160
+
161
+ # Qtb = self.transition_model.get_Qt_bar(
162
+ # alpha_t_bar, X.device
163
+ # ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
164
+
165
+ # bs, n, d = X.shape
166
+ # X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
167
+ # prob_all = X_all @ Qtb.X
168
+ # probX = prob_all[:, :, : self.Xdim_output]
169
+ # probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
170
+
171
+ # sampled_t = utils.sample_discrete_features(
172
+ # probX=probX, probE=probE, node_mask=node_mask
173
+ # )
174
+
175
+ # X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
176
+ # E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
177
+ # assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
178
+
179
+ # y_t = y
180
+ # z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
181
+
182
+ # noisy_data = {
183
+ # "t_int": t_int,
184
+ # "t": t_float,
185
+ # "beta_t": beta_t,
186
+ # "alpha_s_bar": alpha_s_bar,
187
+ # "alpha_t_bar": alpha_t_bar,
188
+ # "X_t": z_t.X,
189
+ # "E_t": z_t.E,
190
+ # "y_t": z_t.y,
191
+ # "node_mask": node_mask,
192
+ # }
193
+ # return noisy_data
194
+
195
+ # @torch.no_grad()
196
+ # def generate(
197
+ # self,
198
+ # properties,
199
+ # guide_scale=1.,
200
+ # num_nodes=None,
201
+ # number_chain_steps=50,
202
+ # ):
203
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
204
+ # properties = [float('nan') if x is None else x for x in properties]
205
+ # properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
206
+ # batch_size = properties.size(0)
207
+ # assert batch_size == 1
208
+ # if num_nodes is None:
209
+ # num_nodes = self.node_dist.sample_n(batch_size, device)
210
+ # else:
211
+ # num_nodes = torch.LongTensor([num_nodes]).to(device)
212
+
213
+ # arange = (
214
+ # torch.arange(self.max_n_nodes, device=device)
215
+ # .unsqueeze(0)
216
+ # .expand(batch_size, -1)
217
+ # )
218
+ # node_mask = arange < num_nodes.unsqueeze(1)
219
+
220
+ # z_T = utils.sample_discrete_feature_noise(
221
+ # limit_dist=self.limit_dist, node_mask=node_mask
222
+ # )
223
+ # X, E = z_T.X, z_T.E
224
+
225
+ # assert (E == torch.transpose(E, 1, 2)).all()
226
+
227
+ # if number_chain_steps > 0:
228
+ # chain_X_size = torch.Size((number_chain_steps, X.size(1)))
229
+ # chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
230
+ # chain_X = torch.zeros(chain_X_size)
231
+ # chain_E = torch.zeros(chain_E_size)
232
+
233
+ # # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
234
+ # y = properties
235
+ # for s_int in reversed(range(0, self.T)):
236
+ # s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
237
+ # t_array = s_array + 1
238
+ # s_norm = s_array / self.T
239
+ # t_norm = t_array / self.T
240
+
241
+ # # Sample z_s
242
+ # sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
243
+ # s_norm, t_norm, X, E, y, node_mask, guide_scale, device
244
+ # )
245
+ # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
246
 
247
+ # if number_chain_steps > 0:
248
+ # # Save the first keep_chain graphs
249
+ # write_index = (s_int * number_chain_steps) // self.T
250
+ # chain_X[write_index] = discrete_sampled_s.X[:1]
251
+ # chain_E[write_index] = discrete_sampled_s.E[:1]
252
+
253
+ # # Sample
254
+ # sampled_s = sampled_s.mask(node_mask, collapse=True)
255
+ # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
256
+
257
+ # molecule_list = []
258
+ # n = num_nodes[0]
259
+ # atom_types = X[0, :n].cpu()
260
+ # edge_types = E[0, :n, :n].cpu()
261
+ # molecule_list.append([atom_types, edge_types])
262
+ # smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
263
+
264
+ # # Visualize Chains
265
+ # if number_chain_steps > 0:
266
+ # final_X_chain = X[:1]
267
+ # final_E_chain = E[:1]
268
+
269
+ # chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
270
+ # chain_E[0] = final_E_chain
271
+
272
+ # chain_X = utils.reverse_tensor(chain_X)
273
+ # chain_E = utils.reverse_tensor(chain_E)
274
+
275
+ # # Repeat last frame to see final sample better
276
+ # chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
277
+ # chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
278
+ # mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
279
+ # else:
280
+ # mol_img_list = []
281
+
282
+ # return smiles, mol_img_list
283
+
284
+ # def check_valid(self, smiles):
285
+ # return check_valid(smiles)
286
 
287
+ # def sample_p_zs_given_zt(
288
+ # self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
289
+ # ):
290
+ # """Samples from zs ~ p(zs | zt). Only used during sampling.
291
+ # if last_step, return the graph prediction as well"""
292
+ # bs, n, _ = X_t.shape
293
+ # beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
294
+ # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
295
+ # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
296
+
297
+ # # Neural net predictions
298
+ # noisy_data = {
299
+ # "X_t": X_t,
300
+ # "E_t": E_t,
301
+ # "y_t": properties,
302
+ # "t": t,
303
+ # "node_mask": node_mask,
304
+ # }
305
+
306
+ # def get_prob(noisy_data, unconditioned=False):
307
+ # pred = self._forward(noisy_data, unconditioned=unconditioned)
308
+
309
+ # # Normalize predictions
310
+ # pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
311
+ # pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
312
+
313
+ # # Retrieve transitions matrix
314
+ # Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
315
+ # Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
316
+ # Qt = self.transition_model.get_Qt(beta_t, device)
317
+
318
+ # Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
319
+ # predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
320
+
321
+ # unnormalized_probX_all = utils.reverse_diffusion(
322
+ # predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
323
+ # )
324
+
325
+ # unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
326
+ # unnormalized_prob_E = unnormalized_probX_all[
327
+ # :, :, self.Xdim_output :
328
+ # ].reshape(bs, n * n, -1)
329
+
330
+ # unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
331
+ # unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
332
+
333
+ # prob_X = unnormalized_prob_X / torch.sum(
334
+ # unnormalized_prob_X, dim=-1, keepdim=True
335
+ # ) # bs, n, d_t-1
336
+ # prob_E = unnormalized_prob_E / torch.sum(
337
+ # unnormalized_prob_E, dim=-1, keepdim=True
338
+ # ) # bs, n, d_t-1
339
+ # prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
340
+
341
+ # return prob_X, prob_E
342
+
343
+ # prob_X, prob_E = get_prob(noisy_data)
344
+
345
+ # ### Guidance
346
+ # if guide_scale != 1:
347
+ # uncon_prob_X, uncon_prob_E = get_prob(
348
+ # noisy_data, unconditioned=True
349
+ # )
350
+ # prob_X = (
351
+ # uncon_prob_X
352
+ # * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
353
+ # )
354
+ # prob_E = (
355
+ # uncon_prob_E
356
+ # * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
357
+ # )
358
+ # prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
359
+ # prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
360
+
361
+ # # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
362
+ # # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
363
+
364
+ # sampled_s = utils.sample_discrete_features(
365
+ # prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
366
+ # )
367
+
368
+ # X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
369
+ # E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
370
+
371
+ # assert (E_s == torch.transpose(E_s, 1, 2)).all()
372
+ # assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
373
+
374
+ # out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
375
+ # out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
376
+
377
+ # return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
378
+ # node_mask, collapse=True
379
+ # ).type_as(properties)
380