Spaces:
Runtime error
Runtime error
import os | |
from io import BytesIO | |
import gradio as gr | |
import requests | |
from PIL import Image | |
from lang_sam import SAM_MODELS | |
from lang_sam.server import PORT, server | |
def inference(sam_type, box_threshold, text_threshold, image, text_prompt): | |
"""Gradio function that makes a request to the /predict LitServe endpoint.""" | |
url = f"http://localhost:{PORT}/predict" # Adjust port if needed | |
# Prepare the multipart form data | |
with open(image, "rb") as img_file: | |
files = { | |
"image": img_file, | |
} | |
data = { | |
"sam_type": sam_type, | |
"box_threshold": str(box_threshold), | |
"text_threshold": str(text_threshold), | |
"text_prompt": text_prompt, | |
} | |
try: | |
response = requests.post(url, files=files, data=data) | |
except Exception as e: | |
print(f"Request failed: {e}") | |
return None | |
if response.status_code == 200: | |
try: | |
output_image = Image.open(BytesIO(response.content)).convert("RGB") | |
return output_image | |
except Exception as e: | |
print(f"Failed to process response image: {e}") | |
return None | |
else: | |
print(f"Request failed with status code {response.status_code}: {response.text}") | |
return None | |
with gr.Blocks(title="lang-sam") as blocks: | |
with gr.Row(): | |
sam_model_choices = gr.Dropdown(choices=list(SAM_MODELS.keys()), label="SAM Model", value="sam2.1_hiera_small") | |
box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, label="Box Threshold") | |
text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Text Threshold") | |
with gr.Row(): | |
image_input = gr.Image(type="filepath", label="Input Image") | |
output_image = gr.Image(type="pil", label="Output Image") | |
text_prompt = gr.Textbox(lines=1, label="Text Prompt") | |
submit_btn = gr.Button("Run Prediction") | |
submit_btn.click( | |
fn=inference, | |
inputs=[sam_model_choices, box_threshold, text_threshold, image_input, text_prompt], | |
outputs=output_image, | |
) | |
examples = [ | |
[ | |
"sam2.1_hiera_small", | |
0.32, | |
0.25, | |
os.path.join(os.path.dirname(__file__), "assets", "fruits.jpg"), | |
"kiwi. watermelon. blueberry.", | |
], | |
[ | |
"sam2.1_hiera_small", | |
0.3, | |
0.25, | |
os.path.join(os.path.dirname(__file__), "assets", "car.jpeg"), | |
"wheel.", | |
], | |
[ | |
"sam2.1_hiera_small", | |
0.3, | |
0.25, | |
os.path.join(os.path.dirname(__file__), "assets", "food.jpg"), | |
"food.", | |
], | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[sam_model_choices, box_threshold, text_threshold, image_input, text_prompt], | |
outputs=output_image, | |
) | |
server.app = gr.mount_gradio_app(server.app, blocks, path="/gradio") | |
if __name__ == "__main__": | |
print(f"Starting LitServe and Gradio server on port {PORT}...") | |
server.run(port=PORT) | |