Spaces:
Runtime error
Runtime error
import os | |
import time | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torchvision.transforms as T | |
from typing import List | |
import glob | |
import cv2 | |
import pickle | |
import zipfile | |
import faiss | |
from examples import examples | |
DINOV2_REPO = "facebookresearch/dinov2" | |
DINOV2_MODEL = "dinov2_vitl14" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Trasnforms | |
patch_height_nums = 40 | |
patch_width_nums = 40 | |
patch_size = 14 | |
height = patch_height_nums * patch_size | |
width = patch_width_nums * patch_size | |
transform = T.Compose([ | |
T.Resize((width, height)), | |
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
]) | |
height, width = patch_size * patch_height_nums, patch_size * patch_width_nums | |
# DINOV2 | |
model = torch.hub.load(DINOV2_REPO, DINOV2_MODEL).to(DEVICE) | |
# faiss | |
K= 5 | |
def read_image(image_path: str) -> np.ndarray: | |
image = cv2.imread(image_path, cv2.IMREAD_COLOR) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
return image | |
def infer(image: np.ndarray) -> np.ndarray: | |
image = np.transpose(image, (2, 0, 1)) // 255 | |
transformed_image = transform(torch.Tensor(image)).to(DEVICE) | |
embedding = model(torch.unsqueeze(transformed_image, 0)) | |
return embedding.detach().cpu().numpy() | |
def unzip(zipfile_path: str) -> str: | |
output_dir = zipfile_path.split("/")[-1].split(".")[0] | |
with zipfile.ZipFile(zipfile_path, 'r') as myzip: | |
# Loop through each file in the zip file | |
for name in myzip.namelist(): | |
# Check if the file is an image file | |
if name.endswith('.jpg') or name.endswith('.png'): | |
# Extract the file from the zip archive to disk | |
myzip.extract(name, output_dir) | |
return output_dir | |
def calculate_embedding( | |
zipfile, | |
) -> str: | |
filedir = unzip(zipfile.name) | |
database = [] | |
start = time.time() | |
for img_path in glob.glob(os.path.join(filedir, "*")): | |
if img_path.split(".")[-1] not in ["jpg", "png", "jpeg"]: | |
continue | |
image = read_image(img_path) | |
embedding = infer(image) | |
database.append((img_path, embedding)) | |
print(f"Embedding Calculation: {time.time() - start}") | |
filepath = "database.pickle" | |
with open(filepath, "wb") as f: | |
pickle.dump(database, f) | |
return filepath | |
def instance_recognition( | |
embedding_file, | |
zipfile, | |
image_path: str, | |
) -> List[np.ndarray]: | |
with open(embedding_file.name, "rb") as f: | |
embeddings = pickle.load(f) | |
unzip(zipfile.name) | |
embedding_vectors = [] | |
image_paths = [] | |
for img_path, embedding in embeddings: | |
embedding_vectors.append(embedding) | |
image_paths.append(img_path) | |
embedding_vectors = np.squeeze(np.array(embedding_vectors), axis=1) | |
d = embedding_vectors.shape[-1] | |
# train faiss | |
index = faiss.IndexFlatIP(d) | |
index.add(embedding_vectors) | |
# infer image | |
image = read_image(image_path) | |
image_embedding = infer(image) | |
# search | |
distances, indices = index.search(image_embedding, K) | |
res = [] | |
for i in indices[0]: | |
res.append(read_image(image_paths[i])) | |
return res[::-1] + distances[0].tolist()[::-1] | |
with gr.Blocks() as demo: | |
gr.Markdown("# Instance Recogniton with DINOV2") | |
with gr.Tab("Instance Recognition"): | |
with gr.Row(): | |
infer_btn = gr.Button(value="Inference") | |
image_embedding_file = gr.File(type="file", label="Image Embeddings") | |
image_zip_file = gr.File(type="file", label="Image Zip File") | |
with gr.Row(): | |
input_image = gr.Image(type="filepath", label="Input Image") | |
with gr.Row(): | |
output_images = [ | |
gr.Image(label=f"Similar {i + 1}") for i in range(K) | |
] | |
distances = [ | |
gr.Text(label=f"Similar {i + 1} Distances") for i in range(K) | |
] | |
infer_btn.click( | |
instance_recognition, | |
inputs=[image_embedding_file, image_zip_file, input_image], | |
outputs=output_images + distances, | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
image_embedding_file, | |
image_zip_file, | |
input_image, | |
], | |
outputs=output_images + distances, | |
fn=instance_recognition, | |
run_on_click=True, | |
) | |
with gr.Tab("Image Embedding with database"): | |
with gr.Row(): | |
embedding_btn = gr.Button(value="Image Embedding") | |
image_zip_file = gr.File(type="file", label="Image Zip File") | |
image_embedding_file = gr.File(type="binary", label="Image Embedding with DINOV2") | |
embedding_btn.click( | |
calculate_embedding, inputs=image_zip_file, outputs=image_embedding_file, | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0") | |