RRoundTable commited on
Commit
65137b3
·
1 Parent(s): 878e93e
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -33,7 +33,7 @@ transform = T.Compose([
33
  height, width = patch_size * patch_height_nums, patch_size * patch_width_nums
34
 
35
  # DINOV2
36
- model = torch.hub.load(DINOV2_REPO, DINOV2_MODEL)
37
 
38
  # faiss
39
  K= 5
@@ -46,7 +46,7 @@ def read_image(image_path: str) -> np.ndarray:
46
 
47
  def infer(image: np.ndarray) -> np.ndarray:
48
  image = np.transpose(image, (2, 0, 1)) // 255
49
- transformed_image = transform(torch.Tensor(image))
50
  embedding = model(torch.unsqueeze(transformed_image, 0))
51
  return embedding.detach().cpu().numpy()
52
 
 
33
  height, width = patch_size * patch_height_nums, patch_size * patch_width_nums
34
 
35
  # DINOV2
36
+ model = torch.hub.load(DINOV2_REPO, DINOV2_MODEL).to(DEVICE)
37
 
38
  # faiss
39
  K= 5
 
46
 
47
  def infer(image: np.ndarray) -> np.ndarray:
48
  image = np.transpose(image, (2, 0, 1)) // 255
49
+ transformed_image = transform(torch.Tensor(image)).to(DEVICE)
50
  embedding = model(torch.unsqueeze(transformed_image, 0))
51
  return embedding.detach().cpu().numpy()
52