Ana Sanchez commited on
Commit
e9648db
·
1 Parent(s): e5b0269

Fix cuda bug

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -52,7 +52,7 @@ def convert_models_to_fp32(model):
52
 
53
 
54
  def load(model_path, device, model, image_resolution):
55
- state_dict = torch.load(model_path, map_location="cpu")
56
  state_dict = state_dict["state_dict"]
57
 
58
  model_config_file = f"{model.replace('/', '-')}.json"
@@ -383,7 +383,7 @@ def molecules_from_image():
383
  predefined_features = False
384
  else:
385
  mol_index = pd.read_csv("cellpainting-unique-molecule.csv")
386
- mol_features_torch = torch.load("all_molecule_cellpainting_features.pkl")
387
  mol_features = mol_features_torch["mol_features"]
388
  mol_ids = mol_features_torch["mol_ids"]
389
  print(len(mol_ids))
@@ -453,7 +453,7 @@ def images_from_molecule():
453
  st.write("")
454
 
455
 
456
- img_features_torch = torch.load(image_features)
457
  img_features = img_features_torch["img_features"]
458
  img_ids = img_features_torch["img_ids"]
459
 
 
52
 
53
 
54
  def load(model_path, device, model, image_resolution):
55
+ state_dict = torch.load(model_path, map_location=device)
56
  state_dict = state_dict["state_dict"]
57
 
58
  model_config_file = f"{model.replace('/', '-')}.json"
 
383
  predefined_features = False
384
  else:
385
  mol_index = pd.read_csv("cellpainting-unique-molecule.csv")
386
+ mol_features_torch = torch.load("all_molecule_cellpainting_features.pkl", map_location=device)
387
  mol_features = mol_features_torch["mol_features"]
388
  mol_ids = mol_features_torch["mol_ids"]
389
  print(len(mol_ids))
 
453
  st.write("")
454
 
455
 
456
+ img_features_torch = torch.load(image_features, map_location=device)
457
  img_features = img_features_torch["img_features"]
458
  img_ids = img_features_torch["img_ids"]
459