SegmentVision / app.py
sagar007's picture
Update app.py
fd55cab verified
raw
history blame
2.4 kB
import gradio as gr
import torch
from PIL import Image
import cv2
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt
# Load CLIP model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Load FastSAM model
fast_sam = FastSAM('FastSAM-x.pt')
def process_image_clip(image, text_input):
# Process image for CLIP
inputs = processor(
images=image,
text=[text_input],
return_tensors="pt",
padding=True
)
# Get model predictions
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
confidence = float(probs[0][0])
return f"Confidence that the image contains '{text_input}': {confidence:.2%}"
def process_image_fastsam(image):
# Convert PIL image to numpy array
image_np = np.array(image)
# Run FastSAM inference
everything_results = fast_sam(image_np, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
prompt_process = FastSAMPrompt(image_np, everything_results, device='cpu')
# Get everything mask
ann = prompt_process.everything()
# Convert annotation to image
result_image = prompt_process.plot_to_result()
return Image.fromarray(result_image)
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# CLIP and FastSAM Demo")
with gr.Tab("CLIP Zero-Shot Classification"):
with gr.Row():
image_input = gr.Image(type="pil", label="Input Image")
text_input = gr.Textbox(label="What do you want to check in the image?", placeholder="Type here...")
output_text = gr.Textbox(label="Result")
classify_btn = gr.Button("Classify")
classify_btn.click(fn=process_image_clip, inputs=[image_input, text_input], outputs=output_text)
with gr.Tab("FastSAM Segmentation"):
with gr.Row():
image_input_sam = gr.Image(type="pil", label="Input Image")
image_output = gr.Image(type="pil", label="Segmentation Result")
segment_btn = gr.Button("Segment")
segment_btn.click(fn=process_image_fastsam, inputs=[image_input_sam], outputs=image_output)
if __name__ == "__main__":
demo.launch()