import gradio as gr import numpy as np import os import torch import subprocess from rdkit import Chem from src import const from src.visualizer import save_xyz_file from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule from src.lightning import DDPM from src.linker_size_lightning import SizeClassifier HTML_TEMPLATE = """
""" IFRAME_TEMPLATE = """""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("results", exist_ok=True) os.makedirs("models", exist_ok=True) subprocess.run( 'wget https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1 -O models/geom_size_gnn.ckpt', shell=True ) size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device) print('Loaded SizeGNN model') subprocess.run( 'wget https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1 -O models/geom_difflinker.ckpt', shell=True ) ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device) print('Loaded diffusion model') def sample_fn(_data): output, _ = size_nn.forward(_data, return_loss=False) probabilities = torch.softmax(output, dim=1) distribution = torch.distributions.Categorical(probs=probabilities) samples = distribution.sample() sizes = [] for label in samples.detach().cpu().numpy(): sizes.append(size_nn.linker_id2size[label]) sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long) return sizes def read_molecule_content(path): with open(path, "r") as f: return "".join(f.readlines()) def read_molecule(path): if path.endswith('.pdb'): return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True) elif path.endswith('.mol'): return Chem.MolFromMolFile(path, sanitize=False, removeHs=True) elif path.endswith('.mol2'): return Chem.MolFromMol2File(path, sanitize=False, removeHs=True) elif path.endswith('.sdf'): return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0] raise Exception('Unknown file extension') def generate(input_file): try: path = input_file.name molecule = read_molecule(path) name = '.'.join(path.split('/')[-1].split('.')[:-1]) out_sdf = f'results/{name}_generated.sdf' print(f'Input path={path}, name={name}') except Exception as e: return f'Could not read the molecule: {e}' if molecule.GetNumAtoms() > 50: return f'Too large molecule: upper limit is 50 heavy atoms' positions, one_hot, charges = parse_molecule(molecule, is_geom=True) anchors = np.zeros_like(charges) fragment_mask = np.ones_like(charges) linker_mask = np.zeros_like(charges) print('Read and parsed molecule') dataset = [{ 'uuid': '0', 'name': '0', 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device), 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), 'num_atoms': len(positions), }] dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges) print('Created dataloader') for data in dataloader: chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1) print('Generated linker') x = chain[0][:, :, :ddpm.n_dims] h = chain[0][:, :, ddpm.n_dims:] save_xyz_file('results', h, x, node_mask, names=[name], is_geom=True, suffix='generated') print('Saved XYZ file') subprocess.run(f'obabel results/{name}_generated.xyz -O {out_sdf}', shell=True) print('Converted to SDF') break generated_molecule = read_molecule_content(out_sdf) html = HTML_TEMPLATE.format(molecule=generated_molecule, fmt='sdf') return IFRAME_TEMPLATE.format(html=html) demo = gr.Blocks() with demo: gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design') with gr.Box(): with gr.Row(): with gr.Column(): gr.Markdown('## Input Fragments') gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format') input_file = gr.File(file_count='single', label='Input fragments') button = gr.Button('Generate Linker!') gr.Markdown('') visualization = gr.HTML() button.click( fn=generate, inputs=[input_file], outputs=[visualization], ) demo.launch()