liuganghuggingface's picture
Update graph_decoder/diffusion_model.py
7c67898 verified
raw
history blame
13.9 kB
import os
import yaml
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import diffusion_utils as utils
from .transformer import Transformer
from .molecule_utils import graph_to_smiles, check_valid
from .visualize_utils import MolecularVisualization
# class GraphDiT(nn.Module):
# def __init__(
# self,
# model_config_path,
# data_info_path,
# model_dtype,
# ):
# super().__init__()
# def init_model(self, model_dir):
# pass
# def disable_grads(self):
# pass
# def generate(self, properties, guide_scale, num_nodes, number_chain_steps):
# return 0, 0
class GraphDiT(nn.Module):
def __init__(
self,
model_config_path,
data_info_path,
model_dtype,
):
super().__init__()
dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
input_dims = data_info.input_dims
output_dims = data_info.output_dims
nodes_dist = data_info.nodes_dist
active_index = data_info.active_index
self.model_config = dm_cfg
self.data_info = data_info
self.T = dm_cfg.diffusion_steps
self.Xdim = input_dims["X"]
self.Edim = input_dims["E"]
self.ydim = input_dims["y"]
self.Xdim_output = output_dims["X"]
self.Edim_output = output_dims["E"]
self.ydim_output = output_dims["y"]
self.node_dist = nodes_dist
self.active_index = active_index
self.max_n_nodes = data_info.max_n_nodes
self.atom_decoder = data_info.atom_decoder
self.hidden_size = dm_cfg.hidden_size
self.mol_visualizer = MolecularVisualization(self.atom_decoder)
self.denoiser = Transformer(
max_n_nodes=self.max_n_nodes,
hidden_size=dm_cfg.hidden_size,
depth=dm_cfg.depth,
num_heads=dm_cfg.num_heads,
mlp_ratio=dm_cfg.mlp_ratio,
drop_condition=dm_cfg.drop_condition,
Xdim=self.Xdim,
Edim=self.Edim,
ydim=self.ydim,
)
self.model_dtype = model_dtype
self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
)
x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
data_info.node_types.to(self.model_dtype)
)
e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
data_info.edge_types.to(self.model_dtype)
)
x_marginals = x_marginals / x_marginals.sum()
e_marginals = e_marginals / e_marginals.sum()
xe_conditions = data_info.transition_E.to(self.model_dtype)
xe_conditions = xe_conditions[self.active_index][:, self.active_index]
xe_conditions = xe_conditions.sum(dim=1)
ex_conditions = xe_conditions.t()
xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
self.transition_model = utils.MarginalTransition(
x_marginals=x_marginals,
e_marginals=e_marginals,
xe_conditions=xe_conditions,
ex_conditions=ex_conditions,
y_classes=self.ydim_output,
n_nodes=self.max_n_nodes,
)
self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
def init_model(self, model_dir):
model_file = os.path.join(model_dir, 'model.pt')
if os.path.exists(model_file):
self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
else:
raise FileNotFoundError(f"Model file not found: {model_file}")
def disable_grads(self):
self.denoiser.disable_grads()
def forward(
self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
):
raise ValueError('Not Implement')
def _forward(self, noisy_data, unconditioned=False):
noisy_x, noisy_e, properties = (
noisy_data["X_t"].to(self.model_dtype),
noisy_data["E_t"].to(self.model_dtype),
noisy_data["y_t"].to(self.model_dtype).clone(),
)
node_mask, timestep = (
noisy_data["node_mask"],
noisy_data["t"],
)
pred = self.denoiser(
noisy_x,
noisy_e,
node_mask,
properties,
timestep,
unconditioned=unconditioned,
)
return pred
def apply_noise(self, X, E, y, node_mask):
"""Sample noise and apply it to the data."""
# Sample a timestep t.
# When evaluating, the loss for t=0 is computed separately
lowest_t = 0 if self.training else 1
t_int = torch.randint(
lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
).to(
self.model_dtype
) # (bs, 1)
s_int = t_int - 1
t_float = t_int / self.T
s_float = s_int / self.T
# beta_t and alpha_s_bar are used for denoising/loss computation
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
Qtb = self.transition_model.get_Qt_bar(
alpha_t_bar, X.device
) # (bs, dx_in, dx_out), (bs, de_in, de_out)
bs, n, d = X.shape
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
prob_all = X_all @ Qtb.X
probX = prob_all[:, :, : self.Xdim_output]
probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
sampled_t = utils.sample_discrete_features(
probX=probX, probE=probE, node_mask=node_mask
)
X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
y_t = y
z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
noisy_data = {
"t_int": t_int,
"t": t_float,
"beta_t": beta_t,
"alpha_s_bar": alpha_s_bar,
"alpha_t_bar": alpha_t_bar,
"X_t": z_t.X,
"E_t": z_t.E,
"y_t": z_t.y,
"node_mask": node_mask,
}
return noisy_data
@torch.no_grad()
def generate(
self,
properties,
guide_scale=1.,
num_nodes=None,
number_chain_steps=50,
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
properties = [float('nan') if x is None else x for x in properties]
properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
batch_size = properties.size(0)
assert batch_size == 1
if num_nodes is None:
num_nodes = self.node_dist.sample_n(batch_size, device)
else:
num_nodes = torch.LongTensor([num_nodes]).to(device)
arange = (
torch.arange(self.max_n_nodes, device=device)
.unsqueeze(0)
.expand(batch_size, -1)
)
node_mask = arange < num_nodes.unsqueeze(1)
z_T = utils.sample_discrete_feature_noise(
limit_dist=self.limit_dist, node_mask=node_mask
)
X, E = z_T.X, z_T.E
assert (E == torch.transpose(E, 1, 2)).all()
if number_chain_steps > 0:
chain_X_size = torch.Size((number_chain_steps, X.size(1)))
chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
chain_X = torch.zeros(chain_X_size)
chain_E = torch.zeros(chain_E_size)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
y = properties
for s_int in reversed(range(0, self.T)):
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
t_array = s_array + 1
s_norm = s_array / self.T
t_norm = t_array / self.T
# Sample z_s
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
s_norm, t_norm, X, E, y, node_mask, guide_scale, device
)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
if number_chain_steps > 0:
# Save the first keep_chain graphs
write_index = (s_int * number_chain_steps) // self.T
chain_X[write_index] = discrete_sampled_s.X[:1]
chain_E[write_index] = discrete_sampled_s.E[:1]
# Sample
sampled_s = sampled_s.mask(node_mask, collapse=True)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
molecule_list = []
n = num_nodes[0]
atom_types = X[0, :n].cpu()
edge_types = E[0, :n, :n].cpu()
molecule_list.append([atom_types, edge_types])
smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
# Visualize Chains
if number_chain_steps > 0:
final_X_chain = X[:1]
final_E_chain = E[:1]
chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
chain_E[0] = final_E_chain
chain_X = utils.reverse_tensor(chain_X)
chain_E = utils.reverse_tensor(chain_E)
# Repeat last frame to see final sample better
chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
else:
mol_img_list = []
return smiles, mol_img_list
def check_valid(self, smiles):
return check_valid(smiles)
def sample_p_zs_given_zt(
self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
):
"""Samples from zs ~ p(zs | zt). Only used during sampling.
if last_step, return the graph prediction as well"""
bs, n, _ = X_t.shape
beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
# Neural net predictions
noisy_data = {
"X_t": X_t,
"E_t": E_t,
"y_t": properties,
"t": t,
"node_mask": node_mask,
}
def get_prob(noisy_data, unconditioned=False):
pred = self._forward(noisy_data, unconditioned=unconditioned)
# Normalize predictions
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
# Retrieve transitions matrix
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
Qt = self.transition_model.get_Qt(beta_t, device)
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
unnormalized_probX_all = utils.reverse_diffusion(
predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
)
unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
unnormalized_prob_E = unnormalized_probX_all[
:, :, self.Xdim_output :
].reshape(bs, n * n, -1)
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
prob_X = unnormalized_prob_X / torch.sum(
unnormalized_prob_X, dim=-1, keepdim=True
) # bs, n, d_t-1
prob_E = unnormalized_prob_E / torch.sum(
unnormalized_prob_E, dim=-1, keepdim=True
) # bs, n, d_t-1
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
return prob_X, prob_E
prob_X, prob_E = get_prob(noisy_data)
### Guidance
if guide_scale != 1:
uncon_prob_X, uncon_prob_E = get_prob(
noisy_data, unconditioned=True
)
prob_X = (
uncon_prob_X
* (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
)
prob_E = (
uncon_prob_E
* (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
)
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
# assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
# assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
sampled_s = utils.sample_discrete_features(
prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
)
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
assert (E_s == torch.transpose(E_s, 1, 2)).all()
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
node_mask, collapse=True
).type_as(properties)