ashhadahsan's picture
Upload 21 files
8e5dadf verified
from io import BytesIO
import litserve as ls
import numpy as np
from fastapi import Response, UploadFile
from PIL import Image
from lang_sam import LangSAM
from lang_sam.utils import draw_image
PORT = 8000
class LangSAMAPI(ls.LitAPI):
def setup(self, device: str) -> None:
"""Initialize or load the LangSAM model."""
self.model = LangSAM(sam_type="sam2.1_hiera_small")
print("LangSAM model initialized.")
def decode_request(self, request) -> dict:
"""Decode the incoming request to extract parameters and image bytes.
Assumes the request is sent as multipart/form-data with fields:
- sam_type: str
- box_threshold: float
- text_threshold: float
- text_prompt: str
- image: UploadFile
"""
# Extract form data
sam_type = request.get("sam_type")
box_threshold = float(request.get("box_threshold", 0.3))
text_threshold = float(request.get("text_threshold", 0.25))
text_prompt = request.get("text_prompt", "")
# Extract image file
image_file: UploadFile = request.get("image")
if image_file is None:
raise ValueError("No image file provided in the request.")
image_bytes = image_file.file.read()
return {
"sam_type": sam_type,
"box_threshold": box_threshold,
"text_threshold": text_threshold,
"image_bytes": image_bytes,
"text_prompt": text_prompt,
}
def predict(self, inputs: dict) -> dict:
"""Perform prediction using the LangSAM model.
Yields:
dict: Contains the processed output image.
"""
print("Starting prediction with parameters:")
print(
f"sam_type: {inputs['sam_type']}, \
box_threshold: {inputs['box_threshold']}, \
text_threshold: {inputs['text_threshold']}, \
text_prompt: {inputs['text_prompt']}"
)
if inputs["sam_type"] != self.model.sam_type:
print(f"Updating SAM model type to {inputs['sam_type']}")
self.model.sam.build_model(inputs["sam_type"])
try:
image_pil = Image.open(BytesIO(inputs["image_bytes"])).convert("RGB")
except Exception as e:
raise ValueError(f"Invalid image data: {e}")
results = self.model.predict(
images_pil=[image_pil],
texts_prompt=[inputs["text_prompt"]],
box_threshold=inputs["box_threshold"],
text_threshold=inputs["text_threshold"],
)
results = results[0]
if not len(results["masks"]):
print("No masks detected. Returning original image.")
return {"output_image": image_pil}
# Draw results on the image
image_array = np.asarray(image_pil)
output_image = draw_image(
image_array,
results["masks"],
results["boxes"],
results["scores"],
results["labels"],
)
output_image = Image.fromarray(np.uint8(output_image)).convert("RGB")
return {"output_image": output_image}
def encode_response(self, output: dict) -> Response:
"""Encode the prediction result into an HTTP response.
Returns:
Response: Contains the processed image in PNG format.
"""
try:
image = output["output_image"]
buffer = BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
return Response(content=buffer.getvalue(), media_type="image/png")
except StopIteration:
raise ValueError("No output generated by the prediction.")
lit_api = LangSAMAPI()
server = ls.LitServer(lit_api)
if __name__ == "__main__":
print(f"Starting LitServe and Gradio server on port {PORT}...")
server.run(port=PORT)