Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
import pandas as pd | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from pipelines.detection.yolo_v8 import Yolov8Pipeline | |
from pipelines.detection.yolo_stamp import YoloStampPipeline | |
from pipelines.segmentation.deeplabv3 import DeepLabv3Pipeline | |
from pipelines.feature_extraction.vae import VaePipeline | |
from pipelines.feature_extraction.vits8 import Vits8Pipeline | |
from utils import * | |
yolov8 = Yolov8Pipeline.from_pretrained(local_model_path='yolov8_old_backup.pt') | |
yolo_stamp = YoloStampPipeline.from_pretrained('stamps-labs/yolo-stamp', 'weights.pt') | |
vae = VaePipeline.from_pretrained('stamps-labs/vae-encoder', 'weights.pt') | |
vits8 = Vits8Pipeline.from_pretrained('stamps-labs/vits8-stamp', 'weights.pt') | |
dlv3 = DeepLabv3Pipeline.from_pretrained('stamps-labs/deeplabv3-finetuned', 'weights.pt') | |
def doc_predict(image, det_choice, seg_choice, emb_choice): | |
image = image.convert('RGB') | |
if det_choice == 'yolov8': | |
boxes = yolov8(image) | |
elif det_choice == 'yolo-stamp': | |
boxes = yolo_stamp(image) | |
else: | |
return | |
image_with_boxes = visualize_bbox(image, boxes) | |
segmented_stamps = [] | |
for box in boxes: | |
cropped_stamp = image.crop(box.tolist()) | |
segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp) | |
if len(segmented_stamps) != 0: | |
widths, heights = zip(*(i.size for i in segmented_stamps)) | |
total_width = sum(widths) | |
max_height = max(heights) | |
concatenated_stamps = Image.new('RGB', (total_width, max_height)) | |
x_offset = 0 | |
for im in segmented_stamps: | |
concatenated_stamps.paste(im, (x_offset,0)) | |
x_offset += im.size[0] | |
else: | |
concatenated_stamps = Image.new('RGB', (0, 0)) | |
embeddings = [] | |
if emb_choice == 'vits8': | |
for stamp in segmented_stamps: | |
embeddings.append(vits8(stamp)) | |
elif emb_choice == 'vae-encoder': | |
for stamp in segmented_stamps: | |
embeddings.append(vae(stamp)) | |
embeddings = np.stack(embeddings) | |
similarities = cosine_similarity(embeddings) | |
df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2']) | |
fig, ax = plt.subplots() | |
im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax, | |
cmap="YlGn", cbarlabel="Embeddings similarities") | |
texts = annotate_heatmap(im, valfmt="{x:.3f}") | |
return image_with_boxes, df_boxes, concatenated_stamps, embeddings, fig | |
doc_examples = [['examples/1.jpg', 'yolov8', True, 'vits8'], ['examples/2.jpg', 'yolo-stamp', False, 'vae-encoder'], ['examples/3.jpg', 'yolov8', True, 'vits8']] | |
doc_inputs = [ | |
gr.Image(label="Document image", type="pil"), | |
gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'), | |
gr.Checkbox(label="Use segmentation model"), | |
gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'), | |
] | |
doc_outputs = [ | |
gr.Image(label="Document with bounding boxes", type="pil"), | |
gr.DataFrame(type='pandas', label="Bounding boxes"), | |
gr.Image(label="Segmented stamps", type="pil"), | |
gr.DataFrame(type='numpy', label="Embeddings"), | |
gr.Plot(label="Cosine Similarities") | |
] | |
with gr.Blocks() as demo: | |
with gr.Tab("Signle document"): | |
gr.Interface(doc_predict, doc_inputs, doc_outputs, examples=doc_examples) | |
demo.launch(inline=False) |