|
import subprocess |
|
|
|
|
|
subprocess.run(["pip", "install", "fastapi==0.108.0"]) |
|
|
|
|
|
import gradio as gr |
|
|
|
from UniVAD.tools import process_image |
|
|
|
subprocess.run(["wget", "-q","https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"], check=True) |
|
subprocess.run(["wget", "-q","https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth"], check=True) |
|
|
|
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor |
|
|
|
import torch |
|
import torchvision.transforms as transforms |
|
from UniVAD.univad import UniVAD |
|
|
|
from ram.models import ram_plus |
|
|
|
from UniVAD.models.segment_anything import ( |
|
sam_hq_model_registry, |
|
SamPredictor, |
|
) |
|
|
|
import spaces |
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
image_size = 448 |
|
univad_model = UniVAD(image_size=image_size).to(device) |
|
|
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize((image_size, image_size)), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
|
|
ram_model = ram_plus( |
|
pretrained="./ram_plus_swin_large_14m.pth", |
|
image_size=384, |
|
vit="swin_l", |
|
) |
|
ram_model.eval() |
|
ram_model = ram_model.to(device) |
|
|
|
|
|
grounding_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") |
|
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to("cuda") |
|
sam = sam_hq_model_registry["vit_h"]("./sam_hq_vit_h.pth").to(device) |
|
sam_predictor = SamPredictor(sam) |
|
|
|
|
|
def preprocess_image(img): |
|
return img.resize((448, 448)) |
|
|
|
|
|
|
|
def update_image(image): |
|
if image is not None: |
|
return preprocess_image(image) |
|
|
|
@spaces.GPU |
|
def ad(image_pil, normal_image, box_threshold, text_threshold, text_prompt, background_prompt, cluster_num): |
|
return process_image(image_pil, normal_image, box_threshold, text_threshold, sam_predictor, grounding_model, univad_model, ram_model, text_prompt, background_prompt, cluster_num, image_size, grounding_processor) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML("""<h1 align="center" style='margin-top: 30px;'>Demo of UniVAD</h1>""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
gr.Markdown("### Upload normal image here for reference.") |
|
with gr.Row(): |
|
normal_img = gr.Image(type="pil", label="Normal Image", value=None, height=475, width=440) |
|
normal_img.change(fn=update_image, inputs=normal_img, outputs=normal_img) |
|
with gr.Row(): |
|
normal_mask = gr.Image(type="filepath", label="Normal Component Mask", value=None, height=450, interactive=False) |
|
with gr.Row(): |
|
clearBtn = gr.Button("Clear", variant="secondary") |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
gr.Markdown("### Upload query image here for anomaly detection.") |
|
with gr.Row(): |
|
query_img = gr.Image(type="pil", label="Query Image", value=None, height=475, width=440) |
|
query_img.change(fn=update_image, inputs=query_img, outputs=query_img) |
|
with gr.Row(): |
|
query_mask = gr.Image(type="filepath", label="Query Component Mask", value=None, height=450) |
|
with gr.Row(): |
|
submitBtn = gr.Button("Submit", variant="primary") |
|
|
|
with gr.Column(): |
|
|
|
with gr.Row(): |
|
gr.Markdown("### Settings:") |
|
with gr.Row(): |
|
box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Box Threshold") |
|
with gr.Row(): |
|
text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Text Threshold") |
|
with gr.Row(): |
|
text_prompt = gr.Textbox(label="Specify what should be in the image. Separate them with periods (.)", placeholder="(optional)") |
|
with gr.Row(): |
|
background_prompt = gr.Textbox(label="Specify what should be IGNORED in the image. Separate them with periods (.)", placeholder="(optional)") |
|
with gr.Row(): |
|
cluster_num = gr.Textbox(label="Number of Clusters", placeholder="(optional)") |
|
|
|
with gr.Row(): |
|
anomaly_map_raw = gr.Image(type="filepath", label="Localizaiton Result", value=None, height=450) |
|
|
|
with gr.Row(): |
|
anomaly_score = gr.HTML(value="<span style='font-size: 30px;'>Detection Result:</span>") |
|
|
|
gr.State() |
|
|
|
submitBtn.click( |
|
ad, [ |
|
query_img, |
|
normal_img, |
|
box_threshold, |
|
text_threshold, |
|
text_prompt, |
|
background_prompt, |
|
cluster_num, |
|
], [ |
|
query_mask, |
|
normal_mask, |
|
anomaly_map_raw, |
|
anomaly_score |
|
|
|
], |
|
show_progress=True |
|
) |
|
|
|
clearBtn.click( |
|
lambda: (None, None, None, None, None, "<span style='font-size: 30px;'>Detection Result:</span>"), |
|
outputs=[query_img, normal_img, query_mask, normal_mask, anomaly_map_raw, anomaly_score] |
|
) |
|
|
|
|
|
demo.queue().launch() |
|
|
|
|