Spaces:
Sleeping
Sleeping
igashov
commited on
Commit
•
aa9b17f
1
Parent(s):
eb031b7
n_steps
Browse files
app.py
CHANGED
@@ -15,7 +15,6 @@ from src.lightning import DDPM
|
|
15 |
from src.linker_size_lightning import SizeClassifier
|
16 |
|
17 |
N_SAMPLES = 5
|
18 |
-
N_STEPS = 10
|
19 |
|
20 |
parser = argparse.ArgumentParser()
|
21 |
parser.add_argument('--ip', type=str, default=None)
|
@@ -39,7 +38,6 @@ if not os.path.exists(diffusion_path):
|
|
39 |
link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
|
40 |
subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
|
41 |
ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
|
42 |
-
ddpm.edm.T = N_STEPS
|
43 |
print('Loaded diffusion model')
|
44 |
|
45 |
|
@@ -111,7 +109,7 @@ def draw_sample(idx, out_files):
|
|
111 |
return output.IFRAME_TEMPLATE.format(html=html)
|
112 |
|
113 |
|
114 |
-
def generate(input_file):
|
115 |
if input_file is None:
|
116 |
return ''
|
117 |
|
@@ -155,6 +153,8 @@ def generate(input_file):
|
|
155 |
dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
|
156 |
print('Created dataloader')
|
157 |
|
|
|
|
|
158 |
for data in dataloader:
|
159 |
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
|
160 |
print('Generated linker')
|
@@ -188,7 +188,8 @@ with demo:
|
|
188 |
with gr.Column():
|
189 |
gr.Markdown('## Input Fragments')
|
190 |
gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
|
191 |
-
input_file = gr.File(file_count='single', label='Input Fragments')
|
|
|
192 |
examples = gr.Dataset(
|
193 |
components=[gr.File(visible=False)],
|
194 |
samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
|
@@ -219,13 +220,13 @@ with demo:
|
|
219 |
outputs=[visualization],
|
220 |
)
|
221 |
examples.click(
|
222 |
-
fn=lambda idx: [f'examples/example_{idx+1}.sdf', show_input(f'examples/example_{idx+1}.sdf')],
|
223 |
inputs=[examples],
|
224 |
-
outputs=[input_file, visualization]
|
225 |
)
|
226 |
button.click(
|
227 |
fn=generate,
|
228 |
-
inputs=[input_file],
|
229 |
outputs=[visualization, output_files, samples],
|
230 |
)
|
231 |
samples.change(
|
|
|
15 |
from src.linker_size_lightning import SizeClassifier
|
16 |
|
17 |
N_SAMPLES = 5
|
|
|
18 |
|
19 |
parser = argparse.ArgumentParser()
|
20 |
parser.add_argument('--ip', type=str, default=None)
|
|
|
38 |
link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
|
39 |
subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
|
40 |
ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
|
|
|
41 |
print('Loaded diffusion model')
|
42 |
|
43 |
|
|
|
109 |
return output.IFRAME_TEMPLATE.format(html=html)
|
110 |
|
111 |
|
112 |
+
def generate(input_file, n_steps):
|
113 |
if input_file is None:
|
114 |
return ''
|
115 |
|
|
|
153 |
dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
|
154 |
print('Created dataloader')
|
155 |
|
156 |
+
ddpm.edm.T = n_steps
|
157 |
+
|
158 |
for data in dataloader:
|
159 |
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
|
160 |
print('Generated linker')
|
|
|
188 |
with gr.Column():
|
189 |
gr.Markdown('## Input Fragments')
|
190 |
gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
|
191 |
+
input_file = gr.File(file_count='single', label='Input Fragments')
|
192 |
+
n_steps = gr.Slider(minimum=10, maximum=500, label="Number of Diffusion Steps", step=10)
|
193 |
examples = gr.Dataset(
|
194 |
components=[gr.File(visible=False)],
|
195 |
samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
|
|
|
220 |
outputs=[visualization],
|
221 |
)
|
222 |
examples.click(
|
223 |
+
fn=lambda idx: [f'examples/example_{idx+1}.sdf', 10, show_input(f'examples/example_{idx+1}.sdf')],
|
224 |
inputs=[examples],
|
225 |
+
outputs=[input_file, n_steps, visualization]
|
226 |
)
|
227 |
button.click(
|
228 |
fn=generate,
|
229 |
+
inputs=[input_file, n_steps],
|
230 |
outputs=[visualization, output_files, samples],
|
231 |
)
|
232 |
samples.change(
|