File size: 9,251 Bytes
5de53c3
 
7782ac2
b0ab0d5
7782ac2
95ba5bc
 
52bf9df
95ba5bc
 
 
 
 
 
 
7782ac2
3c26059
 
5de53c3
53f22d0
5de53c3
7782ac2
95ba5bc
 
49021fb
95ba5bc
ff9d86b
 
 
 
 
95ba5bc
 
 
ff9d86b
 
 
 
 
95ba5bc
 
 
 
 
d1da608
95ba5bc
 
 
 
 
 
 
 
 
 
 
0673854
 
 
 
95ba5bc
 
 
 
 
 
 
 
 
 
 
 
52bf9df
 
 
b7813c6
 
 
 
52bf9df
 
 
 
 
 
 
 
 
 
a05c989
52bf9df
 
 
eb031b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6264fac
 
aa9b17f
52bf9df
 
 
 
 
 
 
 
 
7782ac2
0673854
53f22d0
7c181a3
3c26059
95ba5bc
 
 
b0ab0d5
 
 
f9310fd
 
 
95ba5bc
b0ab0d5
 
 
95ba5bc
 
 
 
 
 
 
 
b0ab0d5
 
 
95ba5bc
3c26059
 
95ba5bc
 
aa9b17f
 
95ba5bc
 
 
 
 
3c26059
 
 
95ba5bc
 
3c26059
 
 
 
 
 
 
 
4f94923
eb031b7
6264fac
c95aee1
4f94923
7782ac2
 
 
 
 
52e7c95
 
 
 
 
 
bec2844
 
 
 
7782ac2
 
711f689
 
52bf9df
aa9b17f
 
7a6d6dd
05e91b8
7a6d6dd
05e91b8
 
f1c7e08
52bf9df
 
 
 
 
 
f58a645
eb031b7
7a7c7ad
 
 
 
 
 
 
 
eb031b7
52bf9df
 
 
 
 
 
7a6d6dd
aa9b17f
05e91b8
aa9b17f
0ce499b
6f4a6fd
 
aa9b17f
6f4a6fd
 
6264fac
 
 
eb031b7
6264fac
b7813c6
5de53c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import argparse

import gradio as gr
import numpy as np
import os
import torch
import subprocess
import output

from rdkit import Chem
from src import const
from src.visualizer import save_xyz_file
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
from src.lightning import DDPM
from src.linker_size_lightning import SizeClassifier

N_SAMPLES = 5

parser = argparse.ArgumentParser()
parser.add_argument('--ip', type=str, default=None)
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("results", exist_ok=True)
os.makedirs("models", exist_ok=True)

size_gnn_path = 'models/geom_size_gnn.ckpt'
if not os.path.exists(size_gnn_path):
    print('Downloading SizeGNN model...')
    link = 'https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1'
    subprocess.run(f'wget {link} -O {size_gnn_path}', shell=True)
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
print('Loaded SizeGNN model')

diffusion_path = 'models/geom_difflinker.ckpt'
if not os.path.exists(diffusion_path):
    print('Downloading Diffusion model...')
    link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
    subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
print('Loaded diffusion model')


def sample_fn(_data):
    output, _ = size_nn.forward(_data, return_loss=False)
    probabilities = torch.softmax(output, dim=1)
    distribution = torch.distributions.Categorical(probs=probabilities)
    samples = distribution.sample()
    sizes = []
    for label in samples.detach().cpu().numpy():
        sizes.append(size_nn.linker_id2size[label])
    sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
    return sizes


def read_molecule_content(path):
    with open(path, "r") as f:
        return "".join(f.readlines())


def read_molecule(path):
    if path.endswith('.pdb'):
        return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True)
    elif path.endswith('.mol'):
        return Chem.MolFromMolFile(path, sanitize=False, removeHs=True)
    elif path.endswith('.mol2'):
        return Chem.MolFromMol2File(path, sanitize=False, removeHs=True)
    elif path.endswith('.sdf'):
        return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0]
    raise Exception('Unknown file extension')


def show_input(input_file):
    if input_file is None:
        return ''
    if isinstance(input_file, str):
        path = input_file
    else:
        path = input_file.name
    extension = path.split('.')[-1]
    if extension not in ['sdf', 'pdb', 'mol', 'mol2']:
        msg = output.INVALID_FORMAT_MSG.format(extension=extension)
        return output.IFRAME_TEMPLATE.format(html=msg)

    try:
        molecule = read_molecule_content(path)
    except Exception as e:
        return f'Could not read the molecule: {e}'

    html = output.INITIAL_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
    return output.IFRAME_TEMPLATE.format(html=html)


def draw_sample(idx, out_files):
    in_file = out_files[0]
    in_sdf = in_file if isinstance(in_file, str) else in_file.name

    out_file = out_files[idx + 1]
    out_sdf = out_file if isinstance(out_file, str) else out_file.name

    input_fragments_content = read_molecule_content(in_sdf)
    generated_molecule_content = read_molecule_content(out_sdf)
    html = output.SAMPLES_RENDERING_TEMPLATE.format(
        fragments=input_fragments_content,
        fragments_fmt='sdf',
        molecule=generated_molecule_content,
        molecule_fmt='sdf',
    )
    return output.IFRAME_TEMPLATE.format(html=html)


def generate(input_file, n_steps):
    if input_file is None:
        return ''

    path = input_file.name
    extension = path.split('.')[-1]
    if extension not in ['sdf', 'pdb', 'mol', 'mol2']:
        msg = output.INVALID_FORMAT_MSG.format(extension=extension)
        return output.IFRAME_TEMPLATE.format(html=msg)

    try:
        molecule = read_molecule(path)
        molecule = Chem.RemoveAllHs(molecule)
        name = '.'.join(path.split('/')[-1].split('.')[:-1])
        inp_sdf = f'results/input_{name}.sdf'
    except Exception as e:
        return f'Could not read the molecule: {e}'

    if molecule.GetNumAtoms() > 50:
        return f'Too large molecule: upper limit is 50 heavy atoms'

    with Chem.SDWriter(inp_sdf) as w:
        w.write(molecule)

    positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
    anchors = np.zeros_like(charges)
    fragment_mask = np.ones_like(charges)
    linker_mask = np.zeros_like(charges)
    print('Read and parsed molecule')

    dataset = [{
        'uuid': '0',
        'name': '0',
        'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
        'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
        'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
        'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
        'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
        'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
        'num_atoms': len(positions),
    }] * N_SAMPLES
    dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
    print('Created dataloader')

    ddpm.edm.T = n_steps

    for data in dataloader:
        chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
        print('Generated linker')
        x = chain[0][:, :, :ddpm.n_dims]
        h = chain[0][:, :, ddpm.n_dims:]
        names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
        save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
        print('Saved XYZ files')
        break

    out_files = []
    for i in range(N_SAMPLES):
        out_xyz = f'results/output_{i+1}_{name}_.xyz'
        out_sdf = f'results/output_{i+1}_{name}_.sdf'
        subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
        out_files.append(out_sdf)
    print('Converted to SDF')

    return [
        draw_sample(0, out_files),
        [inp_sdf] + out_files,
        gr.Radio.update(visible=True, value='Sample 1')
    ]


demo = gr.Blocks()
with demo:
    gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design')
    gr.Markdown(
        'Given a set of disconnected fragments in 3D, '
        'DiffLinker places missing atoms in between and designs a molecule incorporating all the initial fragments. '
        'Our method can link an arbitrary number of fragments, requires no information on the attachment atoms '
        'and linker size, and can be conditioned on the protein pockets.'
    )
    gr.Markdown(
        '[**[Paper]**](https://arxiv.org/abs/2210.05274)    '
        '[**[Code]**](https://github.com/igashov/DiffLinker)'
    )
    with gr.Box():
        with gr.Row():
            with gr.Column():
                gr.Markdown('## Input Fragments')
                gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
                input_file = gr.File(file_count='single', label='Input Fragments')
                n_steps = gr.Slider(minimum=10, maximum=500, label="Number of Diffusion Steps", step=10)
                examples = gr.Dataset(
                    components=[gr.File(visible=False)],
                    samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
                    type='index',
                ) 

                button = gr.Button('Generate Linker!')
                gr.Markdown('')
                gr.Markdown('## Output Files')
                gr.Markdown('Download files with the generated molecules here:')
                output_files = gr.File(file_count='multiple', label='Output Files')
            with gr.Column():
                gr.Markdown('## Visualization')
                # gr.Markdown('Below you will see input and output molecules')
                samples = gr.Radio(
                    choices=['Sample 1', 'Sample 2', 'Sample 3', 'Sample 4', 'Sample 5'],
                    value='Sample 1',
                    type='index',
                    show_label=False,
                    visible=False,
                    interactive=True,
                )
                visualization = gr.HTML()

    input_file.change(
        fn=show_input,
        inputs=[input_file],
        outputs=[visualization],
    )
    examples.click(
        fn=lambda idx: [f'examples/example_{idx+1}.sdf', 10, show_input(f'examples/example_{idx+1}.sdf')],
        inputs=[examples],
        outputs=[input_file, n_steps, visualization]
    )
    button.click(
        fn=generate,
        inputs=[input_file, n_steps],
        outputs=[visualization, output_files, samples],
    )
    samples.change(
        fn=draw_sample,
        inputs=[samples, output_files],
        outputs=[visualization],
    )

demo.launch(server_name=args.ip)