Spaces:
Sleeping
Sleeping
Commit
·
774a99f
1
Parent(s):
def8bfd
bottom 25 model added to visualize in seperate tab
Browse files
app.py
CHANGED
@@ -11,12 +11,17 @@ from torchvision.transforms import v2
|
|
11 |
from model import MAE_ViT, MAE_Encoder, MAE_Decoder, MAE_Encoder_FeatureExtractor
|
12 |
|
13 |
path = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']]
|
14 |
-
model_name = "vit-t-mae-pretrain.pt"
|
15 |
-
model = torch.load(model_name, map_location='cpu')
|
16 |
-
|
17 |
-
model.eval()
|
18 |
device = torch.device("cpu")
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
transform = v2.Compose([
|
22 |
v2.Resize((32, 32)),
|
@@ -39,13 +44,12 @@ def show_image(img, title):
|
|
39 |
plt.title(title)
|
40 |
|
41 |
# Visualize a Single Image
|
42 |
-
def
|
43 |
img = load_image(image_path, transform).to(device)
|
44 |
|
45 |
# Run inference
|
46 |
-
model.eval()
|
47 |
with torch.no_grad():
|
48 |
-
predicted_img, mask =
|
49 |
|
50 |
# Convert the tensor back to a displayable image
|
51 |
# masked image
|
@@ -82,6 +86,47 @@ def visualize_single_image(image_path):
|
|
82 |
|
83 |
return np.array(plt.imread("output.png"))
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
inputs_image = [
|
86 |
gr.components.Image(type="filepath", label="Input Image"),
|
87 |
]
|
@@ -90,11 +135,26 @@ outputs_image = [
|
|
90 |
gr.components.Image(type="numpy", label="Output Image"),
|
91 |
]
|
92 |
|
93 |
-
gr.Interface(
|
94 |
-
fn=
|
95 |
inputs=inputs_image,
|
96 |
outputs=outputs_image,
|
97 |
examples=path,
|
|
|
98 |
title="MAE-ViT Image Reconstruction",
|
99 |
description="This is a demo of the MAE-ViT model for image reconstruction.",
|
100 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from model import MAE_ViT, MAE_Encoder, MAE_Decoder, MAE_Encoder_FeatureExtractor
|
12 |
|
13 |
path = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']]
|
|
|
|
|
|
|
|
|
14 |
device = torch.device("cpu")
|
15 |
+
|
16 |
+
model_name = "model/no_mode/vit-t-mae-pretrain.pt"
|
17 |
+
model_no_mode = torch.load(model_name, map_location='cpu')
|
18 |
+
model_no_mode.eval()
|
19 |
+
model_no_mode.to(device)
|
20 |
+
|
21 |
+
model_name = "model/bottom_256/vit-t-mae-pretrain.pt"
|
22 |
+
model_pca_mode = torch.load(model_name, map_location='cpu')
|
23 |
+
model_pca_mode.eval()
|
24 |
+
model_pca_mode.to(device)
|
25 |
|
26 |
transform = v2.Compose([
|
27 |
v2.Resize((32, 32)),
|
|
|
44 |
plt.title(title)
|
45 |
|
46 |
# Visualize a Single Image
|
47 |
+
def visualize_single_image_no_mode(image_path):
|
48 |
img = load_image(image_path, transform).to(device)
|
49 |
|
50 |
# Run inference
|
|
|
51 |
with torch.no_grad():
|
52 |
+
predicted_img, mask = model_no_mode(img)
|
53 |
|
54 |
# Convert the tensor back to a displayable image
|
55 |
# masked image
|
|
|
86 |
|
87 |
return np.array(plt.imread("output.png"))
|
88 |
|
89 |
+
def visualize_single_image_pca_mode(image_path):
|
90 |
+
img = load_image(image_path, transform).to(device)
|
91 |
+
|
92 |
+
# Run inference
|
93 |
+
with torch.no_grad():
|
94 |
+
predicted_img, mask = model_pca_mode(img)
|
95 |
+
|
96 |
+
# Convert the tensor back to a displayable image
|
97 |
+
# masked image
|
98 |
+
im_masked = img * (1 - mask)
|
99 |
+
|
100 |
+
# MAE reconstruction pasted with visible patches
|
101 |
+
im_paste = img * (1 - mask) + predicted_img * mask
|
102 |
+
|
103 |
+
# remove the batch dimension
|
104 |
+
im_masked = im_masked[0]
|
105 |
+
predicted_img = predicted_img[0]
|
106 |
+
im_paste = im_paste[0]
|
107 |
+
|
108 |
+
# make the plt figure larger
|
109 |
+
plt.figure(figsize=(18, 8))
|
110 |
+
|
111 |
+
plt.subplot(1, 4, 1)
|
112 |
+
show_image(img, "original")
|
113 |
+
|
114 |
+
plt.subplot(1, 4, 2)
|
115 |
+
show_image(im_masked, "masked")
|
116 |
+
|
117 |
+
plt.subplot(1, 4, 3)
|
118 |
+
show_image(predicted_img, "reconstruction")
|
119 |
+
|
120 |
+
plt.subplot(1, 4, 4)
|
121 |
+
show_image(im_paste, "reconstruction + visible")
|
122 |
+
|
123 |
+
plt.tight_layout()
|
124 |
+
|
125 |
+
# convert the plt figure to a numpy array
|
126 |
+
plt.savefig("output.png")
|
127 |
+
|
128 |
+
return np.array(plt.imread("output.png"))
|
129 |
+
|
130 |
inputs_image = [
|
131 |
gr.components.Image(type="filepath", label="Input Image"),
|
132 |
]
|
|
|
135 |
gr.components.Image(type="numpy", label="Output Image"),
|
136 |
]
|
137 |
|
138 |
+
inference_no_mode = gr.Interface(
|
139 |
+
fn=visualize_single_image_no_mode,
|
140 |
inputs=inputs_image,
|
141 |
outputs=outputs_image,
|
142 |
examples=path,
|
143 |
+
cache_examples = False,
|
144 |
title="MAE-ViT Image Reconstruction",
|
145 |
description="This is a demo of the MAE-ViT model for image reconstruction.",
|
146 |
+
)
|
147 |
+
|
148 |
+
inference_pca_mode = gr.Interface(
|
149 |
+
fn=visualize_single_image_pca_mode,
|
150 |
+
inputs=inputs_image,
|
151 |
+
outputs=outputs_image,
|
152 |
+
examples=path,
|
153 |
+
title="MAE-ViT Image Reconstruction",
|
154 |
+
description="This is a demo of the MAE-ViT model for image reconstruction.",
|
155 |
+
)
|
156 |
+
|
157 |
+
gr.TabbedInterface(
|
158 |
+
[inference_no_mode, inference_pca_mode],
|
159 |
+
tab_names=['Normal Mode', 'PCA Mode']
|
160 |
+
).queue().launch()
|
model/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
model/bottom_25/vit-t-mae-pretrain.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e00a227576d93ed464e783ba281866b7441a37123c649a55e648d6e5553b66b0
|
3 |
+
size 28973864
|
vit-t-mae-pretrain.pt → model/no_mode/vit-t-mae-pretrain.pt
RENAMED
File without changes
|