igashov commited on
Commit
3c26059
1 Parent(s): aebc0d2

multiple samples

Browse files
Files changed (1) hide show
  1. app.py +20 -11
app.py CHANGED
@@ -14,6 +14,8 @@ from src.datasets import get_dataloader, collate_with_fragment_edges, parse_mole
14
  from src.lightning import DDPM
15
  from src.linker_size_lightning import SizeClassifier
16
 
 
 
17
  parser = argparse.ArgumentParser()
18
  parser.add_argument('--ip', type=str, default=None)
19
  args = parser.parse_args()
@@ -103,10 +105,8 @@ def generate(input_file):
103
  molecule = read_molecule(path)
104
  molecule = Chem.RemoveAllHs(molecule)
105
  name = '.'.join(path.split('/')[-1].split('.')[:-1])
106
- inp_sdf = f'results/{name}_input.sdf'
107
- inp_xyz = f'results/{name}_input.xyz'
108
- out_sdf = f'results/{name}_output.sdf'
109
- out_xyz = f'results/{name}_output.xyz'
110
  except Exception as e:
111
  return f'Could not read the molecule: {e}'
112
 
@@ -133,8 +133,8 @@ def generate(input_file):
133
  'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
134
  'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
135
  'num_atoms': len(positions),
136
- }]
137
- dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges)
138
  print('Created dataloader')
139
 
140
  for data in dataloader:
@@ -142,12 +142,21 @@ def generate(input_file):
142
  print('Generated linker')
143
  x = chain[0][:, :, :ddpm.n_dims]
144
  h = chain[0][:, :, ddpm.n_dims:]
145
- save_xyz_file('results', h, x, node_mask, names=[name], is_geom=True, suffix='output')
146
- print('Saved XYZ file')
147
- subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
148
- print('Converted to SDF')
149
  break
150
 
 
 
 
 
 
 
 
 
 
 
151
  input_fragments_content = read_molecule_content(inp_sdf)
152
  generated_molecule_content = read_molecule_content(out_sdf)
153
  html = output.SAMPLES_RENDERING_TEMPLATE.format(
@@ -158,7 +167,7 @@ def generate(input_file):
158
  )
159
  return [
160
  output.IFRAME_TEMPLATE.format(html=html),
161
- [inp_sdf, inp_xyz, out_sdf, out_xyz],
162
  ]
163
 
164
 
 
14
  from src.lightning import DDPM
15
  from src.linker_size_lightning import SizeClassifier
16
 
17
+ N_SAMPLES = 5
18
+
19
  parser = argparse.ArgumentParser()
20
  parser.add_argument('--ip', type=str, default=None)
21
  args = parser.parse_args()
 
105
  molecule = read_molecule(path)
106
  molecule = Chem.RemoveAllHs(molecule)
107
  name = '.'.join(path.split('/')[-1].split('.')[:-1])
108
+ inp_sdf = f'results/input_{name}.sdf'
109
+ inp_xyz = f'results/input_{name}.xyz'
 
 
110
  except Exception as e:
111
  return f'Could not read the molecule: {e}'
112
 
 
133
  'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
134
  'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
135
  'num_atoms': len(positions),
136
+ }] * N_SAMPLES
137
+ dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
138
  print('Created dataloader')
139
 
140
  for data in dataloader:
 
142
  print('Generated linker')
143
  x = chain[0][:, :, :ddpm.n_dims]
144
  h = chain[0][:, :, ddpm.n_dims:]
145
+ names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
146
+ save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
147
+ print('Saved XYZ files')
 
148
  break
149
 
150
+ out_files = []
151
+ for i in range(N_SAMPLES):
152
+ out_xyz = f'results/output_{i+1}_{name}_.xyz'
153
+ out_sdf = f'results/output_{i+1}_{name}_.sdf'
154
+ subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
155
+ out_files.append(out_xyz)
156
+ out_files.append(out_sdf)
157
+ print('Converted to SDF')
158
+
159
+ out_sdf = f'results/output_1_{name}_.sdf'
160
  input_fragments_content = read_molecule_content(inp_sdf)
161
  generated_molecule_content = read_molecule_content(out_sdf)
162
  html = output.SAMPLES_RENDERING_TEMPLATE.format(
 
167
  )
168
  return [
169
  output.IFRAME_TEMPLATE.format(html=html),
170
+ [inp_sdf, inp_xyz] + out_files,
171
  ]
172
 
173