import os | |
import yaml | |
import json | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from . import diffusion_utils as utils | |
from .molecule_utils import graph_to_smiles, check_valid | |
from .transformer import Transformer | |
from .visualize_utils import MolecularVisualization | |
class GraphDiT(nn.Module): | |
def __init__( | |
self, | |
model_config_path, | |
data_info_path, | |
model_dtype, | |
): | |
super().__init__() | |
def init_model(self, model_dir): | |
pass | |
def disable_grads(self): | |
pass | |
# class GraphDiT(nn.Module): | |
# def __init__( | |
# self, | |
# model_config_path, | |
# data_info_path, | |
# model_dtype, | |
# ): | |
# super().__init__() | |
# dm_cfg, data_info = utils.load_config(model_config_path, data_info_path) | |
# input_dims = data_info.input_dims | |
# output_dims = data_info.output_dims | |
# nodes_dist = data_info.nodes_dist | |
# active_index = data_info.active_index | |
# self.model_config = dm_cfg | |
# self.data_info = data_info | |
# self.T = dm_cfg.diffusion_steps | |
# self.Xdim = input_dims["X"] | |
# self.Edim = input_dims["E"] | |
# self.ydim = input_dims["y"] | |
# self.Xdim_output = output_dims["X"] | |
# self.Edim_output = output_dims["E"] | |
# self.ydim_output = output_dims["y"] | |
# self.node_dist = nodes_dist | |
# self.active_index = active_index | |
# self.max_n_nodes = data_info.max_n_nodes | |
# self.atom_decoder = data_info.atom_decoder | |
# self.hidden_size = dm_cfg.hidden_size | |
# self.mol_visualizer = MolecularVisualization(self.atom_decoder) | |
# self.denoiser = Transformer( | |
# max_n_nodes=self.max_n_nodes, | |
# hidden_size=dm_cfg.hidden_size, | |
# depth=dm_cfg.depth, | |
# num_heads=dm_cfg.num_heads, | |
# mlp_ratio=dm_cfg.mlp_ratio, | |
# drop_condition=dm_cfg.drop_condition, | |
# Xdim=self.Xdim, | |
# Edim=self.Edim, | |
# ydim=self.ydim, | |
# ) | |
# self.model_dtype = model_dtype | |
# self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete( | |
# dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps | |
# ) | |
# x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum( | |
# data_info.node_types.to(self.model_dtype) | |
# ) | |
# e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum( | |
# data_info.edge_types.to(self.model_dtype) | |
# ) | |
# x_marginals = x_marginals / x_marginals.sum() | |
# e_marginals = e_marginals / e_marginals.sum() | |
# xe_conditions = data_info.transition_E.to(self.model_dtype) | |
# xe_conditions = xe_conditions[self.active_index][:, self.active_index] | |
# xe_conditions = xe_conditions.sum(dim=1) | |
# ex_conditions = xe_conditions.t() | |
# xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True) | |
# ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True) | |
# self.transition_model = utils.MarginalTransition( | |
# x_marginals=x_marginals, | |
# e_marginals=e_marginals, | |
# xe_conditions=xe_conditions, | |
# ex_conditions=ex_conditions, | |
# y_classes=self.ydim_output, | |
# n_nodes=self.max_n_nodes, | |
# ) | |
# self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None) | |
# def init_model(self, model_dir): | |
# model_file = os.path.join(model_dir, 'model.pt') | |
# if os.path.exists(model_file): | |
# self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True)) | |
# else: | |
# raise FileNotFoundError(f"Model file not found: {model_file}") | |
# def disable_grads(self): | |
# self.denoiser.disable_grads() | |
# def forward( | |
# self, x, edge_index, edge_attr, graph_batch, properties, no_label_index | |
# ): | |
# raise ValueError('Not Implement') | |
# def _forward(self, noisy_data, unconditioned=False): | |
# noisy_x, noisy_e, properties = ( | |
# noisy_data["X_t"].to(self.model_dtype), | |
# noisy_data["E_t"].to(self.model_dtype), | |
# noisy_data["y_t"].to(self.model_dtype).clone(), | |
# ) | |
# node_mask, timestep = ( | |
# noisy_data["node_mask"], | |
# noisy_data["t"], | |
# ) | |
# pred = self.denoiser( | |
# noisy_x, | |
# noisy_e, | |
# node_mask, | |
# properties, | |
# timestep, | |
# unconditioned=unconditioned, | |
# ) | |
# return pred | |
# def apply_noise(self, X, E, y, node_mask): | |
# """Sample noise and apply it to the data.""" | |
# # Sample a timestep t. | |
# # When evaluating, the loss for t=0 is computed separately | |
# lowest_t = 0 if self.training else 1 | |
# t_int = torch.randint( | |
# lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device | |
# ).to( | |
# self.model_dtype | |
# ) # (bs, 1) | |
# s_int = t_int - 1 | |
# t_float = t_int / self.T | |
# s_float = s_int / self.T | |
# # beta_t and alpha_s_bar are used for denoising/loss computation | |
# beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1) | |
# alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1) | |
# alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1) | |
# Qtb = self.transition_model.get_Qt_bar( | |
# alpha_t_bar, X.device | |
# ) # (bs, dx_in, dx_out), (bs, de_in, de_out) | |
# bs, n, d = X.shape | |
# X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) | |
# prob_all = X_all @ Qtb.X | |
# probX = prob_all[:, :, : self.Xdim_output] | |
# probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1) | |
# sampled_t = utils.sample_discrete_features( | |
# probX=probX, probE=probE, node_mask=node_mask | |
# ) | |
# X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) | |
# E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) | |
# assert (X.shape == X_t.shape) and (E.shape == E_t.shape) | |
# y_t = y | |
# z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask) | |
# noisy_data = { | |
# "t_int": t_int, | |
# "t": t_float, | |
# "beta_t": beta_t, | |
# "alpha_s_bar": alpha_s_bar, | |
# "alpha_t_bar": alpha_t_bar, | |
# "X_t": z_t.X, | |
# "E_t": z_t.E, | |
# "y_t": z_t.y, | |
# "node_mask": node_mask, | |
# } | |
# return noisy_data | |
# @torch.no_grad() | |
# def generate( | |
# self, | |
# properties, | |
# guide_scale=1., | |
# num_nodes=None, | |
# number_chain_steps=50, | |
# ): | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# properties = [float('nan') if x is None else x for x in properties] | |
# properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device) | |
# batch_size = properties.size(0) | |
# assert batch_size == 1 | |
# if num_nodes is None: | |
# num_nodes = self.node_dist.sample_n(batch_size, device) | |
# else: | |
# num_nodes = torch.LongTensor([num_nodes]).to(device) | |
# arange = ( | |
# torch.arange(self.max_n_nodes, device=device) | |
# .unsqueeze(0) | |
# .expand(batch_size, -1) | |
# ) | |
# node_mask = arange < num_nodes.unsqueeze(1) | |
# z_T = utils.sample_discrete_feature_noise( | |
# limit_dist=self.limit_dist, node_mask=node_mask | |
# ) | |
# X, E = z_T.X, z_T.E | |
# assert (E == torch.transpose(E, 1, 2)).all() | |
# if number_chain_steps > 0: | |
# chain_X_size = torch.Size((number_chain_steps, X.size(1))) | |
# chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2))) | |
# chain_X = torch.zeros(chain_X_size) | |
# chain_E = torch.zeros(chain_E_size) | |
# # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. | |
# y = properties | |
# for s_int in reversed(range(0, self.T)): | |
# s_array = s_int * torch.ones((batch_size, 1)).type_as(y) | |
# t_array = s_array + 1 | |
# s_norm = s_array / self.T | |
# t_norm = t_array / self.T | |
# # Sample z_s | |
# sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt( | |
# s_norm, t_norm, X, E, y, node_mask, guide_scale, device | |
# ) | |
# X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | |
# if number_chain_steps > 0: | |
# # Save the first keep_chain graphs | |
# write_index = (s_int * number_chain_steps) // self.T | |
# chain_X[write_index] = discrete_sampled_s.X[:1] | |
# chain_E[write_index] = discrete_sampled_s.E[:1] | |
# # Sample | |
# sampled_s = sampled_s.mask(node_mask, collapse=True) | |
# X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | |
# molecule_list = [] | |
# n = num_nodes[0] | |
# atom_types = X[0, :n].cpu() | |
# edge_types = E[0, :n, :n].cpu() | |
# molecule_list.append([atom_types, edge_types]) | |
# smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0] | |
# # Visualize Chains | |
# if number_chain_steps > 0: | |
# final_X_chain = X[:1] | |
# final_E_chain = E[:1] | |
# chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E | |
# chain_E[0] = final_E_chain | |
# chain_X = utils.reverse_tensor(chain_X) | |
# chain_E = utils.reverse_tensor(chain_E) | |
# # Repeat last frame to see final sample better | |
# chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0) | |
# chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0) | |
# mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy()) | |
# else: | |
# mol_img_list = [] | |
# return smiles, mol_img_list | |
# def check_valid(self, smiles): | |
# return check_valid(smiles) | |
# def sample_p_zs_given_zt( | |
# self, s, t, X_t, E_t, properties, node_mask, guide_scale, device | |
# ): | |
# """Samples from zs ~ p(zs | zt). Only used during sampling. | |
# if last_step, return the graph prediction as well""" | |
# bs, n, _ = X_t.shape | |
# beta_t = self.noise_schedule(t_normalized=t) # (bs, 1) | |
# alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) | |
# alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) | |
# # Neural net predictions | |
# noisy_data = { | |
# "X_t": X_t, | |
# "E_t": E_t, | |
# "y_t": properties, | |
# "t": t, | |
# "node_mask": node_mask, | |
# } | |
# def get_prob(noisy_data, unconditioned=False): | |
# pred = self._forward(noisy_data, unconditioned=unconditioned) | |
# # Normalize predictions | |
# pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0 | |
# pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0 | |
# # Retrieve transitions matrix | |
# Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device) | |
# Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device) | |
# Qt = self.transition_model.get_Qt(beta_t, device) | |
# Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1) | |
# predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1) | |
# unnormalized_probX_all = utils.reverse_diffusion( | |
# predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X | |
# ) | |
# unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output] | |
# unnormalized_prob_E = unnormalized_probX_all[ | |
# :, :, self.Xdim_output : | |
# ].reshape(bs, n * n, -1) | |
# unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 | |
# unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 | |
# prob_X = unnormalized_prob_X / torch.sum( | |
# unnormalized_prob_X, dim=-1, keepdim=True | |
# ) # bs, n, d_t-1 | |
# prob_E = unnormalized_prob_E / torch.sum( | |
# unnormalized_prob_E, dim=-1, keepdim=True | |
# ) # bs, n, d_t-1 | |
# prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) | |
# return prob_X, prob_E | |
# prob_X, prob_E = get_prob(noisy_data) | |
# ### Guidance | |
# if guide_scale != 1: | |
# uncon_prob_X, uncon_prob_E = get_prob( | |
# noisy_data, unconditioned=True | |
# ) | |
# prob_X = ( | |
# uncon_prob_X | |
# * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale | |
# ) | |
# prob_E = ( | |
# uncon_prob_E | |
# * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale | |
# ) | |
# prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5) | |
# prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5) | |
# # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all() | |
# # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all() | |
# sampled_s = utils.sample_discrete_features( | |
# prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item() | |
# ) | |
# X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype) | |
# E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype) | |
# assert (E_s == torch.transpose(E_s, 1, 2)).all() | |
# assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) | |
# out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties) | |
# out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties) | |
# return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask( | |
# node_mask, collapse=True | |
# ).type_as(properties) | |