Spaces:
Sleeping
Sleeping
import argparse | |
import gradio as gr | |
import numpy as np | |
import os | |
import torch | |
import subprocess | |
import output | |
from rdkit import Chem | |
from src import const | |
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule | |
from src.lightning import DDPM | |
from src.linker_size_lightning import SizeClassifier | |
from src.generation import N_SAMPLES, generate_linkers, try_to_convert_to_sdf | |
MODELS_METADATA = { | |
'geom_difflinker': { | |
'link': 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1', | |
'path': 'models/geom_difflinker.ckpt', | |
}, | |
'geom_difflinker_given_anchors': { | |
'link': 'https://zenodo.org/record/7775568/files/geom_difflinker_given_anchors.ckpt?download=1', | |
'path': 'models/geom_difflinker_given_anchors.ckpt', | |
}, | |
'pockets_difflinker': { | |
'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full_no_anchors.ckpt?download=1', | |
'path': 'models/pockets_difflinker.ckpt', | |
}, | |
'pockets_difflinker_given_anchors': { | |
'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full.ckpt?download=1', | |
'path': 'models/pockets_difflinker_given_anchors.ckpt', | |
}, | |
} | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--ip', type=str, default=None) | |
args = parser.parse_args() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f'Device: {device}') | |
os.makedirs("results", exist_ok=True) | |
os.makedirs("models", exist_ok=True) | |
size_gnn_path = 'models/geom_size_gnn.ckpt' | |
if not os.path.exists(size_gnn_path): | |
print('Downloading SizeGNN model...') | |
link = 'https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1' | |
subprocess.run(f'wget {link} -O {size_gnn_path}', shell=True) | |
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device) | |
print('Loaded SizeGNN model') | |
diffusion_models = {} | |
for model_name, metadata in MODELS_METADATA.items(): | |
link = metadata['link'] | |
diffusion_path = metadata['path'] | |
if not os.path.exists(diffusion_path): | |
print(f'Downloading {model_name}...') | |
subprocess.run(f'wget {link} -O {diffusion_path}', shell=True) | |
diffusion_models[model_name] = DDPM.load_from_checkpoint(diffusion_path, map_location=device).eval().to(device) | |
print(f'Loaded model {model_name}') | |
print(os.curdir) | |
print(os.path.abspath(os.curdir)) | |
print(os.listdir(os.curdir)) | |
def read_molecule_content(path): | |
with open(path, "r") as f: | |
return "".join(f.readlines()) | |
def read_molecule(path): | |
if path.endswith('.pdb'): | |
return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True) | |
elif path.endswith('.mol'): | |
return Chem.MolFromMolFile(path, sanitize=False, removeHs=True) | |
elif path.endswith('.mol2'): | |
return Chem.MolFromMol2File(path, sanitize=False, removeHs=True) | |
elif path.endswith('.sdf'): | |
return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0] | |
raise Exception('Unknown file extension') | |
def show_input(input_file): | |
if input_file is None: | |
return ['', gr.Radio.update(visible=False, value='Sample 1'), None] | |
if isinstance(input_file, str): | |
path = input_file | |
else: | |
path = input_file.name | |
extension = path.split('.')[-1] | |
if extension not in ['sdf', 'pdb', 'mol', 'mol2']: | |
msg = output.INVALID_FORMAT_MSG.format(extension=extension) | |
return [ | |
output.IFRAME_TEMPLATE.format(html=msg), | |
gr.Radio.update(visible=False), | |
None, | |
] | |
try: | |
molecule = read_molecule_content(path) | |
except Exception as e: | |
return [ | |
f'Could not read the molecule: {e}', | |
gr.Radio.update(visible=False), | |
None, | |
] | |
html = output.INITIAL_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension) | |
return [ | |
output.IFRAME_TEMPLATE.format(html=html), | |
gr.Radio.update(visible=False), | |
None, | |
] | |
def draw_sample(idx, out_files): | |
if isinstance(idx, str): | |
idx = int(idx.strip().split(' ')[-1]) - 1 | |
in_file = out_files[0] | |
in_sdf = in_file if isinstance(in_file, str) else in_file.name | |
out_file = out_files[idx + 1] | |
out_sdf = out_file if isinstance(out_file, str) else out_file.name | |
input_fragments_content = read_molecule_content(in_sdf) | |
generated_molecule_content = read_molecule_content(out_sdf) | |
fragments_fmt = in_sdf.split('.')[-1] | |
molecule_fmt = out_sdf.split('.')[-1] | |
html = output.SAMPLES_RENDERING_TEMPLATE.format( | |
fragments=input_fragments_content, | |
fragments_fmt=fragments_fmt, | |
molecule=generated_molecule_content, | |
molecule_fmt=molecule_fmt, | |
) | |
return output.IFRAME_TEMPLATE.format(html=html) | |
def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms): | |
# Parsing selected atoms (javascript output) | |
selected_atoms = selected_atoms.strip() | |
if selected_atoms == '': | |
selected_atoms = [] | |
else: | |
selected_atoms = list(map(int, selected_atoms.split(','))) | |
# Selecting model | |
if len(selected_atoms) == 0: | |
selected_model_name = 'geom_difflinker' | |
else: | |
selected_model_name = 'geom_difflinker_given_anchors' | |
if input_file is None: | |
return [None, None, None, None] | |
print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms) | |
ddpm = diffusion_models[selected_model_name] | |
path = input_file.name | |
extension = path.split('.')[-1] | |
if extension not in ['sdf', 'pdb', 'mol', 'mol2']: | |
msg = output.INVALID_FORMAT_MSG.format(extension=extension) | |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
try: | |
molecule = read_molecule(path) | |
molecule = Chem.RemoveAllHs(molecule) | |
name = '.'.join(path.split('/')[-1].split('.')[:-1]) | |
inp_sdf = f'results/input_{name}.sdf' | |
except Exception as e: | |
error = f'Could not read the molecule: {e}' | |
msg = output.ERROR_FORMAT_MSG.format(message=error) | |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
if molecule.GetNumAtoms() > 50: | |
error = f'Too large molecule: upper limit is 50 heavy atoms' | |
msg = output.ERROR_FORMAT_MSG.format(message=error) | |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
with Chem.SDWriter(inp_sdf) as w: | |
w.write(molecule) | |
positions, one_hot, charges = parse_molecule(molecule, is_geom=True) | |
anchors = np.zeros_like(charges) | |
anchors[selected_atoms] = 1 | |
fragment_mask = np.ones_like(charges) | |
linker_mask = np.zeros_like(charges) | |
print('Read and parsed molecule') | |
dataset = [{ | |
'uuid': '0', | |
'name': '0', | |
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), | |
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), | |
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), | |
'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device), | |
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), | |
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), | |
'num_atoms': len(positions), | |
}] * N_SAMPLES | |
dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges) | |
print('Created dataloader') | |
ddpm.edm.T = n_steps | |
if n_atoms == 0: | |
def sample_fn(_data): | |
out, _ = size_nn.forward(_data, return_loss=False) | |
probabilities = torch.softmax(out, dim=1) | |
distribution = torch.distributions.Categorical(probs=probabilities) | |
samples = distribution.sample() | |
sizes = [] | |
for label in samples.detach().cpu().numpy(): | |
sizes.append(size_nn.linker_id2size[label]) | |
sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long) | |
return sizes | |
else: | |
def sample_fn(_data): | |
return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms | |
for data in dataloader: | |
try: | |
generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name) | |
except Exception as e: | |
error = f'Caught exception while generating linkers: {e}' | |
msg = output.ERROR_FORMAT_MSG.format(message=error) | |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
out_files = try_to_convert_to_sdf(name) | |
out_files = [inp_sdf] + out_files | |
return [ | |
draw_sample(radio_samples, out_files), | |
out_files, | |
gr.Radio.update(visible=True), | |
None | |
] | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design') | |
gr.Markdown( | |
'Given a set of disconnected fragments in 3D, ' | |
'DiffLinker places missing atoms in between and designs a molecule incorporating all the initial fragments. ' | |
'Our method can link an arbitrary number of fragments, requires no information on the attachment atoms ' | |
'and linker size, and can be conditioned on the protein pockets.' | |
) | |
gr.Markdown( | |
'[**[Paper]**](https://arxiv.org/abs/2210.05274) ' | |
'[**[Code]**](https://github.com/igashov/DiffLinker)' | |
) | |
with gr.Box(): | |
with gr.Row(): | |
hidden = gr.Textbox(visible=False) | |
with gr.Column(): | |
gr.Markdown('## Input Fragments') | |
gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:') | |
input_file = gr.File(file_count='single', label='Input Fragments') | |
n_steps = gr.Slider(minimum=10, maximum=500, label="Number of Denoising Steps", step=10) | |
n_atoms = gr.Slider( | |
minimum=0, maximum=20, | |
label="Linker Size: DiffLinker will predict it if set to 0", | |
step=1 | |
) | |
examples = gr.Dataset( | |
components=[gr.File(visible=False)], | |
samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']], | |
type='index', | |
) | |
button = gr.Button('Generate Linker!') | |
gr.Markdown('') | |
gr.Markdown('## Output Files') | |
gr.Markdown('Download files with the generated molecules here:') | |
output_files = gr.File(file_count='multiple', label='Output Files', interactive=False) | |
with gr.Column(): | |
gr.Markdown('## Visualization') | |
gr.Markdown('**Hint:** click on atoms to select anchor points (optionally)') | |
samples = gr.Radio( | |
choices=['Sample 1', 'Sample 2', 'Sample 3', 'Sample 4', 'Sample 5'], | |
value='Sample 1', | |
type='value', | |
show_label=False, | |
visible=False, | |
interactive=True, | |
) | |
visualization = gr.HTML() | |
input_file.change( | |
fn=show_input, | |
inputs=[input_file], | |
outputs=[visualization, samples, hidden], | |
) | |
input_file.clear( | |
fn=lambda: [None, '', gr.Radio.update(visible=False), None], | |
inputs=[], | |
outputs=[input_file, visualization, samples, hidden], | |
) | |
examples.click( | |
fn=lambda idx: [f'examples/example_{idx+1}.sdf', 10, 0] + show_input(f'examples/example_{idx+1}.sdf'), | |
inputs=[examples], | |
outputs=[input_file, n_steps, n_atoms, visualization, samples, hidden] | |
) | |
button.click( | |
fn=generate, | |
inputs=[input_file, n_steps, n_atoms, samples, hidden], | |
outputs=[visualization, output_files, samples, hidden], | |
_js=output.RETURN_SELECTION_JS, | |
) | |
samples.change( | |
fn=draw_sample, | |
inputs=[samples, output_files], | |
outputs=[visualization], | |
) | |
demo.load(_js=output.STARTUP_JS) | |
demo.launch(server_name=args.ip) | |