Spaces:
Runtime error
Runtime error
from io import BytesIO | |
import litserve as ls | |
import numpy as np | |
from fastapi import Response, UploadFile | |
from PIL import Image | |
from lang_sam import LangSAM | |
from lang_sam.utils import draw_image | |
PORT = 8000 | |
class LangSAMAPI(ls.LitAPI): | |
def setup(self, device: str) -> None: | |
"""Initialize or load the LangSAM model.""" | |
self.model = LangSAM(sam_type="sam2.1_hiera_small") | |
print("LangSAM model initialized.") | |
def decode_request(self, request) -> dict: | |
"""Decode the incoming request to extract parameters and image bytes. | |
Assumes the request is sent as multipart/form-data with fields: | |
- sam_type: str | |
- box_threshold: float | |
- text_threshold: float | |
- text_prompt: str | |
- image: UploadFile | |
""" | |
# Extract form data | |
sam_type = request.get("sam_type") | |
box_threshold = float(request.get("box_threshold", 0.3)) | |
text_threshold = float(request.get("text_threshold", 0.25)) | |
text_prompt = request.get("text_prompt", "") | |
# Extract image file | |
image_file: UploadFile = request.get("image") | |
if image_file is None: | |
raise ValueError("No image file provided in the request.") | |
image_bytes = image_file.file.read() | |
return { | |
"sam_type": sam_type, | |
"box_threshold": box_threshold, | |
"text_threshold": text_threshold, | |
"image_bytes": image_bytes, | |
"text_prompt": text_prompt, | |
} | |
def predict(self, inputs: dict) -> dict: | |
"""Perform prediction using the LangSAM model. | |
Yields: | |
dict: Contains the processed output image. | |
""" | |
print("Starting prediction with parameters:") | |
print( | |
f"sam_type: {inputs['sam_type']}, \ | |
box_threshold: {inputs['box_threshold']}, \ | |
text_threshold: {inputs['text_threshold']}, \ | |
text_prompt: {inputs['text_prompt']}" | |
) | |
if inputs["sam_type"] != self.model.sam_type: | |
print(f"Updating SAM model type to {inputs['sam_type']}") | |
self.model.sam.build_model(inputs["sam_type"]) | |
try: | |
image_pil = Image.open(BytesIO(inputs["image_bytes"])).convert("RGB") | |
except Exception as e: | |
raise ValueError(f"Invalid image data: {e}") | |
results = self.model.predict( | |
images_pil=[image_pil], | |
texts_prompt=[inputs["text_prompt"]], | |
box_threshold=inputs["box_threshold"], | |
text_threshold=inputs["text_threshold"], | |
) | |
results = results[0] | |
if not len(results["masks"]): | |
print("No masks detected. Returning original image.") | |
return {"output_image": image_pil} | |
# Draw results on the image | |
image_array = np.asarray(image_pil) | |
output_image = draw_image( | |
image_array, | |
results["masks"], | |
results["boxes"], | |
results["scores"], | |
results["labels"], | |
) | |
output_image = Image.fromarray(np.uint8(output_image)).convert("RGB") | |
return {"output_image": output_image} | |
def encode_response(self, output: dict) -> Response: | |
"""Encode the prediction result into an HTTP response. | |
Returns: | |
Response: Contains the processed image in PNG format. | |
""" | |
try: | |
image = output["output_image"] | |
buffer = BytesIO() | |
image.save(buffer, format="PNG") | |
buffer.seek(0) | |
return Response(content=buffer.getvalue(), media_type="image/png") | |
except StopIteration: | |
raise ValueError("No output generated by the prediction.") | |
lit_api = LangSAMAPI() | |
server = ls.LitServer(lit_api) | |
if __name__ == "__main__": | |
print(f"Starting LitServe and Gradio server on port {PORT}...") | |
server.run(port=PORT) | |