RRoundTable commited on
Commit
3d8379d
·
1 Parent(s): 2b3fda0
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from typing import List
8
+ import glob
9
+ import cv2
10
+ import pickle
11
+ import zipfile
12
+ import faiss
13
+
14
+
15
+ DINOV2_REPO = "facebookresearch/dinov2"
16
+ DINOV2_MODEL = "dinov2_vitl14"
17
+
18
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Trasnforms
21
+ patch_height_nums = 40
22
+ patch_width_nums = 40
23
+ patch_size = 14
24
+ height = patch_height_nums * patch_size
25
+ width = patch_width_nums * patch_size
26
+
27
+ transform = T.Compose([
28
+ T.Resize((width, height)),
29
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
30
+ ])
31
+
32
+ height, width = patch_size * patch_height_nums, patch_size * patch_width_nums
33
+
34
+ # DINOV2
35
+ model = torch.hub.load(DINOV2_REPO, DINOV2_MODEL)
36
+
37
+
38
+ def read_image(image_path: str) -> np.ndarray:
39
+ image = cv2.imread(image_path, cv2.IMREAD_COLOR)
40
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
41
+ return image
42
+
43
+
44
+ def infer(image: np.ndarray) -> np.ndarray:
45
+ image = np.transpose(image, (2, 0, 1)) // 255
46
+ transformed_image = transform(torch.Tensor(image))
47
+ embedding = model(torch.unsqueeze(transformed_image, 0))
48
+ return embedding.detach().cpu().numpy()
49
+
50
+
51
+ def unzip(zipfile_path: str) -> str:
52
+ output_dir = zipfile_path.split("/")[-1].split(".")[0]
53
+ with zipfile.ZipFile(zipfile_path, 'r') as myzip:
54
+ # Loop through each file in the zip file
55
+ for name in myzip.namelist():
56
+ # Check if the file is an image file
57
+ if name.endswith('.jpg') or name.endswith('.png'):
58
+ # Extract the file from the zip archive to disk
59
+ myzip.extract(name, output_dir)
60
+ return output_dir
61
+
62
+
63
+ def calculate_embedding(
64
+ zipfile,
65
+ ) -> str:
66
+ filedir = unzip(zipfile.name)
67
+ database = []
68
+ start = time.time()
69
+ for img_path in glob.glob(os.path.join(filedir, "*")):
70
+ if img_path.split(".")[-1] not in ["jpg", "png", "jpeg"]:
71
+ continue
72
+ image = read_image(img_path)
73
+ embedding = infer(image)
74
+ database.append((img_path, embedding))
75
+ print(f"Embedding Calculation: {time.time() - start}")
76
+ filepath = "database.npy"
77
+ with open(filepath, "wb") as f:
78
+ pickle.dump(database, f)
79
+ return filepath
80
+
81
+
82
+ def instance_recognition(
83
+ embedding_filepath: str,
84
+ image_path: str,
85
+ ) -> List[np.ndarray]:
86
+ with open(embedding_filepath.name, "rb") as f:
87
+ embeddings = pickle.load(f)
88
+
89
+ embedding_vectors = []
90
+ image_paths = []
91
+ for img_path, embedding in embeddings:
92
+ embedding_vectors.append(embedding)
93
+ image_paths.append(img_path)
94
+
95
+ embedding_vectors = np.squeeze(np.array(embedding_vectors), axis=1)
96
+ d = embedding_vectors.shape[-1]
97
+
98
+ # train faiss
99
+ index = faiss.IndexFlatIP(d)
100
+ index.add(embedding_vectors)
101
+
102
+ # infer image
103
+ image = read_image(image_path)
104
+ image_embedding = infer(image)
105
+
106
+ # search
107
+ k = 5
108
+ distances, indices = index.search(image_embedding, k)
109
+
110
+ res = []
111
+ for i in indices[0]:
112
+ res.append(read_image(image_paths[i]))
113
+ return res[::-1] + distances[0].tolist()[::-1]
114
+
115
+
116
+ with gr.Blocks() as demo:
117
+ gr.Markdown("# Instance Recogniton with DINOV2")
118
+
119
+ with gr.Tab("Image Embedding with database"):
120
+ with gr.Row():
121
+ embedding_btn = gr.Button(value="Image Embedding")
122
+ image_zip_file = gr.File(type="file", label="Image Zip File")
123
+ image_embedding_file = gr.File(type="binary", label="Image Embedding with DINOV2")
124
+ embedding_btn.click(
125
+ calculate_embedding, inputs=image_zip_file, outputs=image_embedding_file,
126
+ )
127
+
128
+ with gr.Tab("Instance Recognition"):
129
+
130
+ with gr.Row():
131
+ infer_btn = gr.Button(value="Inference")
132
+ input_image = gr.Image(type="filepath", label="Input Image")
133
+ image_embedding_file = gr.File(type="file", label="Image Embeddings")
134
+
135
+ with gr.Row():
136
+ output_images = [
137
+ gr.Image(label=f"Similar {i + 1}") for i in range(5)
138
+ ]
139
+ distances = [
140
+ gr.Text(label=f"Similar {i + 1} Distances") for i in range(5)
141
+ ]
142
+
143
+ infer_btn.click(
144
+ instance_recognition,
145
+ inputs=[image_embedding_file, input_image],
146
+ outputs=output_images + distances,
147
+ )
148
+
149
+ if __name__ == "__main__":
150
+ demo.launch(server_name="0.0.0.0")