import argparse import gradio as gr import numpy as np import os import torch import subprocess import output from rdkit import Chem from src import const from src.visualizer import save_xyz_file from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule from src.lightning import DDPM from src.linker_size_lightning import SizeClassifier N_SAMPLES = 5 parser = argparse.ArgumentParser() parser.add_argument('--ip', type=str, default=None) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("results", exist_ok=True) os.makedirs("models", exist_ok=True) size_gnn_path = 'models/geom_size_gnn.ckpt' if not os.path.exists(size_gnn_path): print('Downloading SizeGNN model...') link = 'https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1' subprocess.run(f'wget {link} -O {size_gnn_path}', shell=True) size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device) print('Loaded SizeGNN model') diffusion_path = 'models/geom_difflinker.ckpt' if not os.path.exists(diffusion_path): print('Downloading Diffusion model...') link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1' subprocess.run(f'wget {link} -O {diffusion_path}', shell=True) ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device) print('Loaded diffusion model') def sample_fn(_data): output, _ = size_nn.forward(_data, return_loss=False) probabilities = torch.softmax(output, dim=1) distribution = torch.distributions.Categorical(probs=probabilities) samples = distribution.sample() sizes = [] for label in samples.detach().cpu().numpy(): sizes.append(size_nn.linker_id2size[label]) sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long) return sizes def read_molecule_content(path): with open(path, "r") as f: return "".join(f.readlines()) def read_molecule(path): if path.endswith('.pdb'): return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True) elif path.endswith('.mol'): return Chem.MolFromMolFile(path, sanitize=False, removeHs=True) elif path.endswith('.mol2'): return Chem.MolFromMol2File(path, sanitize=False, removeHs=True) elif path.endswith('.sdf'): return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0] raise Exception('Unknown file extension') def show_input(input_file): if input_file is None: return '' if isinstance(input_file, str): path = input_file else: path = input_file.name extension = path.split('.')[-1] if extension not in ['sdf', 'pdb', 'mol', 'mol2']: msg = output.INVALID_FORMAT_MSG.format(extension=extension) return output.IFRAME_TEMPLATE.format(html=msg) try: molecule = read_molecule_content(path) except Exception as e: return f'Could not read the molecule: {e}' html = output.INITIAL_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension) return output.IFRAME_TEMPLATE.format(html=html) def draw_sample(idx, out_files): in_file = out_files[0] in_sdf = in_file if isinstance(in_file, str) else in_file.name out_file = out_files[idx + 1] out_sdf = out_file if isinstance(out_file, str) else out_file.name input_fragments_content = read_molecule_content(in_sdf) generated_molecule_content = read_molecule_content(out_sdf) html = output.SAMPLES_RENDERING_TEMPLATE.format( fragments=input_fragments_content, fragments_fmt='sdf', molecule=generated_molecule_content, molecule_fmt='sdf', ) return output.IFRAME_TEMPLATE.format(html=html) def generate(input_file, n_steps): if input_file is None: return '' path = input_file.name extension = path.split('.')[-1] if extension not in ['sdf', 'pdb', 'mol', 'mol2']: msg = output.INVALID_FORMAT_MSG.format(extension=extension) return output.IFRAME_TEMPLATE.format(html=msg) try: molecule = read_molecule(path) molecule = Chem.RemoveAllHs(molecule) name = '.'.join(path.split('/')[-1].split('.')[:-1]) inp_sdf = f'results/input_{name}.sdf' except Exception as e: return f'Could not read the molecule: {e}' if molecule.GetNumAtoms() > 50: return f'Too large molecule: upper limit is 50 heavy atoms' with Chem.SDWriter(inp_sdf) as w: w.write(molecule) positions, one_hot, charges = parse_molecule(molecule, is_geom=True) anchors = np.zeros_like(charges) fragment_mask = np.ones_like(charges) linker_mask = np.zeros_like(charges) print('Read and parsed molecule') dataset = [{ 'uuid': '0', 'name': '0', 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device), 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), 'num_atoms': len(positions), }] * N_SAMPLES dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges) print('Created dataloader') ddpm.edm.T = n_steps assert ddpm.center_of_mass == 'fragments' for data in dataloader: chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1) print('Generated linker') x = chain[0][:, :, :ddpm.n_dims] h = chain[0][:, :, ddpm.n_dims:] # Put the molecule back to the initial orientation pos_masked = data['positions'] * data['fragment_mask'] N = data['fragment_mask'].sum(1, keepdims=True) mean = torch.sum(pos_masked, dim=1, keepdim=True) / N x = x + mean * node_mask names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)] save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='') print('Saved XYZ files') break out_files = [] for i in range(N_SAMPLES): out_xyz = f'results/output_{i+1}_{name}_.xyz' out_sdf = f'results/output_{i+1}_{name}_.sdf' subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True) out_files.append(out_sdf) print('Converted to SDF') return [ draw_sample(0, out_files), [inp_sdf] + out_files, gr.Radio.update(visible=True, value='Sample 1') ] demo = gr.Blocks() with demo: gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design') gr.Markdown( 'Given a set of disconnected fragments in 3D, ' 'DiffLinker places missing atoms in between and designs a molecule incorporating all the initial fragments. ' 'Our method can link an arbitrary number of fragments, requires no information on the attachment atoms ' 'and linker size, and can be conditioned on the protein pockets.' ) gr.Markdown( '[**[Paper]**](https://arxiv.org/abs/2210.05274) ' '[**[Code]**](https://github.com/igashov/DiffLinker)' ) with gr.Box(): with gr.Row(): with gr.Column(): gr.Markdown('## Input Fragments') gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:') input_file = gr.File(file_count='single', label='Input Fragments') n_steps = gr.Slider(minimum=10, maximum=500, label="Number of Diffusion Steps", step=10) examples = gr.Dataset( components=[gr.File(visible=False)], samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']], type='index', ) button = gr.Button('Generate Linker!') gr.Markdown('') gr.Markdown('## Output Files') gr.Markdown('Download files with the generated molecules here:') output_files = gr.File(file_count='multiple', label='Output Files', interactive=False) with gr.Column(): gr.Markdown('## Visualization') # gr.Markdown('Below you will see input and output molecules') samples = gr.Radio( choices=['Sample 1', 'Sample 2', 'Sample 3', 'Sample 4', 'Sample 5'], value='Sample 1', type='index', show_label=False, visible=False, interactive=True, ) visualization = gr.HTML() input_file.change( fn=show_input, inputs=[input_file], outputs=[visualization], ) examples.click( fn=lambda idx: [ f'examples/example_{idx+1}.sdf', 10, show_input(f'examples/example_{idx+1}.sdf'), gr.Radio(value='Sample 1', visible=False) ], inputs=[examples], outputs=[input_file, n_steps, visualization, samples] ) button.click( fn=generate, inputs=[input_file, n_steps], outputs=[visualization, output_files, samples], ) samples.change( fn=draw_sample, inputs=[samples, output_files], outputs=[visualization], ) input_file.clear( fn=lambda: ['', gr.Radio(value='Sample 1', visible=False)], inputs=[], outputs=[visualization, samples], ) demo.launch(server_name=args.ip)