Spaces:
Running
on
A10G
Running
on
A10G
Variable number of samples
Browse files- app.py +49 -38
- output.py +2 -2
- src/generation.py +4 -6
app.py
CHANGED
@@ -13,7 +13,7 @@ from src import const
|
|
13 |
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule, MOADDataset
|
14 |
from src.lightning import DDPM
|
15 |
from src.linker_size_lightning import SizeClassifier
|
16 |
-
from src.generation import
|
17 |
from zipfile import ZipFile
|
18 |
|
19 |
|
@@ -125,7 +125,7 @@ def show_input(in_fragments, in_protein):
|
|
125 |
vis = show_target(in_protein)
|
126 |
elif in_fragments is not None and in_protein is not None:
|
127 |
vis = show_fragments_and_target(in_fragments, in_protein)
|
128 |
-
return [vis, gr.
|
129 |
|
130 |
|
131 |
def show_fragments(in_fragments):
|
@@ -167,28 +167,25 @@ def clear_fragments_input(in_protein):
|
|
167 |
vis = ''
|
168 |
if in_protein is not None:
|
169 |
vis = show_target(in_protein)
|
170 |
-
return [None, vis, gr.
|
171 |
|
172 |
|
173 |
def clear_protein_input(in_fragments):
|
174 |
vis = ''
|
175 |
if in_fragments is not None:
|
176 |
vis = show_fragments(in_fragments)
|
177 |
-
return [None, vis, gr.
|
178 |
|
179 |
|
180 |
def click_on_example(example):
|
181 |
fragment_fname, target_fname = example
|
182 |
fragment_path = f'examples/{fragment_fname}' if fragment_fname != '' else None
|
183 |
target_path = f'examples/{target_fname}' if target_fname != '' else None
|
184 |
-
return [fragment_path, target_path
|
185 |
|
186 |
|
187 |
-
def draw_sample(
|
188 |
-
with_protein = (len(out_files) ==
|
189 |
-
|
190 |
-
if isinstance(idx, str):
|
191 |
-
idx = int(idx.strip().split(' ')[-1]) - 1
|
192 |
|
193 |
in_file = out_files[1]
|
194 |
in_sdf = in_file if isinstance(in_file, str) else in_file.name
|
@@ -204,8 +201,7 @@ def draw_sample(idx, out_files):
|
|
204 |
input_target_content = read_molecule_content(in_pdb)
|
205 |
target_fmt = in_pdb.split('.')[-1]
|
206 |
|
207 |
-
|
208 |
-
out_sdf = out_file if isinstance(out_file, str) else out_file.name
|
209 |
generated_molecule_content = read_molecule_content(out_sdf)
|
210 |
molecule_fmt = out_sdf.split('.')[-1]
|
211 |
|
@@ -237,17 +233,17 @@ def compress(output_fnames, name):
|
|
237 |
return archive_path
|
238 |
|
239 |
|
240 |
-
def generate(in_fragments, in_protein, n_steps, n_atoms,
|
241 |
if in_fragments is None:
|
242 |
return [None, None, None, None]
|
243 |
|
244 |
if in_protein is None:
|
245 |
-
return generate_without_pocket(in_fragments, n_steps, n_atoms,
|
246 |
else:
|
247 |
-
return generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms,
|
248 |
|
249 |
|
250 |
-
def generate_without_pocket(input_file, n_steps, n_atoms,
|
251 |
# Parsing selected atoms (javascript output)
|
252 |
selected_atoms = selected_atoms.strip()
|
253 |
if selected_atoms == '':
|
@@ -310,8 +306,8 @@ def generate_without_pocket(input_file, n_steps, n_atoms, radio_samples, selecte
|
|
310 |
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
311 |
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
312 |
'num_atoms': len(positions),
|
313 |
-
}] *
|
314 |
-
dataloader = get_dataloader(dataset, batch_size=
|
315 |
print('Created dataloader')
|
316 |
|
317 |
ddpm.edm.T = n_steps
|
@@ -333,26 +329,33 @@ def generate_without_pocket(input_file, n_steps, n_atoms, radio_samples, selecte
|
|
333 |
|
334 |
for data in dataloader:
|
335 |
try:
|
336 |
-
generate_linkers(
|
|
|
|
|
337 |
except Exception as e:
|
338 |
e = str(e).replace('\'', '')
|
339 |
error = f'Caught exception while generating linkers: {e}'
|
340 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
341 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
342 |
|
343 |
-
out_files = try_to_convert_to_sdf(name)
|
344 |
out_files = [inp_sdf] + out_files
|
345 |
out_files = [compress(out_files, name=name)] + out_files
|
|
|
346 |
|
347 |
return [
|
348 |
-
draw_sample(
|
349 |
out_files,
|
350 |
-
gr.
|
|
|
|
|
|
|
|
|
351 |
None
|
352 |
]
|
353 |
|
354 |
|
355 |
-
def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms,
|
356 |
# Parsing selected atoms (javascript output)
|
357 |
selected_atoms = selected_atoms.strip()
|
358 |
if selected_atoms == '':
|
@@ -443,11 +446,11 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_sampl
|
|
443 |
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
444 |
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
445 |
'num_atoms': len(positions),
|
446 |
-
}] *
|
447 |
dataset = MOADDataset(data=dataset)
|
448 |
ddpm.val_dataset = dataset
|
449 |
|
450 |
-
dataloader = get_dataloader(dataset, batch_size=
|
451 |
print('Created dataloader')
|
452 |
|
453 |
ddpm.edm.T = n_steps
|
@@ -469,21 +472,28 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_sampl
|
|
469 |
|
470 |
for data in dataloader:
|
471 |
try:
|
472 |
-
generate_linkers(
|
|
|
|
|
473 |
except Exception as e:
|
474 |
e = str(e).replace('\'', '')
|
475 |
error = f'Caught exception while generating linkers: {e}'
|
476 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
477 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
478 |
|
479 |
-
out_files = try_to_convert_to_sdf(name)
|
480 |
out_files = [inp_sdf, inp_pdb] + out_files
|
481 |
out_files = [compress(out_files, name=name)] + out_files
|
|
|
482 |
|
483 |
return [
|
484 |
-
draw_sample(
|
485 |
out_files,
|
486 |
-
gr.
|
|
|
|
|
|
|
|
|
487 |
None
|
488 |
]
|
489 |
|
@@ -516,6 +526,7 @@ with demo:
|
|
516 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
517 |
step=1
|
518 |
)
|
|
|
519 |
examples = gr.Dataset(
|
520 |
components=[gr.File(visible=False), gr.File(visible=False)],
|
521 |
samples=[
|
@@ -524,7 +535,6 @@ with demo:
|
|
524 |
['examples/3hz1_fragments.sdf', 'examples/3hz1_protein.pdb'],
|
525 |
['examples/5ou2_fragments.sdf', 'examples/5ou2_protein.pdb'],
|
526 |
],
|
527 |
-
# headers=['Fragments', 'Target Protein'],
|
528 |
type='values',
|
529 |
)
|
530 |
|
@@ -537,13 +547,14 @@ with demo:
|
|
537 |
with gr.Column():
|
538 |
gr.Markdown('## Visualization')
|
539 |
gr.Markdown('**Hint:** click on atoms to select anchor points (optionally)')
|
540 |
-
samples = gr.
|
541 |
-
choices=[
|
542 |
-
value=
|
543 |
type='value',
|
544 |
-
|
545 |
visible=False,
|
546 |
interactive=True,
|
|
|
547 |
)
|
548 |
visualization = gr.HTML()
|
549 |
|
@@ -570,17 +581,17 @@ with demo:
|
|
570 |
examples.click(
|
571 |
fn=click_on_example,
|
572 |
inputs=[examples],
|
573 |
-
outputs=[input_fragments_file, input_protein_file,
|
574 |
)
|
575 |
button.click(
|
576 |
fn=generate,
|
577 |
-
inputs=[input_fragments_file, input_protein_file, n_steps, n_atoms,
|
578 |
outputs=[visualization, output_files, samples, hidden],
|
579 |
_js=output.RETURN_SELECTION_JS,
|
580 |
)
|
581 |
-
samples.
|
582 |
fn=draw_sample,
|
583 |
-
inputs=[samples, output_files],
|
584 |
outputs=[visualization],
|
585 |
)
|
586 |
demo.load(_js=output.STARTUP_JS)
|
|
|
13 |
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule, MOADDataset
|
14 |
from src.lightning import DDPM
|
15 |
from src.linker_size_lightning import SizeClassifier
|
16 |
+
from src.generation import generate_linkers, try_to_convert_to_sdf, get_pocket
|
17 |
from zipfile import ZipFile
|
18 |
|
19 |
|
|
|
125 |
vis = show_target(in_protein)
|
126 |
elif in_fragments is not None and in_protein is not None:
|
127 |
vis = show_fragments_and_target(in_fragments, in_protein)
|
128 |
+
return [vis, gr.Dropdown.update(choices=[], value=None, visible=False), None]
|
129 |
|
130 |
|
131 |
def show_fragments(in_fragments):
|
|
|
167 |
vis = ''
|
168 |
if in_protein is not None:
|
169 |
vis = show_target(in_protein)
|
170 |
+
return [None, vis, gr.Dropdown.update(choices=[], value=None, visible=False), None]
|
171 |
|
172 |
|
173 |
def clear_protein_input(in_fragments):
|
174 |
vis = ''
|
175 |
if in_fragments is not None:
|
176 |
vis = show_fragments(in_fragments)
|
177 |
+
return [None, vis, gr.Dropdown.update(choices=[], value=None, visible=False), None]
|
178 |
|
179 |
|
180 |
def click_on_example(example):
|
181 |
fragment_fname, target_fname = example
|
182 |
fragment_path = f'examples/{fragment_fname}' if fragment_fname != '' else None
|
183 |
target_path = f'examples/{target_fname}' if target_fname != '' else None
|
184 |
+
return [fragment_path, target_path] + show_input(fragment_path, target_path)
|
185 |
|
186 |
|
187 |
+
def draw_sample(sample_path, out_files, num_samples):
|
188 |
+
with_protein = (len(out_files) == num_samples + 3)
|
|
|
|
|
|
|
189 |
|
190 |
in_file = out_files[1]
|
191 |
in_sdf = in_file if isinstance(in_file, str) else in_file.name
|
|
|
201 |
input_target_content = read_molecule_content(in_pdb)
|
202 |
target_fmt = in_pdb.split('.')[-1]
|
203 |
|
204 |
+
out_sdf = sample_path if isinstance(sample_path, str) else sample_path.name
|
|
|
205 |
generated_molecule_content = read_molecule_content(out_sdf)
|
206 |
molecule_fmt = out_sdf.split('.')[-1]
|
207 |
|
|
|
233 |
return archive_path
|
234 |
|
235 |
|
236 |
+
def generate(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms):
|
237 |
if in_fragments is None:
|
238 |
return [None, None, None, None]
|
239 |
|
240 |
if in_protein is None:
|
241 |
+
return generate_without_pocket(in_fragments, n_steps, n_atoms, num_samples, selected_atoms)
|
242 |
else:
|
243 |
+
return generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms)
|
244 |
|
245 |
|
246 |
+
def generate_without_pocket(input_file, n_steps, n_atoms, num_samples, selected_atoms):
|
247 |
# Parsing selected atoms (javascript output)
|
248 |
selected_atoms = selected_atoms.strip()
|
249 |
if selected_atoms == '':
|
|
|
306 |
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
307 |
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
308 |
'num_atoms': len(positions),
|
309 |
+
}] * num_samples
|
310 |
+
dataloader = get_dataloader(dataset, batch_size=num_samples, collate_fn=collate_with_fragment_edges)
|
311 |
print('Created dataloader')
|
312 |
|
313 |
ddpm.edm.T = n_steps
|
|
|
329 |
|
330 |
for data in dataloader:
|
331 |
try:
|
332 |
+
generate_linkers(
|
333 |
+
ddpm=ddpm, data=data, num_samples=num_samples, sample_fn=sample_fn, name=name, with_pocket=False
|
334 |
+
)
|
335 |
except Exception as e:
|
336 |
e = str(e).replace('\'', '')
|
337 |
error = f'Caught exception while generating linkers: {e}'
|
338 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
339 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
340 |
|
341 |
+
out_files = try_to_convert_to_sdf(name, num_samples)
|
342 |
out_files = [inp_sdf] + out_files
|
343 |
out_files = [compress(out_files, name=name)] + out_files
|
344 |
+
choice = out_files[2]
|
345 |
|
346 |
return [
|
347 |
+
draw_sample(choice, out_files, num_samples),
|
348 |
out_files,
|
349 |
+
gr.Dropdown.update(
|
350 |
+
choices=out_files[2:],
|
351 |
+
value=choice,
|
352 |
+
visible=True,
|
353 |
+
),
|
354 |
None
|
355 |
]
|
356 |
|
357 |
|
358 |
+
def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms):
|
359 |
# Parsing selected atoms (javascript output)
|
360 |
selected_atoms = selected_atoms.strip()
|
361 |
if selected_atoms == '':
|
|
|
446 |
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
447 |
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
448 |
'num_atoms': len(positions),
|
449 |
+
}] * num_samples
|
450 |
dataset = MOADDataset(data=dataset)
|
451 |
ddpm.val_dataset = dataset
|
452 |
|
453 |
+
dataloader = get_dataloader(dataset, batch_size=num_samples, collate_fn=collate_with_fragment_edges)
|
454 |
print('Created dataloader')
|
455 |
|
456 |
ddpm.edm.T = n_steps
|
|
|
472 |
|
473 |
for data in dataloader:
|
474 |
try:
|
475 |
+
generate_linkers(
|
476 |
+
ddpm=ddpm, data=data, num_samples=num_samples, sample_fn=sample_fn, name=name, with_pocket=True
|
477 |
+
)
|
478 |
except Exception as e:
|
479 |
e = str(e).replace('\'', '')
|
480 |
error = f'Caught exception while generating linkers: {e}'
|
481 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
482 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
483 |
|
484 |
+
out_files = try_to_convert_to_sdf(name, num_samples)
|
485 |
out_files = [inp_sdf, inp_pdb] + out_files
|
486 |
out_files = [compress(out_files, name=name)] + out_files
|
487 |
+
choice = out_files[3]
|
488 |
|
489 |
return [
|
490 |
+
draw_sample(choice, out_files, num_samples),
|
491 |
out_files,
|
492 |
+
gr.Dropdown.update(
|
493 |
+
choices=out_files[3:],
|
494 |
+
value=choice,
|
495 |
+
visible=True,
|
496 |
+
),
|
497 |
None
|
498 |
]
|
499 |
|
|
|
526 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
527 |
step=1
|
528 |
)
|
529 |
+
n_samples = gr.Slider(minimum=5, maximum=50, label="Number of Samples", step=5)
|
530 |
examples = gr.Dataset(
|
531 |
components=[gr.File(visible=False), gr.File(visible=False)],
|
532 |
samples=[
|
|
|
535 |
['examples/3hz1_fragments.sdf', 'examples/3hz1_protein.pdb'],
|
536 |
['examples/5ou2_fragments.sdf', 'examples/5ou2_protein.pdb'],
|
537 |
],
|
|
|
538 |
type='values',
|
539 |
)
|
540 |
|
|
|
547 |
with gr.Column():
|
548 |
gr.Markdown('## Visualization')
|
549 |
gr.Markdown('**Hint:** click on atoms to select anchor points (optionally)')
|
550 |
+
samples = gr.Dropdown(
|
551 |
+
choices=[],
|
552 |
+
value=None,
|
553 |
type='value',
|
554 |
+
multiselect=False,
|
555 |
visible=False,
|
556 |
interactive=True,
|
557 |
+
label='Samples'
|
558 |
)
|
559 |
visualization = gr.HTML()
|
560 |
|
|
|
581 |
examples.click(
|
582 |
fn=click_on_example,
|
583 |
inputs=[examples],
|
584 |
+
outputs=[input_fragments_file, input_protein_file, visualization, samples, hidden]
|
585 |
)
|
586 |
button.click(
|
587 |
fn=generate,
|
588 |
+
inputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, n_samples, hidden],
|
589 |
outputs=[visualization, output_files, samples, hidden],
|
590 |
_js=output.RETURN_SELECTION_JS,
|
591 |
)
|
592 |
+
samples.select(
|
593 |
fn=draw_sample,
|
594 |
+
inputs=[samples, output_files, n_samples],
|
595 |
outputs=[visualization],
|
596 |
)
|
597 |
demo.load(_js=output.STARTUP_JS)
|
output.py
CHANGED
@@ -365,7 +365,7 @@ STARTUP_JS = """
|
|
365 |
"""
|
366 |
|
367 |
RETURN_SELECTION_JS = """
|
368 |
-
(input_file, input_protein_file, n_steps, n_atoms,
|
369 |
let selected = []
|
370 |
for (const [atom, add] of Object.entries(window.selected_elements)) {
|
371 |
if (add) {
|
@@ -378,6 +378,6 @@ RETURN_SELECTION_JS = """
|
|
378 |
}
|
379 |
}
|
380 |
console.log("Finished parsing");
|
381 |
-
return [input_file, input_protein_file, n_steps, n_atoms,
|
382 |
}
|
383 |
"""
|
|
|
365 |
"""
|
366 |
|
367 |
RETURN_SELECTION_JS = """
|
368 |
+
(input_file, input_protein_file, n_steps, n_atoms, n_samples, hidden) => {
|
369 |
let selected = []
|
370 |
for (const [atom, add] of Object.entries(window.selected_elements)) {
|
371 |
if (add) {
|
|
|
378 |
}
|
379 |
}
|
380 |
console.log("Finished parsing");
|
381 |
+
return [input_file, input_protein_file, n_steps, n_atoms, n_samples, selected.join(",")];
|
382 |
}
|
383 |
"""
|
src/generation.py
CHANGED
@@ -9,10 +9,8 @@ from src.visualizer import save_xyz_file
|
|
9 |
from src.utils import FoundNaNException
|
10 |
from src.datasets import get_one_hot
|
11 |
|
12 |
-
N_SAMPLES = 5
|
13 |
|
14 |
-
|
15 |
-
def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False):
|
16 |
chain = node_mask = None
|
17 |
for i in range(5):
|
18 |
try:
|
@@ -39,14 +37,14 @@ def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False):
|
|
39 |
if with_pocket:
|
40 |
node_mask[torch.where(data['pocket_mask'])] = 0
|
41 |
|
42 |
-
names = [f'output_{i + 1}_{name}' for i in range(
|
43 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
44 |
print('Saved XYZ files')
|
45 |
|
46 |
|
47 |
-
def try_to_convert_to_sdf(name):
|
48 |
out_files = []
|
49 |
-
for i in range(
|
50 |
out_xyz = f'results/output_{i + 1}_{name}_.xyz'
|
51 |
out_sdf = f'results/output_{i + 1}_{name}_.sdf'
|
52 |
subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
|
|
|
9 |
from src.utils import FoundNaNException
|
10 |
from src.datasets import get_one_hot
|
11 |
|
|
|
12 |
|
13 |
+
def generate_linkers(ddpm, data, num_samples, sample_fn, name, with_pocket=False):
|
|
|
14 |
chain = node_mask = None
|
15 |
for i in range(5):
|
16 |
try:
|
|
|
37 |
if with_pocket:
|
38 |
node_mask[torch.where(data['pocket_mask'])] = 0
|
39 |
|
40 |
+
names = [f'output_{i + 1}_{name}' for i in range(num_samples)]
|
41 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
42 |
print('Saved XYZ files')
|
43 |
|
44 |
|
45 |
+
def try_to_convert_to_sdf(name, num_samples):
|
46 |
out_files = []
|
47 |
+
for i in range(num_samples):
|
48 |
out_xyz = f'results/output_{i + 1}_{name}_.xyz'
|
49 |
out_sdf = f'results/output_{i + 1}_{name}_.sdf'
|
50 |
subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
|