import os import yaml import json import torch import torch.nn as nn import torch.nn.functional as F from . import diffusion_utils as utils from .transformer import Transformer from .molecule_utils import graph_to_smiles, check_valid from .visualize_utils import MolecularVisualization # class GraphDiT(nn.Module): # def __init__( # self, # model_config_path, # data_info_path, # model_dtype, # ): # super().__init__() # def init_model(self, model_dir): # pass # def disable_grads(self): # pass # def generate(self, properties, guide_scale, num_nodes, number_chain_steps): # return 0, 0 class GraphDiT(nn.Module): def __init__( self, model_config_path, data_info_path, model_dtype, ): super().__init__() dm_cfg, data_info = utils.load_config(model_config_path, data_info_path) input_dims = data_info.input_dims output_dims = data_info.output_dims nodes_dist = data_info.nodes_dist active_index = data_info.active_index self.model_config = dm_cfg self.data_info = data_info self.T = dm_cfg.diffusion_steps self.Xdim = input_dims["X"] self.Edim = input_dims["E"] self.ydim = input_dims["y"] self.Xdim_output = output_dims["X"] self.Edim_output = output_dims["E"] self.ydim_output = output_dims["y"] self.node_dist = nodes_dist self.active_index = active_index self.max_n_nodes = data_info.max_n_nodes self.atom_decoder = data_info.atom_decoder self.hidden_size = dm_cfg.hidden_size self.mol_visualizer = MolecularVisualization(self.atom_decoder) self.denoiser = Transformer( max_n_nodes=self.max_n_nodes, hidden_size=dm_cfg.hidden_size, depth=dm_cfg.depth, num_heads=dm_cfg.num_heads, mlp_ratio=dm_cfg.mlp_ratio, drop_condition=dm_cfg.drop_condition, Xdim=self.Xdim, Edim=self.Edim, ydim=self.ydim, ) self.model_dtype = model_dtype self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete( dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps ) x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum( data_info.node_types.to(self.model_dtype) ) e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum( data_info.edge_types.to(self.model_dtype) ) x_marginals = x_marginals / x_marginals.sum() e_marginals = e_marginals / e_marginals.sum() xe_conditions = data_info.transition_E.to(self.model_dtype) xe_conditions = xe_conditions[self.active_index][:, self.active_index] xe_conditions = xe_conditions.sum(dim=1) ex_conditions = xe_conditions.t() xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True) ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True) self.transition_model = utils.MarginalTransition( x_marginals=x_marginals, e_marginals=e_marginals, xe_conditions=xe_conditions, ex_conditions=ex_conditions, y_classes=self.ydim_output, n_nodes=self.max_n_nodes, ) self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None) def init_model(self, model_dir): model_file = os.path.join(model_dir, 'model.pt') if os.path.exists(model_file): self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True)) else: raise FileNotFoundError(f"Model file not found: {model_file}") def disable_grads(self): self.denoiser.disable_grads() def forward( self, x, edge_index, edge_attr, graph_batch, properties, no_label_index ): raise ValueError('Not Implement') def _forward(self, noisy_data, unconditioned=False): noisy_x, noisy_e, properties = ( noisy_data["X_t"].to(self.model_dtype), noisy_data["E_t"].to(self.model_dtype), noisy_data["y_t"].to(self.model_dtype).clone(), ) node_mask, timestep = ( noisy_data["node_mask"], noisy_data["t"], ) pred = self.denoiser( noisy_x, noisy_e, node_mask, properties, timestep, unconditioned=unconditioned, ) return pred def apply_noise(self, X, E, y, node_mask): """Sample noise and apply it to the data.""" # Sample a timestep t. # When evaluating, the loss for t=0 is computed separately lowest_t = 0 if self.training else 1 t_int = torch.randint( lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device ).to( self.model_dtype ) # (bs, 1) s_int = t_int - 1 t_float = t_int / self.T s_float = s_int / self.T # beta_t and alpha_s_bar are used for denoising/loss computation beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1) alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1) alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1) Qtb = self.transition_model.get_Qt_bar( alpha_t_bar, X.device ) # (bs, dx_in, dx_out), (bs, de_in, de_out) bs, n, d = X.shape X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) prob_all = X_all @ Qtb.X probX = prob_all[:, :, : self.Xdim_output] probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1) sampled_t = utils.sample_discrete_features( probX=probX, probE=probE, node_mask=node_mask ) X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) assert (X.shape == X_t.shape) and (E.shape == E_t.shape) y_t = y z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask) noisy_data = { "t_int": t_int, "t": t_float, "beta_t": beta_t, "alpha_s_bar": alpha_s_bar, "alpha_t_bar": alpha_t_bar, "X_t": z_t.X, "E_t": z_t.E, "y_t": z_t.y, "node_mask": node_mask, } return noisy_data @torch.no_grad() def generate( self, properties, guide_scale=1., num_nodes=None, number_chain_steps=50, ): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") properties = [float('nan') if x is None else x for x in properties] properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device) batch_size = properties.size(0) assert batch_size == 1 if num_nodes is None: num_nodes = self.node_dist.sample_n(batch_size, device) else: num_nodes = torch.LongTensor([num_nodes]).to(device) arange = ( torch.arange(self.max_n_nodes, device=device) .unsqueeze(0) .expand(batch_size, -1) ) node_mask = arange < num_nodes.unsqueeze(1) z_T = utils.sample_discrete_feature_noise( limit_dist=self.limit_dist, node_mask=node_mask ) X, E = z_T.X, z_T.E assert (E == torch.transpose(E, 1, 2)).all() if number_chain_steps > 0: chain_X_size = torch.Size((number_chain_steps, X.size(1))) chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2))) chain_X = torch.zeros(chain_X_size) chain_E = torch.zeros(chain_E_size) # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. y = properties for s_int in reversed(range(0, self.T)): s_array = s_int * torch.ones((batch_size, 1)).type_as(y) t_array = s_array + 1 s_norm = s_array / self.T t_norm = t_array / self.T # Sample z_s sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt( s_norm, t_norm, X, E, y, node_mask, guide_scale, device ) X, E, y = sampled_s.X, sampled_s.E, sampled_s.y if number_chain_steps > 0: # Save the first keep_chain graphs write_index = (s_int * number_chain_steps) // self.T chain_X[write_index] = discrete_sampled_s.X[:1] chain_E[write_index] = discrete_sampled_s.E[:1] # Sample sampled_s = sampled_s.mask(node_mask, collapse=True) X, E, y = sampled_s.X, sampled_s.E, sampled_s.y molecule_list = [] n = num_nodes[0] atom_types = X[0, :n].cpu() edge_types = E[0, :n, :n].cpu() molecule_list.append([atom_types, edge_types]) smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0] # Visualize Chains if number_chain_steps > 0: final_X_chain = X[:1] final_E_chain = E[:1] chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E chain_E[0] = final_E_chain chain_X = utils.reverse_tensor(chain_X) chain_E = utils.reverse_tensor(chain_E) # Repeat last frame to see final sample better chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0) chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0) mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy()) else: mol_img_list = [] return smiles, mol_img_list def check_valid(self, smiles): return check_valid(smiles) def sample_p_zs_given_zt( self, s, t, X_t, E_t, properties, node_mask, guide_scale, device ): """Samples from zs ~ p(zs | zt). Only used during sampling. if last_step, return the graph prediction as well""" bs, n, _ = X_t.shape beta_t = self.noise_schedule(t_normalized=t) # (bs, 1) alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) # Neural net predictions noisy_data = { "X_t": X_t, "E_t": E_t, "y_t": properties, "t": t, "node_mask": node_mask, } def get_prob(noisy_data, unconditioned=False): pred = self._forward(noisy_data, unconditioned=unconditioned) # Normalize predictions pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0 pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0 # Retrieve transitions matrix Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device) Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device) Qt = self.transition_model.get_Qt(beta_t, device) Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1) predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1) unnormalized_probX_all = utils.reverse_diffusion( predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X ) unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output] unnormalized_prob_E = unnormalized_probX_all[ :, :, self.Xdim_output : ].reshape(bs, n * n, -1) unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 prob_X = unnormalized_prob_X / torch.sum( unnormalized_prob_X, dim=-1, keepdim=True ) # bs, n, d_t-1 prob_E = unnormalized_prob_E / torch.sum( unnormalized_prob_E, dim=-1, keepdim=True ) # bs, n, d_t-1 prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) return prob_X, prob_E prob_X, prob_E = get_prob(noisy_data) ### Guidance if guide_scale != 1: uncon_prob_X, uncon_prob_E = get_prob( noisy_data, unconditioned=True ) prob_X = ( uncon_prob_X * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale ) prob_E = ( uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale ) prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5) prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5) # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all() # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all() sampled_s = utils.sample_discrete_features( prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item() ) X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype) E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype) assert (E_s == torch.transpose(E_s, 1, 2)).all() assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties) out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties) return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask( node_mask, collapse=True ).type_as(properties)