turhancan97 commited on
Commit
774a99f
·
1 Parent(s): def8bfd

bottom 25 model added to visualize in seperate tab

Browse files
app.py CHANGED
@@ -11,12 +11,17 @@ from torchvision.transforms import v2
11
  from model import MAE_ViT, MAE_Encoder, MAE_Decoder, MAE_Encoder_FeatureExtractor
12
 
13
  path = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']]
14
- model_name = "vit-t-mae-pretrain.pt"
15
- model = torch.load(model_name, map_location='cpu')
16
-
17
- model.eval()
18
  device = torch.device("cpu")
19
- model.to(device)
 
 
 
 
 
 
 
 
 
20
 
21
  transform = v2.Compose([
22
  v2.Resize((32, 32)),
@@ -39,13 +44,12 @@ def show_image(img, title):
39
  plt.title(title)
40
 
41
  # Visualize a Single Image
42
- def visualize_single_image(image_path):
43
  img = load_image(image_path, transform).to(device)
44
 
45
  # Run inference
46
- model.eval()
47
  with torch.no_grad():
48
- predicted_img, mask = model(img)
49
 
50
  # Convert the tensor back to a displayable image
51
  # masked image
@@ -82,6 +86,47 @@ def visualize_single_image(image_path):
82
 
83
  return np.array(plt.imread("output.png"))
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  inputs_image = [
86
  gr.components.Image(type="filepath", label="Input Image"),
87
  ]
@@ -90,11 +135,26 @@ outputs_image = [
90
  gr.components.Image(type="numpy", label="Output Image"),
91
  ]
92
 
93
- gr.Interface(
94
- fn=visualize_single_image,
95
  inputs=inputs_image,
96
  outputs=outputs_image,
97
  examples=path,
 
98
  title="MAE-ViT Image Reconstruction",
99
  description="This is a demo of the MAE-ViT model for image reconstruction.",
100
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from model import MAE_ViT, MAE_Encoder, MAE_Decoder, MAE_Encoder_FeatureExtractor
12
 
13
  path = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']]
 
 
 
 
14
  device = torch.device("cpu")
15
+
16
+ model_name = "model/no_mode/vit-t-mae-pretrain.pt"
17
+ model_no_mode = torch.load(model_name, map_location='cpu')
18
+ model_no_mode.eval()
19
+ model_no_mode.to(device)
20
+
21
+ model_name = "model/bottom_256/vit-t-mae-pretrain.pt"
22
+ model_pca_mode = torch.load(model_name, map_location='cpu')
23
+ model_pca_mode.eval()
24
+ model_pca_mode.to(device)
25
 
26
  transform = v2.Compose([
27
  v2.Resize((32, 32)),
 
44
  plt.title(title)
45
 
46
  # Visualize a Single Image
47
+ def visualize_single_image_no_mode(image_path):
48
  img = load_image(image_path, transform).to(device)
49
 
50
  # Run inference
 
51
  with torch.no_grad():
52
+ predicted_img, mask = model_no_mode(img)
53
 
54
  # Convert the tensor back to a displayable image
55
  # masked image
 
86
 
87
  return np.array(plt.imread("output.png"))
88
 
89
+ def visualize_single_image_pca_mode(image_path):
90
+ img = load_image(image_path, transform).to(device)
91
+
92
+ # Run inference
93
+ with torch.no_grad():
94
+ predicted_img, mask = model_pca_mode(img)
95
+
96
+ # Convert the tensor back to a displayable image
97
+ # masked image
98
+ im_masked = img * (1 - mask)
99
+
100
+ # MAE reconstruction pasted with visible patches
101
+ im_paste = img * (1 - mask) + predicted_img * mask
102
+
103
+ # remove the batch dimension
104
+ im_masked = im_masked[0]
105
+ predicted_img = predicted_img[0]
106
+ im_paste = im_paste[0]
107
+
108
+ # make the plt figure larger
109
+ plt.figure(figsize=(18, 8))
110
+
111
+ plt.subplot(1, 4, 1)
112
+ show_image(img, "original")
113
+
114
+ plt.subplot(1, 4, 2)
115
+ show_image(im_masked, "masked")
116
+
117
+ plt.subplot(1, 4, 3)
118
+ show_image(predicted_img, "reconstruction")
119
+
120
+ plt.subplot(1, 4, 4)
121
+ show_image(im_paste, "reconstruction + visible")
122
+
123
+ plt.tight_layout()
124
+
125
+ # convert the plt figure to a numpy array
126
+ plt.savefig("output.png")
127
+
128
+ return np.array(plt.imread("output.png"))
129
+
130
  inputs_image = [
131
  gr.components.Image(type="filepath", label="Input Image"),
132
  ]
 
135
  gr.components.Image(type="numpy", label="Output Image"),
136
  ]
137
 
138
+ inference_no_mode = gr.Interface(
139
+ fn=visualize_single_image_no_mode,
140
  inputs=inputs_image,
141
  outputs=outputs_image,
142
  examples=path,
143
+ cache_examples = False,
144
  title="MAE-ViT Image Reconstruction",
145
  description="This is a demo of the MAE-ViT model for image reconstruction.",
146
+ )
147
+
148
+ inference_pca_mode = gr.Interface(
149
+ fn=visualize_single_image_pca_mode,
150
+ inputs=inputs_image,
151
+ outputs=outputs_image,
152
+ examples=path,
153
+ title="MAE-ViT Image Reconstruction",
154
+ description="This is a demo of the MAE-ViT model for image reconstruction.",
155
+ )
156
+
157
+ gr.TabbedInterface(
158
+ [inference_no_mode, inference_pca_mode],
159
+ tab_names=['Normal Mode', 'PCA Mode']
160
+ ).queue().launch()
model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/bottom_25/vit-t-mae-pretrain.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e00a227576d93ed464e783ba281866b7441a37123c649a55e648d6e5553b66b0
3
+ size 28973864
vit-t-mae-pretrain.pt → model/no_mode/vit-t-mae-pretrain.pt RENAMED
File without changes