File size: 8,296 Bytes
5de53c3
 
7782ac2
b0ab0d5
7782ac2
95ba5bc
 
52bf9df
95ba5bc
 
 
 
 
 
 
7782ac2
3c26059
 
5de53c3
53f22d0
5de53c3
7782ac2
95ba5bc
 
49021fb
95ba5bc
ff9d86b
 
 
 
 
95ba5bc
 
 
ff9d86b
 
 
 
 
95ba5bc
 
 
 
 
d1da608
95ba5bc
 
 
 
 
 
 
 
 
 
 
0673854
 
 
 
95ba5bc
 
 
 
 
 
 
 
 
 
 
 
52bf9df
 
 
b7813c6
 
 
 
52bf9df
 
 
 
 
 
 
 
 
 
a05c989
52bf9df
 
 
6264fac
 
 
 
 
 
 
e0c110c
52bf9df
 
 
 
 
 
 
 
 
7782ac2
0673854
53f22d0
7c181a3
3c26059
 
95ba5bc
 
 
b0ab0d5
 
 
f9310fd
 
 
 
95ba5bc
b0ab0d5
 
 
95ba5bc
 
 
 
 
 
 
 
b0ab0d5
 
 
95ba5bc
3c26059
 
95ba5bc
 
 
 
 
 
 
3c26059
 
 
95ba5bc
 
3c26059
 
 
 
 
 
 
 
 
a05c989
 
 
 
 
 
 
 
4f94923
52bf9df
6264fac
f58a645
 
 
 
4f94923
7782ac2
 
 
 
 
 
 
711f689
 
52bf9df
05e91b8
7a6d6dd
05e91b8
7a6d6dd
05e91b8
 
f1c7e08
52bf9df
 
 
 
 
 
f58a645
52bf9df
6264fac
52bf9df
 
 
 
 
 
7782ac2
 
e0c110c
f58a645
7782ac2
7a6d6dd
05e91b8
 
7a6d6dd
0ce499b
6264fac
 
 
 
 
b7813c6
5de53c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import argparse

import gradio as gr
import numpy as np
import os
import torch
import subprocess
import output

from rdkit import Chem
from src import const
from src.visualizer import save_xyz_file
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
from src.lightning import DDPM
from src.linker_size_lightning import SizeClassifier

N_SAMPLES = 5

parser = argparse.ArgumentParser()
parser.add_argument('--ip', type=str, default=None)
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("results", exist_ok=True)
os.makedirs("models", exist_ok=True)

size_gnn_path = 'models/geom_size_gnn.ckpt'
if not os.path.exists(size_gnn_path):
    print('Downloading SizeGNN model...')
    link = 'https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1'
    subprocess.run(f'wget {link} -O {size_gnn_path}', shell=True)
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
print('Loaded SizeGNN model')

diffusion_path = 'models/geom_difflinker.ckpt'
if not os.path.exists(diffusion_path):
    print('Downloading Diffusion model...')
    link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
    subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
print('Loaded diffusion model')


def sample_fn(_data):
    output, _ = size_nn.forward(_data, return_loss=False)
    probabilities = torch.softmax(output, dim=1)
    distribution = torch.distributions.Categorical(probs=probabilities)
    samples = distribution.sample()
    sizes = []
    for label in samples.detach().cpu().numpy():
        sizes.append(size_nn.linker_id2size[label])
    sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
    return sizes


def read_molecule_content(path):
    with open(path, "r") as f:
        return "".join(f.readlines())


def read_molecule(path):
    if path.endswith('.pdb'):
        return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True)
    elif path.endswith('.mol'):
        return Chem.MolFromMolFile(path, sanitize=False, removeHs=True)
    elif path.endswith('.mol2'):
        return Chem.MolFromMol2File(path, sanitize=False, removeHs=True)
    elif path.endswith('.sdf'):
        return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0]
    raise Exception('Unknown file extension')


def show_input(input_file):
    if input_file is None:
        return ''
    if isinstance(input_file, str):
        path = input_file
    else:
        path = input_file.name
    extension = path.split('.')[-1]
    if extension not in ['sdf', 'pdb', 'mol', 'mol2']:
        msg = output.INVALID_FORMAT_MSG.format(extension=extension)
        return output.IFRAME_TEMPLATE.format(html=msg)

    try:
        molecule = read_molecule_content(path)
    except Exception as e:
        return f'Could not read the molecule: {e}'

    html = output.INITIAL_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
    return output.IFRAME_TEMPLATE.format(html=html)


def draw_sample(idx, output_files):
    print(idx)
    print(output_files)
    print(output_files[0].name)
    return


def generate(input_file):
    if input_file is None:
        return ''

    path = input_file.name
    extension = path.split('.')[-1]
    if extension not in ['sdf', 'pdb', 'mol', 'mol2']:
        msg = output.INVALID_FORMAT_MSG.format(extension=extension)
        return output.IFRAME_TEMPLATE.format(html=msg)

    try:
        molecule = read_molecule(path)
        molecule = Chem.RemoveAllHs(molecule)
        name = '.'.join(path.split('/')[-1].split('.')[:-1])
        inp_sdf = f'results/input_{name}.sdf'
        inp_xyz = f'results/input_{name}.xyz'
    except Exception as e:
        return f'Could not read the molecule: {e}'

    if molecule.GetNumAtoms() > 50:
        return f'Too large molecule: upper limit is 50 heavy atoms'

    with Chem.SDWriter(inp_sdf) as w:
        w.write(molecule)
    Chem.MolToXYZFile(molecule, inp_xyz)

    positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
    anchors = np.zeros_like(charges)
    fragment_mask = np.ones_like(charges)
    linker_mask = np.zeros_like(charges)
    print('Read and parsed molecule')

    dataset = [{
        'uuid': '0',
        'name': '0',
        'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
        'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
        'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
        'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
        'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
        'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
        'num_atoms': len(positions),
    }] * N_SAMPLES
    dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
    print('Created dataloader')

    for data in dataloader:
        chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
        print('Generated linker')
        x = chain[0][:, :, :ddpm.n_dims]
        h = chain[0][:, :, ddpm.n_dims:]
        names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
        save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
        print('Saved XYZ files')
        break

    out_files = []
    for i in range(N_SAMPLES):
        out_xyz = f'results/output_{i+1}_{name}_.xyz'
        out_sdf = f'results/output_{i+1}_{name}_.sdf'
        subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
        out_files.append(out_sdf)
    print('Converted to SDF')

    out_sdf = f'results/output_1_{name}_.sdf'
    input_fragments_content = read_molecule_content(inp_sdf)
    generated_molecule_content = read_molecule_content(out_sdf)
    html = output.SAMPLES_RENDERING_TEMPLATE.format(
        fragments=input_fragments_content,
        fragments_fmt='sdf',
        molecule=generated_molecule_content,
        molecule_fmt='sdf',
    )
    return [
        output.IFRAME_TEMPLATE.format(html=html),
        [inp_sdf] + out_files,
        gr.Radio.update(
            choices=['Sample 1', 'Sample 2', 'Sample 3', 'Sample 4', 'Sample 5'],
            value='Sample 1',
        )
    ]


demo = gr.Blocks()
with demo:
    gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design')
    with gr.Box():
        with gr.Row():
            with gr.Column():
                gr.Markdown('## Input Fragments')
                gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
                input_file = gr.File(file_count='single', label='Input Fragments') 
                examples = gr.Dataset(
                    components=[gr.File(visible=False)],
                    samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
                    type='index',
                ) 

                button = gr.Button('Generate Linker!')
                gr.Markdown('')
                gr.Markdown('## Output Files')
                gr.Markdown('Download files with the generated molecules here:')
                output_files = gr.File(file_count='multiple', label='Output Files')
            with gr.Column():
                gr.Markdown('## Visualization')
                visualization = gr.HTML()
                samples = gr.Radio(interactive=True, type='index', label='Samples')

    input_file.change(
        fn=show_input,
        inputs=[input_file],
        outputs=[visualization],
    )
    button.click(
        fn=generate,
        inputs=[input_file],
        outputs=[visualization, output_files, samples],
    )
    examples.click(
        fn=lambda idx: [f'examples/example_{idx+1}.sdf', show_input(f'examples/example_{idx+1}.sdf')],
        inputs=[examples],
        outputs=[input_file, visualization]
    )
    samples.change(
        fn=draw_sample,
        inputs=[samples, output_files],
        outputs=[],
    )

demo.launch(server_name=args.ip)