|
import os |
|
import yaml |
|
import json |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from . import diffusion_utils as utils |
|
from .transformer import Transformer |
|
from .molecule_utils import graph_to_smiles, check_valid |
|
from .visualize_utils import MolecularVisualization |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraphDiT(nn.Module): |
|
def __init__( |
|
self, |
|
model_config_path, |
|
data_info_path, |
|
model_dtype, |
|
): |
|
super().__init__() |
|
dm_cfg, data_info = utils.load_config(model_config_path, data_info_path) |
|
|
|
input_dims = data_info.input_dims |
|
output_dims = data_info.output_dims |
|
nodes_dist = data_info.nodes_dist |
|
active_index = data_info.active_index |
|
|
|
self.model_config = dm_cfg |
|
self.data_info = data_info |
|
self.T = dm_cfg.diffusion_steps |
|
self.Xdim = input_dims["X"] |
|
self.Edim = input_dims["E"] |
|
self.ydim = input_dims["y"] |
|
self.Xdim_output = output_dims["X"] |
|
self.Edim_output = output_dims["E"] |
|
self.ydim_output = output_dims["y"] |
|
self.node_dist = nodes_dist |
|
self.active_index = active_index |
|
self.max_n_nodes = data_info.max_n_nodes |
|
self.atom_decoder = data_info.atom_decoder |
|
self.hidden_size = dm_cfg.hidden_size |
|
self.mol_visualizer = MolecularVisualization(self.atom_decoder) |
|
|
|
self.denoiser = Transformer( |
|
max_n_nodes=self.max_n_nodes, |
|
hidden_size=dm_cfg.hidden_size, |
|
depth=dm_cfg.depth, |
|
num_heads=dm_cfg.num_heads, |
|
mlp_ratio=dm_cfg.mlp_ratio, |
|
drop_condition=dm_cfg.drop_condition, |
|
Xdim=self.Xdim, |
|
Edim=self.Edim, |
|
ydim=self.ydim, |
|
) |
|
|
|
self.model_dtype = model_dtype |
|
self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete( |
|
dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps |
|
) |
|
x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum( |
|
data_info.node_types.to(self.model_dtype) |
|
) |
|
e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum( |
|
data_info.edge_types.to(self.model_dtype) |
|
) |
|
x_marginals = x_marginals / x_marginals.sum() |
|
e_marginals = e_marginals / e_marginals.sum() |
|
|
|
xe_conditions = data_info.transition_E.to(self.model_dtype) |
|
xe_conditions = xe_conditions[self.active_index][:, self.active_index] |
|
|
|
xe_conditions = xe_conditions.sum(dim=1) |
|
ex_conditions = xe_conditions.t() |
|
xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True) |
|
ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True) |
|
|
|
self.transition_model = utils.MarginalTransition( |
|
x_marginals=x_marginals, |
|
e_marginals=e_marginals, |
|
xe_conditions=xe_conditions, |
|
ex_conditions=ex_conditions, |
|
y_classes=self.ydim_output, |
|
n_nodes=self.max_n_nodes, |
|
) |
|
self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None) |
|
|
|
def init_model(self, model_dir): |
|
model_file = os.path.join(model_dir, 'model.pt') |
|
if os.path.exists(model_file): |
|
self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True)) |
|
else: |
|
raise FileNotFoundError(f"Model file not found: {model_file}") |
|
|
|
def disable_grads(self): |
|
self.denoiser.disable_grads() |
|
|
|
def forward( |
|
self, x, edge_index, edge_attr, graph_batch, properties, no_label_index |
|
): |
|
raise ValueError('Not Implement') |
|
|
|
def _forward(self, noisy_data, unconditioned=False): |
|
noisy_x, noisy_e, properties = ( |
|
noisy_data["X_t"].to(self.model_dtype), |
|
noisy_data["E_t"].to(self.model_dtype), |
|
noisy_data["y_t"].to(self.model_dtype).clone(), |
|
) |
|
node_mask, timestep = ( |
|
noisy_data["node_mask"], |
|
noisy_data["t"], |
|
) |
|
|
|
pred = self.denoiser( |
|
noisy_x, |
|
noisy_e, |
|
node_mask, |
|
properties, |
|
timestep, |
|
unconditioned=unconditioned, |
|
) |
|
return pred |
|
|
|
def apply_noise(self, X, E, y, node_mask): |
|
"""Sample noise and apply it to the data.""" |
|
|
|
|
|
|
|
lowest_t = 0 if self.training else 1 |
|
t_int = torch.randint( |
|
lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device |
|
).to( |
|
self.model_dtype |
|
) |
|
s_int = t_int - 1 |
|
|
|
t_float = t_int / self.T |
|
s_float = s_int / self.T |
|
|
|
|
|
beta_t = self.noise_schedule(t_normalized=t_float) |
|
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) |
|
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) |
|
|
|
Qtb = self.transition_model.get_Qt_bar( |
|
alpha_t_bar, X.device |
|
) |
|
|
|
bs, n, d = X.shape |
|
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) |
|
prob_all = X_all @ Qtb.X |
|
probX = prob_all[:, :, : self.Xdim_output] |
|
probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1) |
|
|
|
sampled_t = utils.sample_discrete_features( |
|
probX=probX, probE=probE, node_mask=node_mask |
|
) |
|
|
|
X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) |
|
E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) |
|
assert (X.shape == X_t.shape) and (E.shape == E_t.shape) |
|
|
|
y_t = y |
|
z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask) |
|
|
|
noisy_data = { |
|
"t_int": t_int, |
|
"t": t_float, |
|
"beta_t": beta_t, |
|
"alpha_s_bar": alpha_s_bar, |
|
"alpha_t_bar": alpha_t_bar, |
|
"X_t": z_t.X, |
|
"E_t": z_t.E, |
|
"y_t": z_t.y, |
|
"node_mask": node_mask, |
|
} |
|
return noisy_data |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
properties, |
|
guide_scale=1., |
|
num_nodes=None, |
|
number_chain_steps=50, |
|
): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
properties = [float('nan') if x is None else x for x in properties] |
|
properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device) |
|
batch_size = properties.size(0) |
|
assert batch_size == 1 |
|
if num_nodes is None: |
|
num_nodes = self.node_dist.sample_n(batch_size, device) |
|
else: |
|
num_nodes = torch.LongTensor([num_nodes]).to(device) |
|
|
|
arange = ( |
|
torch.arange(self.max_n_nodes, device=device) |
|
.unsqueeze(0) |
|
.expand(batch_size, -1) |
|
) |
|
node_mask = arange < num_nodes.unsqueeze(1) |
|
|
|
z_T = utils.sample_discrete_feature_noise( |
|
limit_dist=self.limit_dist, node_mask=node_mask |
|
) |
|
X, E = z_T.X, z_T.E |
|
|
|
assert (E == torch.transpose(E, 1, 2)).all() |
|
|
|
if number_chain_steps > 0: |
|
chain_X_size = torch.Size((number_chain_steps, X.size(1))) |
|
chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2))) |
|
chain_X = torch.zeros(chain_X_size) |
|
chain_E = torch.zeros(chain_E_size) |
|
|
|
|
|
y = properties |
|
for s_int in reversed(range(0, self.T)): |
|
s_array = s_int * torch.ones((batch_size, 1)).type_as(y) |
|
t_array = s_array + 1 |
|
s_norm = s_array / self.T |
|
t_norm = t_array / self.T |
|
|
|
|
|
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt( |
|
s_norm, t_norm, X, E, y, node_mask, guide_scale, device |
|
) |
|
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y |
|
|
|
if number_chain_steps > 0: |
|
|
|
write_index = (s_int * number_chain_steps) // self.T |
|
chain_X[write_index] = discrete_sampled_s.X[:1] |
|
chain_E[write_index] = discrete_sampled_s.E[:1] |
|
|
|
|
|
sampled_s = sampled_s.mask(node_mask, collapse=True) |
|
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y |
|
|
|
molecule_list = [] |
|
n = num_nodes[0] |
|
atom_types = X[0, :n].cpu() |
|
edge_types = E[0, :n, :n].cpu() |
|
molecule_list.append([atom_types, edge_types]) |
|
smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0] |
|
|
|
|
|
if number_chain_steps > 0: |
|
final_X_chain = X[:1] |
|
final_E_chain = E[:1] |
|
|
|
chain_X[0] = final_X_chain |
|
chain_E[0] = final_E_chain |
|
|
|
chain_X = utils.reverse_tensor(chain_X) |
|
chain_E = utils.reverse_tensor(chain_E) |
|
|
|
|
|
chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0) |
|
chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0) |
|
mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy()) |
|
else: |
|
mol_img_list = [] |
|
|
|
return smiles, mol_img_list |
|
|
|
def check_valid(self, smiles): |
|
return check_valid(smiles) |
|
|
|
def sample_p_zs_given_zt( |
|
self, s, t, X_t, E_t, properties, node_mask, guide_scale, device |
|
): |
|
"""Samples from zs ~ p(zs | zt). Only used during sampling. |
|
if last_step, return the graph prediction as well""" |
|
bs, n, _ = X_t.shape |
|
beta_t = self.noise_schedule(t_normalized=t) |
|
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) |
|
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) |
|
|
|
|
|
noisy_data = { |
|
"X_t": X_t, |
|
"E_t": E_t, |
|
"y_t": properties, |
|
"t": t, |
|
"node_mask": node_mask, |
|
} |
|
|
|
def get_prob(noisy_data, unconditioned=False): |
|
pred = self._forward(noisy_data, unconditioned=unconditioned) |
|
|
|
|
|
pred_X = F.softmax(pred.X, dim=-1) |
|
pred_E = F.softmax(pred.E, dim=-1) |
|
|
|
|
|
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device) |
|
Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device) |
|
Qt = self.transition_model.get_Qt(beta_t, device) |
|
|
|
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1) |
|
predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1) |
|
|
|
unnormalized_probX_all = utils.reverse_diffusion( |
|
predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X |
|
) |
|
|
|
unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output] |
|
unnormalized_prob_E = unnormalized_probX_all[ |
|
:, :, self.Xdim_output : |
|
].reshape(bs, n * n, -1) |
|
|
|
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 |
|
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 |
|
|
|
prob_X = unnormalized_prob_X / torch.sum( |
|
unnormalized_prob_X, dim=-1, keepdim=True |
|
) |
|
prob_E = unnormalized_prob_E / torch.sum( |
|
unnormalized_prob_E, dim=-1, keepdim=True |
|
) |
|
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) |
|
|
|
return prob_X, prob_E |
|
|
|
prob_X, prob_E = get_prob(noisy_data) |
|
|
|
|
|
if guide_scale != 1: |
|
uncon_prob_X, uncon_prob_E = get_prob( |
|
noisy_data, unconditioned=True |
|
) |
|
prob_X = ( |
|
uncon_prob_X |
|
* (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale |
|
) |
|
prob_E = ( |
|
uncon_prob_E |
|
* (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale |
|
) |
|
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5) |
|
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5) |
|
|
|
|
|
|
|
|
|
sampled_s = utils.sample_discrete_features( |
|
prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item() |
|
) |
|
|
|
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype) |
|
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype) |
|
|
|
assert (E_s == torch.transpose(E_s, 1, 2)).all() |
|
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) |
|
|
|
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties) |
|
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties) |
|
|
|
return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask( |
|
node_mask, collapse=True |
|
).type_as(properties) |
|
|