import subprocess # subprocess.run(["pip", "install", "-e", "./models/GroundingDINO"]) # subprocess.run(["pip", "install", "gradio==4.21.0"]) # subprocess.run(["pip", "install", "fastapi==0.108.0"]) import gradio as gr from UniVAD.tools import process_image # subprocess.run(["wget", "-q","https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"], check=True) # subprocess.run(["wget", "-q","https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"], check=True) # subprocess.run(["wget", "-q","https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth"], check=True) import torch import torchvision.transforms as transforms from UniVAD.univad import UniVAD from ram.models import ram_plus from UniVAD.models.segment_anything import ( sam_hq_model_registry, SamPredictor, ) # Grounding DINO from UniVAD.models.grounded_sam import ( load_model, ) device = "cuda" if torch.cuda.is_available() else "cpu" image_size = 224 univad_model = UniVAD(image_size=image_size).to(device) transform = transforms.Compose( [ transforms.Resize((image_size, image_size)), transforms.ToTensor(), ] ) ram_model = ram_plus( pretrained="./ram_plus_swin_large_14m.pth", image_size=384, vit="swin_l", ) ram_model.eval() ram_model = ram_model.to(device) grounding_model = load_model( "./UniVAD/models/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "./groundingdino_swint_ogc.pth", "cuda" if torch.cuda.is_available() else "cpu" ) sam = sam_hq_model_registry["vit_h"]("./sam_hq_vit_h.pth").to(device) sam_predictor = SamPredictor(sam) def preprocess_image(img): return img.resize((448, 448)) def update_image(image): if image is not None: return preprocess_image(image) def ad(image_pil, normal_image, box_threshold, text_threshold, text_prompt, background_prompt, cluster_num): return process_image(image_pil, normal_image, box_threshold, text_threshold, sam_predictor, grounding_model, univad_model, ram_model, text_prompt, background_prompt, cluster_num, image_size) with gr.Blocks() as demo: gr.HTML("""

Demo of UniVAD

""") with gr.Row(): with gr.Column(): with gr.Row(): gr.Markdown("### Upload normal image here for reference.") with gr.Row(): normal_img = gr.Image(type="pil", label="Normal Image", value=None, height=475, width=440) normal_img.change(fn=update_image, inputs=normal_img, outputs=normal_img) with gr.Row(): normal_mask = gr.Image(type="filepath", label="Normal Component Mask", value=None, height=450, interactive=False) with gr.Row(): clearBtn = gr.Button("Clear", variant="secondary") with gr.Column(): with gr.Row(): gr.Markdown("### Upload query image here for anomaly detection.") with gr.Row(): query_img = gr.Image(type="pil", label="Query Image", value=None, height=475, width=440) query_img.change(fn=update_image, inputs=query_img, outputs=query_img) with gr.Row(): query_mask = gr.Image(type="filepath", label="Query Component Mask", value=None, height=450) with gr.Row(): submitBtn = gr.Button("Submit", variant="primary") with gr.Column(): with gr.Row(): gr.Markdown("### Settings:") with gr.Row(): box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Box Threshold") with gr.Row(): text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Text Threshold") with gr.Row(): text_prompt = gr.Textbox(label="Specify what should be in the image. Separate them with periods (.)", placeholder="(optional)") with gr.Row(): background_prompt = gr.Textbox(label="Specify what should be IGNORED in the image. Separate them with periods (.)", placeholder="(optional)") with gr.Row(): cluster_num = gr.Textbox(label="Number of Clusters", placeholder="(optional)") with gr.Row(): anomaly_map_raw = gr.Image(type="filepath", label="Localizaiton Result", value=None, height=450) with gr.Row(): anomaly_score = gr.HTML(value="Detection Result:") gr.State() submitBtn.click( ad, [ query_img, normal_img, box_threshold, text_threshold, text_prompt, background_prompt, cluster_num, ], [ query_mask, normal_mask, anomaly_map_raw, anomaly_score ], show_progress=True ) clearBtn.click( lambda: (None, None, None, None, None, "Detection Result:"), outputs=[query_img, normal_img, query_mask, normal_mask, anomaly_map_raw, anomaly_score] ) demo.queue().launch()