UniVAD / app.py
FantasticGNU's picture
Update app.py
1768414 verified
import subprocess
subprocess.run(["pip", "install", "fastapi==0.108.0"])
import gradio as gr
from UniVAD.tools import process_image
subprocess.run(["wget", "-q","https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"], check=True)
subprocess.run(["wget", "-q","https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth"], check=True)
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
import torch
import torchvision.transforms as transforms
from UniVAD.univad import UniVAD
from ram.models import ram_plus
from UniVAD.models.segment_anything import (
sam_hq_model_registry,
SamPredictor,
)
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
image_size = 448
univad_model = UniVAD(image_size=image_size).to(device)
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
]
)
ram_model = ram_plus(
pretrained="./ram_plus_swin_large_14m.pth",
image_size=384,
vit="swin_l",
)
ram_model.eval()
ram_model = ram_model.to(device)
grounding_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to("cuda")
sam = sam_hq_model_registry["vit_h"]("./sam_hq_vit_h.pth").to(device)
sam_predictor = SamPredictor(sam)
def preprocess_image(img):
return img.resize((448, 448))
def update_image(image):
if image is not None:
return preprocess_image(image)
@spaces.GPU
def ad(image_pil, normal_image, box_threshold, text_threshold, text_prompt, background_prompt, cluster_num):
return process_image(image_pil, normal_image, box_threshold, text_threshold, sam_predictor, grounding_model, univad_model, ram_model, text_prompt, background_prompt, cluster_num, image_size, grounding_processor)
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center" style='margin-top: 30px;'>Demo of UniVAD</h1>""")
with gr.Row():
with gr.Column():
with gr.Row():
gr.Markdown("### Upload normal image here for reference.")
with gr.Row():
normal_img = gr.Image(type="pil", label="Normal Image", value=None, height=475, width=440)
normal_img.change(fn=update_image, inputs=normal_img, outputs=normal_img)
with gr.Row():
normal_mask = gr.Image(type="filepath", label="Normal Component Mask", value=None, height=450, interactive=False)
with gr.Row():
clearBtn = gr.Button("Clear", variant="secondary")
with gr.Column():
with gr.Row():
gr.Markdown("### Upload query image here for anomaly detection.")
with gr.Row():
query_img = gr.Image(type="pil", label="Query Image", value=None, height=475, width=440)
query_img.change(fn=update_image, inputs=query_img, outputs=query_img)
with gr.Row():
query_mask = gr.Image(type="filepath", label="Query Component Mask", value=None, height=450)
with gr.Row():
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column():
with gr.Row():
gr.Markdown("### Settings:")
with gr.Row():
box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Box Threshold")
with gr.Row():
text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Text Threshold")
with gr.Row():
text_prompt = gr.Textbox(label="Specify what should be in the image. Separate them with periods (.)", placeholder="(optional)")
with gr.Row():
background_prompt = gr.Textbox(label="Specify what should be IGNORED in the image. Separate them with periods (.)", placeholder="(optional)")
with gr.Row():
cluster_num = gr.Textbox(label="Number of Clusters", placeholder="(optional)")
with gr.Row():
anomaly_map_raw = gr.Image(type="filepath", label="Localizaiton Result", value=None, height=450)
with gr.Row():
anomaly_score = gr.HTML(value="<span style='font-size: 30px;'>Detection Result:</span>")
gr.State()
submitBtn.click(
ad, [
query_img,
normal_img,
box_threshold,
text_threshold,
text_prompt,
background_prompt,
cluster_num,
], [
query_mask,
normal_mask,
anomaly_map_raw,
anomaly_score
],
show_progress=True
)
clearBtn.click(
lambda: (None, None, None, None, None, "<span style='font-size: 30px;'>Detection Result:</span>"),
outputs=[query_img, normal_img, query_mask, normal_mask, anomaly_map_raw, anomaly_score]
)
demo.queue().launch()