import os import time import gradio as gr import numpy as np import torch import torchvision.transforms as T from typing import List import glob import cv2 import pickle import zipfile import faiss from examples import examples DINOV2_REPO = "facebookresearch/dinov2" DINOV2_MODEL = "dinov2_vitl14" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Trasnforms patch_height_nums = 40 patch_width_nums = 40 patch_size = 14 height = patch_height_nums * patch_size width = patch_width_nums * patch_size transform = T.Compose([ T.Resize((width, height)), T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) height, width = patch_size * patch_height_nums, patch_size * patch_width_nums # DINOV2 model = torch.hub.load(DINOV2_REPO, DINOV2_MODEL).to(DEVICE) # faiss K= 5 def read_image(image_path: str) -> np.ndarray: image = cv2.imread(image_path, cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image def infer(image: np.ndarray) -> np.ndarray: image = np.transpose(image, (2, 0, 1)) // 255 transformed_image = transform(torch.Tensor(image)).to(DEVICE) embedding = model(torch.unsqueeze(transformed_image, 0)) return embedding.detach().cpu().numpy() def unzip(zipfile_path: str) -> str: output_dir = zipfile_path.split("/")[-1].split(".")[0] with zipfile.ZipFile(zipfile_path, 'r') as myzip: # Loop through each file in the zip file for name in myzip.namelist(): # Check if the file is an image file if name.endswith('.jpg') or name.endswith('.png'): # Extract the file from the zip archive to disk myzip.extract(name, output_dir) return output_dir def calculate_embedding( zipfile, ) -> str: filedir = unzip(zipfile.name) database = [] start = time.time() for img_path in glob.glob(os.path.join(filedir, "*")): if img_path.split(".")[-1] not in ["jpg", "png", "jpeg"]: continue image = read_image(img_path) embedding = infer(image) database.append((img_path, embedding)) print(f"Embedding Calculation: {time.time() - start}") filepath = "database.pickle" with open(filepath, "wb") as f: pickle.dump(database, f) return filepath def instance_recognition( embedding_file, zipfile, image_path: str, ) -> List[np.ndarray]: with open(embedding_file.name, "rb") as f: embeddings = pickle.load(f) unzip(zipfile.name) embedding_vectors = [] image_paths = [] for img_path, embedding in embeddings: embedding_vectors.append(embedding) image_paths.append(img_path) embedding_vectors = np.squeeze(np.array(embedding_vectors), axis=1) d = embedding_vectors.shape[-1] # train faiss index = faiss.IndexFlatIP(d) index.add(embedding_vectors) # infer image image = read_image(image_path) image_embedding = infer(image) # search distances, indices = index.search(image_embedding, K) res = [] for i in indices[0]: res.append(read_image(image_paths[i])) return res[::-1] + distances[0].tolist()[::-1] with gr.Blocks() as demo: gr.Markdown("# Instance Recogniton with DINOV2") with gr.Tab("Instance Recognition"): with gr.Row(): infer_btn = gr.Button(value="Inference") image_embedding_file = gr.File(type="file", label="Image Embeddings") image_zip_file = gr.File(type="file", label="Image Zip File") with gr.Row(): input_image = gr.Image(type="filepath", label="Input Image") with gr.Row(): output_images = [ gr.Image(label=f"Similar {i + 1}") for i in range(K) ] distances = [ gr.Text(label=f"Similar {i + 1} Distances") for i in range(K) ] infer_btn.click( instance_recognition, inputs=[image_embedding_file, image_zip_file, input_image], outputs=output_images + distances, ) gr.Examples( examples=examples, inputs=[ image_embedding_file, image_zip_file, input_image, ], outputs=output_images + distances, fn=instance_recognition, run_on_click=True, ) with gr.Tab("Image Embedding with database"): with gr.Row(): embedding_btn = gr.Button(value="Image Embedding") image_zip_file = gr.File(type="file", label="Image Zip File") image_embedding_file = gr.File(type="binary", label="Image Embedding with DINOV2") embedding_btn.click( calculate_embedding, inputs=image_zip_file, outputs=image_embedding_file, ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0")