import os import cv2 import spaces import gradio as gr from PIL import Image from omegaconf import OmegaConf # set up environment from utils.env_utils import set_random_seed, use_lower_vram from utils.timer_utils import Timer set_random_seed(1024) timer = Timer() timer.start() # use_lower_vram() # import functions from utils.labels_utils import Labels from utils.ram_utils import ram_inference from utils.blip2_utils import blip2_caption from utils.llms_utils import pre_refinement, make_prompt, init_model from utils.grounded_sam_utils import run_grounded_sam # hardcode parameters for G-SAM box_threshold = 0.18 text_threshold = 0.15 iou_threshold = 0.8 global current_config, L, llm, system_prompt # load Llama-3 here to avoid loading it during the inference. llm = init_model("Meta-Llama-3-8B-Instruct") current_config = "" L = None system_prompt = None def load_config(config_type): config = OmegaConf.load(os.path.join(os.path.dirname(__file__), f"configs/{config_type}.yaml")) L = Labels(config=config) # init labels and llm prompt, only Meta-Llama-3-8B-Instruct is supported for online demo, but you can use any model in your local environment using our released code system_prompt = make_prompt(", ".join(L.LABELS)) return L, system_prompt @spaces.GPU(duration=120) def process(image_ori, config_type): global current_config, L, llm, system_prompt if current_config != config_type: L, system_prompt = load_config(config_type) current_config = config_type else: pass image_ori = cv2.cvtColor(image_ori, cv2.COLOR_BGR2RGB) image_pil = Image.fromarray(image_ori) labels_ram = ram_inference(image_pil) + ": " + blip2_caption(image_pil) converted_labels, llm_output = pre_refinement([labels_ram], system_prompt, llm=llm) labels_llm = L.check_labels(converted_labels)[0] print("labels_ram: ", labels_ram) print("llm_output: ", llm_output) print("labels_llm: ", labels_llm) # run sam label_res, bboxes, output_labels, output_prob_maps, output_points = run_grounded_sam( input_image = {"image": image_pil, "mask": None}, text_prompt = labels_llm, box_threshold = box_threshold, text_threshold = text_threshold, iou_threshold = iou_threshold, LABELS = L.LABELS, IDS = L.IDS, llm = llm, timer = timer, ) # draw mask and save image ours = L.draw_mask(label_res, image_ori, print_label=True, tag="Ours") return cv2.cvtColor(ours, cv2.COLOR_BGR2RGB) if __name__ == "__main__": # options for different settings dropdown_options = ["COCO-81", "Cityscapes", "DRAM", "VOC2012"] default_option = "COCO-81" with gr.Blocks() as demo: gr.HTML( """
This is an online demo for the paper "Training-Free Zero-Shot Semantic Segmentation with LLM Refinement" (BMVC 2024).
Uasge: Please select or upload an image and choose a dataset setting for semantic segmentation refinement.
""", allow_flagging='never', examples=[ ["examples/Cityscapes_eg.jpg", "Cityscapes"], ["examples/DRAM_eg.jpg", "DRAM"], ["examples/COCO-81_eg.jpg", "COCO-81"], ["examples/VOC2012_eg.jpg", "VOC2012"], ], cache_examples=True, ) demo.queue(max_size=10).launch()