igashov commited on
Commit
aa9b17f
1 Parent(s): eb031b7
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -15,7 +15,6 @@ from src.lightning import DDPM
15
  from src.linker_size_lightning import SizeClassifier
16
 
17
  N_SAMPLES = 5
18
- N_STEPS = 10
19
 
20
  parser = argparse.ArgumentParser()
21
  parser.add_argument('--ip', type=str, default=None)
@@ -39,7 +38,6 @@ if not os.path.exists(diffusion_path):
39
  link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
40
  subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
41
  ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
42
- ddpm.edm.T = N_STEPS
43
  print('Loaded diffusion model')
44
 
45
 
@@ -111,7 +109,7 @@ def draw_sample(idx, out_files):
111
  return output.IFRAME_TEMPLATE.format(html=html)
112
 
113
 
114
- def generate(input_file):
115
  if input_file is None:
116
  return ''
117
 
@@ -155,6 +153,8 @@ def generate(input_file):
155
  dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
156
  print('Created dataloader')
157
 
 
 
158
  for data in dataloader:
159
  chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
160
  print('Generated linker')
@@ -188,7 +188,8 @@ with demo:
188
  with gr.Column():
189
  gr.Markdown('## Input Fragments')
190
  gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
191
- input_file = gr.File(file_count='single', label='Input Fragments')
 
192
  examples = gr.Dataset(
193
  components=[gr.File(visible=False)],
194
  samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
@@ -219,13 +220,13 @@ with demo:
219
  outputs=[visualization],
220
  )
221
  examples.click(
222
- fn=lambda idx: [f'examples/example_{idx+1}.sdf', show_input(f'examples/example_{idx+1}.sdf')],
223
  inputs=[examples],
224
- outputs=[input_file, visualization]
225
  )
226
  button.click(
227
  fn=generate,
228
- inputs=[input_file],
229
  outputs=[visualization, output_files, samples],
230
  )
231
  samples.change(
 
15
  from src.linker_size_lightning import SizeClassifier
16
 
17
  N_SAMPLES = 5
 
18
 
19
  parser = argparse.ArgumentParser()
20
  parser.add_argument('--ip', type=str, default=None)
 
38
  link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
39
  subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
40
  ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
 
41
  print('Loaded diffusion model')
42
 
43
 
 
109
  return output.IFRAME_TEMPLATE.format(html=html)
110
 
111
 
112
+ def generate(input_file, n_steps):
113
  if input_file is None:
114
  return ''
115
 
 
153
  dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
154
  print('Created dataloader')
155
 
156
+ ddpm.edm.T = n_steps
157
+
158
  for data in dataloader:
159
  chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
160
  print('Generated linker')
 
188
  with gr.Column():
189
  gr.Markdown('## Input Fragments')
190
  gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
191
+ input_file = gr.File(file_count='single', label='Input Fragments')
192
+ n_steps = gr.Slider(minimum=10, maximum=500, label="Number of Diffusion Steps", step=10)
193
  examples = gr.Dataset(
194
  components=[gr.File(visible=False)],
195
  samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
 
220
  outputs=[visualization],
221
  )
222
  examples.click(
223
+ fn=lambda idx: [f'examples/example_{idx+1}.sdf', 10, show_input(f'examples/example_{idx+1}.sdf')],
224
  inputs=[examples],
225
+ outputs=[input_file, n_steps, visualization]
226
  )
227
  button.click(
228
  fn=generate,
229
+ inputs=[input_file, n_steps],
230
  outputs=[visualization, output_files, samples],
231
  )
232
  samples.change(