ashhadahsan's picture
Upload 21 files
8e5dadf verified
raw
history blame
3.09 kB
import os
from io import BytesIO
import gradio as gr
import requests
from PIL import Image
from lang_sam import SAM_MODELS
from lang_sam.server import PORT, server
def inference(sam_type, box_threshold, text_threshold, image, text_prompt):
"""Gradio function that makes a request to the /predict LitServe endpoint."""
url = f"http://localhost:{PORT}/predict" # Adjust port if needed
# Prepare the multipart form data
with open(image, "rb") as img_file:
files = {
"image": img_file,
}
data = {
"sam_type": sam_type,
"box_threshold": str(box_threshold),
"text_threshold": str(text_threshold),
"text_prompt": text_prompt,
}
try:
response = requests.post(url, files=files, data=data)
except Exception as e:
print(f"Request failed: {e}")
return None
if response.status_code == 200:
try:
output_image = Image.open(BytesIO(response.content)).convert("RGB")
return output_image
except Exception as e:
print(f"Failed to process response image: {e}")
return None
else:
print(f"Request failed with status code {response.status_code}: {response.text}")
return None
with gr.Blocks(title="lang-sam") as blocks:
with gr.Row():
sam_model_choices = gr.Dropdown(choices=list(SAM_MODELS.keys()), label="SAM Model", value="sam2.1_hiera_small")
box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, label="Box Threshold")
text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Text Threshold")
with gr.Row():
image_input = gr.Image(type="filepath", label="Input Image")
output_image = gr.Image(type="pil", label="Output Image")
text_prompt = gr.Textbox(lines=1, label="Text Prompt")
submit_btn = gr.Button("Run Prediction")
submit_btn.click(
fn=inference,
inputs=[sam_model_choices, box_threshold, text_threshold, image_input, text_prompt],
outputs=output_image,
)
examples = [
[
"sam2.1_hiera_small",
0.32,
0.25,
os.path.join(os.path.dirname(__file__), "assets", "fruits.jpg"),
"kiwi. watermelon. blueberry.",
],
[
"sam2.1_hiera_small",
0.3,
0.25,
os.path.join(os.path.dirname(__file__), "assets", "car.jpeg"),
"wheel.",
],
[
"sam2.1_hiera_small",
0.3,
0.25,
os.path.join(os.path.dirname(__file__), "assets", "food.jpg"),
"food.",
],
]
gr.Examples(
examples=examples,
inputs=[sam_model_choices, box_threshold, text_threshold, image_input, text_prompt],
outputs=output_image,
)
server.app = gr.mount_gradio_app(server.app, blocks, path="/gradio")
if __name__ == "__main__":
print(f"Starting LitServe and Gradio server on port {PORT}...")
server.run(port=PORT)