jiuface's picture
bug fix
f0d9f07
raw
history blame
3.64 kB
from typing import Optional
import numpy as np
import gradio as gr
import spaces
import supervision as sv
import torch
from PIL import Image
from io import BytesIO
import PIL.Image
import requests
import cv2
from utils.florence import load_florence_model, run_florence_inference, \
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
DEVICE = torch.device("cuda")
# DEVICE = torch.device("cpu")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
@spaces.GPU(duration=20)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0) -> Optional[Image.Image]:
if not image_input:
gr.Info("Please upload an image.")
return None
if not task_prompt:
gr.Info("Please enter a task prompt.")
return None
if image_url:
print("start to fetch image from url", image_url)
response = requests.get(image_url)
response.raise_for_status()
image_input = PIL.Image.open(BytesIO(response.content))
print("fetch image success")
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=task_prompt,
text=text_prompt
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image_input.size
)
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
if len(detections) == 0:
gr.Info("No objects detected.")
return None
images = []
print("mask generated:", len(detections.mask))
kernel_size = dilate
kernel = np.ones((kernel_size, kernel_size), np.uint8)
for i in range(len(detections.mask)):
mask = detections.mask[i].astype(np.uint8) * 255
if dilate > 0:
mask = cv2.dilate(mask, kernel, iterations=1)
images.append(mask)
return images
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image = gr.Image(type='pil', label='Upload image')
image_url = gr.Textbox( label='Image url', placeholder='Enter text prompts (Optional)')
task_prompt = gr.Dropdown(
["<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>", "<CAPTION_TO_PHRASE_GROUNDING>", "<OPEN_VOCABULARY_DETECTION>", "<DENSE_REGION_CAPTION>"], value="<CAPTION_TO_PHRASE_GROUNDING>", label="Task Prompt", info="task prompts"
)
dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1)
text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
submit_button = gr.Button(value='Submit', variant='primary')
with gr.Column():
image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
print(image, image_url, task_prompt, text_prompt, image_gallery)
submit_button.click(
fn = process_image,
inputs = [image, image_url, task_prompt, text_prompt, dilate],
outputs = [image_gallery,],
show_api=False
)
demo.launch(debug=True, show_error=True)