File size: 5,085 Bytes
e78be20 15344a2 0a1ac14 e78be20 bfa4c00 e78be20 0a1ac14 e78be20 1768414 d32f884 0a1ac14 e78be20 0a1ac14 e78be20 27a3c77 0a1ac14 27a3c77 e78be20 1768414 e78be20 0a1ac14 e78be20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import subprocess
subprocess.run(["pip", "install", "fastapi==0.108.0"])
import gradio as gr
from UniVAD.tools import process_image
subprocess.run(["wget", "-q","https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"], check=True)
subprocess.run(["wget", "-q","https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth"], check=True)
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
import torch
import torchvision.transforms as transforms
from UniVAD.univad import UniVAD
from ram.models import ram_plus
from UniVAD.models.segment_anything import (
sam_hq_model_registry,
SamPredictor,
)
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
image_size = 448
univad_model = UniVAD(image_size=image_size).to(device)
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
]
)
ram_model = ram_plus(
pretrained="./ram_plus_swin_large_14m.pth",
image_size=384,
vit="swin_l",
)
ram_model.eval()
ram_model = ram_model.to(device)
grounding_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to("cuda")
sam = sam_hq_model_registry["vit_h"]("./sam_hq_vit_h.pth").to(device)
sam_predictor = SamPredictor(sam)
def preprocess_image(img):
return img.resize((448, 448))
def update_image(image):
if image is not None:
return preprocess_image(image)
@spaces.GPU
def ad(image_pil, normal_image, box_threshold, text_threshold, text_prompt, background_prompt, cluster_num):
return process_image(image_pil, normal_image, box_threshold, text_threshold, sam_predictor, grounding_model, univad_model, ram_model, text_prompt, background_prompt, cluster_num, image_size, grounding_processor)
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center" style='margin-top: 30px;'>Demo of UniVAD</h1>""")
with gr.Row():
with gr.Column():
with gr.Row():
gr.Markdown("### Upload normal image here for reference.")
with gr.Row():
normal_img = gr.Image(type="pil", label="Normal Image", value=None, height=475, width=440)
normal_img.change(fn=update_image, inputs=normal_img, outputs=normal_img)
with gr.Row():
normal_mask = gr.Image(type="filepath", label="Normal Component Mask", value=None, height=450, interactive=False)
with gr.Row():
clearBtn = gr.Button("Clear", variant="secondary")
with gr.Column():
with gr.Row():
gr.Markdown("### Upload query image here for anomaly detection.")
with gr.Row():
query_img = gr.Image(type="pil", label="Query Image", value=None, height=475, width=440)
query_img.change(fn=update_image, inputs=query_img, outputs=query_img)
with gr.Row():
query_mask = gr.Image(type="filepath", label="Query Component Mask", value=None, height=450)
with gr.Row():
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column():
with gr.Row():
gr.Markdown("### Settings:")
with gr.Row():
box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Box Threshold")
with gr.Row():
text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Text Threshold")
with gr.Row():
text_prompt = gr.Textbox(label="Specify what should be in the image. Separate them with periods (.)", placeholder="(optional)")
with gr.Row():
background_prompt = gr.Textbox(label="Specify what should be IGNORED in the image. Separate them with periods (.)", placeholder="(optional)")
with gr.Row():
cluster_num = gr.Textbox(label="Number of Clusters", placeholder="(optional)")
with gr.Row():
anomaly_map_raw = gr.Image(type="filepath", label="Localizaiton Result", value=None, height=450)
with gr.Row():
anomaly_score = gr.HTML(value="<span style='font-size: 30px;'>Detection Result:</span>")
gr.State()
submitBtn.click(
ad, [
query_img,
normal_img,
box_threshold,
text_threshold,
text_prompt,
background_prompt,
cluster_num,
], [
query_mask,
normal_mask,
anomaly_map_raw,
anomaly_score
],
show_progress=True
)
clearBtn.click(
lambda: (None, None, None, None, None, "<span style='font-size: 30px;'>Detection Result:</span>"),
outputs=[query_img, normal_img, query_mask, normal_mask, anomaly_map_raw, anomaly_score]
)
demo.queue().launch()
|