RRoundTable
to device
65137b3
import os
import time
import gradio as gr
import numpy as np
import torch
import torchvision.transforms as T
from typing import List
import glob
import cv2
import pickle
import zipfile
import faiss
from examples import examples
DINOV2_REPO = "facebookresearch/dinov2"
DINOV2_MODEL = "dinov2_vitl14"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Trasnforms
patch_height_nums = 40
patch_width_nums = 40
patch_size = 14
height = patch_height_nums * patch_size
width = patch_width_nums * patch_size
transform = T.Compose([
T.Resize((width, height)),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
height, width = patch_size * patch_height_nums, patch_size * patch_width_nums
# DINOV2
model = torch.hub.load(DINOV2_REPO, DINOV2_MODEL).to(DEVICE)
# faiss
K= 5
def read_image(image_path: str) -> np.ndarray:
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def infer(image: np.ndarray) -> np.ndarray:
image = np.transpose(image, (2, 0, 1)) // 255
transformed_image = transform(torch.Tensor(image)).to(DEVICE)
embedding = model(torch.unsqueeze(transformed_image, 0))
return embedding.detach().cpu().numpy()
def unzip(zipfile_path: str) -> str:
output_dir = zipfile_path.split("/")[-1].split(".")[0]
with zipfile.ZipFile(zipfile_path, 'r') as myzip:
# Loop through each file in the zip file
for name in myzip.namelist():
# Check if the file is an image file
if name.endswith('.jpg') or name.endswith('.png'):
# Extract the file from the zip archive to disk
myzip.extract(name, output_dir)
return output_dir
def calculate_embedding(
zipfile,
) -> str:
filedir = unzip(zipfile.name)
database = []
start = time.time()
for img_path in glob.glob(os.path.join(filedir, "*")):
if img_path.split(".")[-1] not in ["jpg", "png", "jpeg"]:
continue
image = read_image(img_path)
embedding = infer(image)
database.append((img_path, embedding))
print(f"Embedding Calculation: {time.time() - start}")
filepath = "database.pickle"
with open(filepath, "wb") as f:
pickle.dump(database, f)
return filepath
def instance_recognition(
embedding_file,
zipfile,
image_path: str,
) -> List[np.ndarray]:
with open(embedding_file.name, "rb") as f:
embeddings = pickle.load(f)
unzip(zipfile.name)
embedding_vectors = []
image_paths = []
for img_path, embedding in embeddings:
embedding_vectors.append(embedding)
image_paths.append(img_path)
embedding_vectors = np.squeeze(np.array(embedding_vectors), axis=1)
d = embedding_vectors.shape[-1]
# train faiss
index = faiss.IndexFlatIP(d)
index.add(embedding_vectors)
# infer image
image = read_image(image_path)
image_embedding = infer(image)
# search
distances, indices = index.search(image_embedding, K)
res = []
for i in indices[0]:
res.append(read_image(image_paths[i]))
return res[::-1] + distances[0].tolist()[::-1]
with gr.Blocks() as demo:
gr.Markdown("# Instance Recogniton with DINOV2")
with gr.Tab("Instance Recognition"):
with gr.Row():
infer_btn = gr.Button(value="Inference")
image_embedding_file = gr.File(type="file", label="Image Embeddings")
image_zip_file = gr.File(type="file", label="Image Zip File")
with gr.Row():
input_image = gr.Image(type="filepath", label="Input Image")
with gr.Row():
output_images = [
gr.Image(label=f"Similar {i + 1}") for i in range(K)
]
distances = [
gr.Text(label=f"Similar {i + 1} Distances") for i in range(K)
]
infer_btn.click(
instance_recognition,
inputs=[image_embedding_file, image_zip_file, input_image],
outputs=output_images + distances,
)
gr.Examples(
examples=examples,
inputs=[
image_embedding_file,
image_zip_file,
input_image,
],
outputs=output_images + distances,
fn=instance_recognition,
run_on_click=True,
)
with gr.Tab("Image Embedding with database"):
with gr.Row():
embedding_btn = gr.Button(value="Image Embedding")
image_zip_file = gr.File(type="file", label="Image Zip File")
image_embedding_file = gr.File(type="binary", label="Image Embedding with DINOV2")
embedding_btn.click(
calculate_embedding, inputs=image_zip_file, outputs=image_embedding_file,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")