RRoundTable commited on
Commit
950d956
·
1 Parent(s): fdbda3f
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -36,12 +36,16 @@ imgs_tensor = torch.zeros(4, 3, patch_h * 14, patch_w * 14)
36
  # PCA
37
  pca = PCA(n_components=3)
38
 
39
- def query_image(img1, img2, img3, img4) -> List[np.ndarray]:
 
 
 
 
40
 
41
  # Transform
42
  imgs = [img1, img2, img3, img4]
43
  for i, img in enumerate(imgs):
44
- img = np.transpose(img, (2, 0, 1))
45
  imgs_tensor[i] = transform(torch.Tensor(img))
46
 
47
  # Get feature from patches
@@ -56,7 +60,10 @@ def query_image(img1, img2, img3, img4) -> List[np.ndarray]:
56
  pca_feature = sklearn.preprocessing.minmax_scale(pca_features)
57
 
58
  # Foreground/Background
59
- pca_features_bg = pca_features[:, 0] < 0
 
 
 
60
  pca_features_fg = ~pca_features_bg
61
 
62
  # PCA with only foreground
@@ -79,10 +86,14 @@ DINOV2 PCA
79
  """
80
  demo = gr.Interface(
81
  query_image,
82
- inputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image()],
83
  outputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image()],
84
  title="DINOV2 PCA",
85
  description=description,
86
- examples=[],
 
 
 
 
87
  )
88
  demo.launch()
 
36
  # PCA
37
  pca = PCA(n_components=3)
38
 
39
+ def query_image(
40
+ img1, img2, img3, img4,
41
+ background_threshold,
42
+ is_foreground_larger_than_threshold,
43
+ ) -> List[np.ndarray]:
44
 
45
  # Transform
46
  imgs = [img1, img2, img3, img4]
47
  for i, img in enumerate(imgs):
48
+ img = np.transpose(img, (2, 0, 1)) / 255
49
  imgs_tensor[i] = transform(torch.Tensor(img))
50
 
51
  # Get feature from patches
 
60
  pca_feature = sklearn.preprocessing.minmax_scale(pca_features)
61
 
62
  # Foreground/Background
63
+ if is_foreground_larger_than_threshold:
64
+ pca_features_bg = pca_features[:, 0] < background_threshold
65
+ else:
66
+ pca_features_bg = pca_features[:, 0] > background_threshold
67
  pca_features_fg = ~pca_features_bg
68
 
69
  # PCA with only foreground
 
86
  """
87
  demo = gr.Interface(
88
  query_image,
89
+ inputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image(), gr.Slider(-1, 1, value=0.1), gr.Checkbox(label="foreground is larger than threshold", value=True) ],
90
  outputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image()],
91
  title="DINOV2 PCA",
92
  description=description,
93
+ examples=[
94
+ ["assets/1.png", "assets/2.png","assets/3.png","assets/4.png", 0.9, True],
95
+ ["assets/5.png", "assets/6.png","assets/7.png","assets/8.png", 0.6, True],
96
+ ["assets/9.png", "assets/10.png","assets/11.png","assets/12.png", 0.6, True],
97
+ ]
98
  )
99
  demo.launch()