igashov commited on
Commit
ff9d86b
1 Parent(s): 52bf9df
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -17,17 +17,19 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  os.makedirs("results", exist_ok=True)
18
  os.makedirs("models", exist_ok=True)
19
 
20
- subprocess.run(
21
- 'wget https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1 -O models/geom_size_gnn.ckpt',
22
- shell=True
23
- )
 
24
  size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
25
  print('Loaded SizeGNN model')
26
 
27
- subprocess.run(
28
- 'wget https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1 -O models/geom_difflinker.ckpt',
29
- shell=True
30
- )
 
31
  ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
32
  print('Loaded diffusion model')
33
 
@@ -174,4 +176,4 @@ with demo:
174
  outputs=[visualization, output_files],
175
  )
176
 
177
- demo.launch()
 
17
  os.makedirs("results", exist_ok=True)
18
  os.makedirs("models", exist_ok=True)
19
 
20
+ size_gnn_path = 'models/geom_size_gnn.ckpt'
21
+ if not os.path.exists(size_gnn_path):
22
+ print('Downloading SizeGNN model...')
23
+ link = 'https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1'
24
+ subprocess.run(f'wget {link} -O {size_gnn_path}', shell=True)
25
  size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
26
  print('Loaded SizeGNN model')
27
 
28
+ diffusion_path = 'models/geom_difflinker.ckpt'
29
+ if not os.path.exists(diffusion_path):
30
+ print('Downloading Diffusion model...')
31
+ link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
32
+ subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
33
  ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
34
  print('Loaded diffusion model')
35
 
 
176
  outputs=[visualization, output_files],
177
  )
178
 
179
+ demo.launch(share=True)