Spaces:
Runtime error
Runtime error
File size: 4,906 Bytes
3d8379d e5a474c 3d8379d 65137b3 3d8379d adb0128 3d8379d 65137b3 3d8379d 691eb11 3d8379d 691eb11 3d8379d 691eb11 3d8379d 691eb11 3d8379d adb0128 3d8379d 691eb11 df8e1c7 3d8379d adb0128 3d8379d adb0128 3d8379d 691eb11 3d8379d e5a474c 6956f4b 3d8379d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
import time
import gradio as gr
import numpy as np
import torch
import torchvision.transforms as T
from typing import List
import glob
import cv2
import pickle
import zipfile
import faiss
from examples import examples
DINOV2_REPO = "facebookresearch/dinov2"
DINOV2_MODEL = "dinov2_vitl14"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Trasnforms
patch_height_nums = 40
patch_width_nums = 40
patch_size = 14
height = patch_height_nums * patch_size
width = patch_width_nums * patch_size
transform = T.Compose([
T.Resize((width, height)),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
height, width = patch_size * patch_height_nums, patch_size * patch_width_nums
# DINOV2
model = torch.hub.load(DINOV2_REPO, DINOV2_MODEL).to(DEVICE)
# faiss
K= 5
def read_image(image_path: str) -> np.ndarray:
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def infer(image: np.ndarray) -> np.ndarray:
image = np.transpose(image, (2, 0, 1)) // 255
transformed_image = transform(torch.Tensor(image)).to(DEVICE)
embedding = model(torch.unsqueeze(transformed_image, 0))
return embedding.detach().cpu().numpy()
def unzip(zipfile_path: str) -> str:
output_dir = zipfile_path.split("/")[-1].split(".")[0]
with zipfile.ZipFile(zipfile_path, 'r') as myzip:
# Loop through each file in the zip file
for name in myzip.namelist():
# Check if the file is an image file
if name.endswith('.jpg') or name.endswith('.png'):
# Extract the file from the zip archive to disk
myzip.extract(name, output_dir)
return output_dir
def calculate_embedding(
zipfile,
) -> str:
filedir = unzip(zipfile.name)
database = []
start = time.time()
for img_path in glob.glob(os.path.join(filedir, "*")):
if img_path.split(".")[-1] not in ["jpg", "png", "jpeg"]:
continue
image = read_image(img_path)
embedding = infer(image)
database.append((img_path, embedding))
print(f"Embedding Calculation: {time.time() - start}")
filepath = "database.pickle"
with open(filepath, "wb") as f:
pickle.dump(database, f)
return filepath
def instance_recognition(
embedding_file,
zipfile,
image_path: str,
) -> List[np.ndarray]:
with open(embedding_file.name, "rb") as f:
embeddings = pickle.load(f)
unzip(zipfile.name)
embedding_vectors = []
image_paths = []
for img_path, embedding in embeddings:
embedding_vectors.append(embedding)
image_paths.append(img_path)
embedding_vectors = np.squeeze(np.array(embedding_vectors), axis=1)
d = embedding_vectors.shape[-1]
# train faiss
index = faiss.IndexFlatIP(d)
index.add(embedding_vectors)
# infer image
image = read_image(image_path)
image_embedding = infer(image)
# search
distances, indices = index.search(image_embedding, K)
res = []
for i in indices[0]:
res.append(read_image(image_paths[i]))
return res[::-1] + distances[0].tolist()[::-1]
with gr.Blocks() as demo:
gr.Markdown("# Instance Recogniton with DINOV2")
with gr.Tab("Instance Recognition"):
with gr.Row():
infer_btn = gr.Button(value="Inference")
image_embedding_file = gr.File(type="file", label="Image Embeddings")
image_zip_file = gr.File(type="file", label="Image Zip File")
with gr.Row():
input_image = gr.Image(type="filepath", label="Input Image")
with gr.Row():
output_images = [
gr.Image(label=f"Similar {i + 1}") for i in range(K)
]
distances = [
gr.Text(label=f"Similar {i + 1} Distances") for i in range(K)
]
infer_btn.click(
instance_recognition,
inputs=[image_embedding_file, image_zip_file, input_image],
outputs=output_images + distances,
)
gr.Examples(
examples=examples,
inputs=[
image_embedding_file,
image_zip_file,
input_image,
],
outputs=output_images + distances,
fn=instance_recognition,
run_on_click=True,
)
with gr.Tab("Image Embedding with database"):
with gr.Row():
embedding_btn = gr.Button(value="Image Embedding")
image_zip_file = gr.File(type="file", label="Image Zip File")
image_embedding_file = gr.File(type="binary", label="Image Embedding with DINOV2")
embedding_btn.click(
calculate_embedding, inputs=image_zip_file, outputs=image_embedding_file,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")
|