Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from ultralytics import YOLO | |
from torchvision.transforms.functional import to_tensor | |
from huggingface_hub import hf_hub_download | |
import torch | |
import albumentations as A | |
from albumentations.pytorch.transforms import ToTensorV2 | |
import pandas as pd | |
from sklearn.metrics.pairwise import cosine_similarity | |
from utils import * | |
from models import YOLOStamp, Encoder | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
yolov8 = YOLO(hf_hub_download('stamps-labs/yolov8-finetuned', filename='best.torchscript'), task='detect') | |
yolo_stamp = YOLOStamp() | |
yolo_stamp.load_state_dict(torch.load(hf_hub_download('stamps-labs/yolo-stamp', filename='state_dict.pth'), map_location='cpu')) | |
yolo_stamp = yolo_stamp.to(device) | |
yolo_stamp.eval() | |
transform = A.Compose([ | |
A.Normalize(), | |
ToTensorV2(p=1.0), | |
]) | |
vits8 = torch.jit.load(hf_hub_download('stamps-labs/vits8-stamp', filename='vits8stamp-torchscript.pth'), map_location='cpu') | |
vits8 = vits8.to(device) | |
vits8.eval() | |
encoder = Encoder() | |
encoder.load_state_dict(torch.load(hf_hub_download('stamps-labs/vae-encoder', filename='encoder.pth'), map_location='cpu')) | |
encoder = encoder.to(device) | |
encoder.eval() | |
def predict(image, det_choice, emb_choice): | |
shape = torch.tensor(image.size) | |
image = image.convert('RGB') | |
if det_choice == 'yolov8': | |
coef = torch.hstack((shape, shape)) / 640 | |
image = image.resize((640, 640)) | |
boxes = yolov8(image)[0].boxes.xyxy.cpu() | |
image_with_boxes = visualize_bbox(image, boxes) | |
elif det_choice == 'yolo-stamp': | |
coef = torch.hstack((shape, shape)) / 448 | |
image = image.resize((448, 448)) | |
image_tensor = transform(image=np.array(image))['image'] | |
output = yolo_stamp(image_tensor.unsqueeze(0).to(device)) | |
boxes = output_tensor_to_boxes(output[0].detach().cpu()) | |
boxes = nonmax_suppression(boxes) | |
boxes = xywh2xyxy(torch.tensor(boxes)[:, :4]) | |
image_with_boxes = visualize_bbox(image, boxes) | |
else: | |
return | |
embeddings = [] | |
if emb_choice == 'vits8': | |
for box in boxes: | |
cropped_stamp = to_tensor(image.crop(box.tolist())) | |
embeddings.append(vits8(cropped_stamp.unsqueeze(0).to(device))[0].detach().cpu()) | |
elif emb_choice == 'vae-encoder': | |
for box in boxes: | |
cropped_stamp = to_tensor(image.crop(box.tolist()).resize((118, 118))) | |
embeddings.append(encoder(cropped_stamp.unsqueeze(0).to(device))[0][0].detach().cpu()) | |
embeddings = np.stack(np.array(embeddings)) | |
similarities = cosine_similarity(embeddings) | |
boxes = boxes * coef | |
df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2']) | |
fig, ax = plt.subplots() | |
im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax, | |
cmap="YlGn", cbarlabel="Embeddings similarities") | |
texts = annotate_heatmap(im, valfmt="{x:.3f}") | |
return image_with_boxes, df_boxes, embeddings, fig | |
examples = [['./examples/1.jpg', 'yolov8', 'vits8'], ['./examples/2.jpg', 'yolov8', 'vae-encoder'], ['./examples/3.jpg', 'yolo-stamp', 'vits8']] | |
inputs = [ | |
gr.Image(type="pil"), | |
gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'), | |
gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'), | |
] | |
outputs = [ | |
gr.Image(type="pil"), | |
gr.DataFrame(type='pandas', label="Bounding boxes"), | |
gr.DataFrame(type='numpy', label="Embeddings"), | |
gr.Plot(label="Cosine Similarities") | |
] | |
app = gr.Interface(predict, inputs, outputs, examples=examples) | |
app.queue() | |
app.launch() |