Ana Sanchez commited on
Commit
78d5260
·
1 Parent(s): 8dd5d91

Allow model selection

Browse files
app.py CHANGED
@@ -43,7 +43,6 @@ imgname = "I1"
43
 
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
45
  model_type = "RN50"
46
- image_resolution = 520
47
 
48
  ######### CLOOME FUNCTIONS #########
49
  def convert_models_to_fp32(model):
@@ -410,9 +409,11 @@ def molecules_from_image(top_n):
410
  print(mol_probs.sum(dim=-1))
411
  print((top_probs, top_labels))
412
 
413
- def images_from_molecule(top_n):
414
  #st.markdown("Enter a query molecule in [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) format",)
415
  smiles = st.text_input("Enter a query molecule in SMILES format", value="CC(=O)OC1=CC=CC=C1C(=O)O", placeholder="CC(=O)OC1=CC=CC=C1C(=O)O")
 
 
416
  if smiles:
417
  smiles = [smiles]
418
  morgan = [morgan_from_smiles(s) for s in smiles]
@@ -423,7 +424,7 @@ def images_from_molecule(top_n):
423
  fps_fname = save_hdf(morgan, molnames, molpath)
424
  mol_imgs = draw_molecules(smiles)
425
 
426
- mol_features, mol_ids = main(mol_index, MODEL_PATH, model_type, mol_path=molpath, image_resolution=image_resolution)
427
 
428
  col1, col2, col3 = st.columns(3)
429
 
@@ -494,5 +495,27 @@ n_objects = st.sidebar.selectbox(
494
  "How many objects would you like to retrieve?",
495
  ("5", "10", "20"))
496
 
497
- page_names_to_funcs[selected_page](n_objects)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
 
 
43
 
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
45
  model_type = "RN50"
 
46
 
47
  ######### CLOOME FUNCTIONS #########
48
  def convert_models_to_fp32(model):
 
409
  print(mol_probs.sum(dim=-1))
410
  print((top_probs, top_labels))
411
 
412
+ def images_from_molecule(top_n, model_path):
413
  #st.markdown("Enter a query molecule in [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) format",)
414
  smiles = st.text_input("Enter a query molecule in SMILES format", value="CC(=O)OC1=CC=CC=C1C(=O)O", placeholder="CC(=O)OC1=CC=CC=C1C(=O)O")
415
+
416
+
417
  if smiles:
418
  smiles = [smiles]
419
  morgan = [morgan_from_smiles(s) for s in smiles]
 
424
  fps_fname = save_hdf(morgan, molnames, molpath)
425
  mol_imgs = draw_molecules(smiles)
426
 
427
+ mol_features, mol_ids = main(mol_index, model_path, model_type, mol_path=molpath, image_resolution=image_resolution)
428
 
429
  col1, col2, col3 = st.columns(3)
430
 
 
495
  "How many objects would you like to retrieve?",
496
  ("5", "10", "20"))
497
 
498
+
499
+ selected_model = st.sidebar.selectbox(
500
+ "Select a CLOOME model to load",
501
+ ("CLOOME (default)", "CLOOME (batch size 128)", "CLOOME (fullres)", "CLOOME (imgres 320)"))
502
+
503
+ model_dict = {
504
+ "CLOOME (default)" : "cloome_default.pt",
505
+ "CLOOME (CLIP imgres 320)" : "cloome_cli_imres320.pt",
506
+ "CLOOME (fullres)" : "cloome_fullres.pt",
507
+ "CLOOME (CLOOB imgres 320)" : "cloome_imres320.pt"
508
+
509
+ }
510
+
511
+ model_file = model_dict[selected_model]
512
+ model_path = os.path.join(datapath, model_file)
513
+
514
+ if model_path.endswith("320.pt"):
515
+ image_resolution = 320
516
+ else:
517
+ image_resolution = 520
518
+
519
+
520
+ page_names_to_funcs[selected_page](n_objects, selected_model)
521
 
data/cloome_cli_imres320.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91d042823294919768aae7bf2a489c6c6e0c9535b8adc1be19adea8fce03cb5a
3
+ size 352014131
data/cloome_default.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39e0c98d47b18ce913f4bcb1a1bc89d26ca9938ee74646c32a46ad236cddbc38
3
+ size 352014131
data/cloome_fullres.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f412fe015caef7d9b9012a7ab8885081b089109abbbee76f5c5c609935151ec5
3
+ size 352013623
data/cloome_imres320.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5eee478ca947ae9708d5432d427fb086e2141f1b0a81d6d9e8fe11fc4f29a1da
3
+ size 352013623