ashhadahsan's picture
Upload 21 files
8e5dadf verified
raw
history blame
3.92 kB
from io import BytesIO
import litserve as ls
import numpy as np
from fastapi import Response, UploadFile
from PIL import Image
from lang_sam import LangSAM
from lang_sam.utils import draw_image
PORT = 8000
class LangSAMAPI(ls.LitAPI):
def setup(self, device: str) -> None:
"""Initialize or load the LangSAM model."""
self.model = LangSAM(sam_type="sam2.1_hiera_small")
print("LangSAM model initialized.")
def decode_request(self, request) -> dict:
"""Decode the incoming request to extract parameters and image bytes.
Assumes the request is sent as multipart/form-data with fields:
- sam_type: str
- box_threshold: float
- text_threshold: float
- text_prompt: str
- image: UploadFile
"""
# Extract form data
sam_type = request.get("sam_type")
box_threshold = float(request.get("box_threshold", 0.3))
text_threshold = float(request.get("text_threshold", 0.25))
text_prompt = request.get("text_prompt", "")
# Extract image file
image_file: UploadFile = request.get("image")
if image_file is None:
raise ValueError("No image file provided in the request.")
image_bytes = image_file.file.read()
return {
"sam_type": sam_type,
"box_threshold": box_threshold,
"text_threshold": text_threshold,
"image_bytes": image_bytes,
"text_prompt": text_prompt,
}
def predict(self, inputs: dict) -> dict:
"""Perform prediction using the LangSAM model.
Yields:
dict: Contains the processed output image.
"""
print("Starting prediction with parameters:")
print(
f"sam_type: {inputs['sam_type']}, \
box_threshold: {inputs['box_threshold']}, \
text_threshold: {inputs['text_threshold']}, \
text_prompt: {inputs['text_prompt']}"
)
if inputs["sam_type"] != self.model.sam_type:
print(f"Updating SAM model type to {inputs['sam_type']}")
self.model.sam.build_model(inputs["sam_type"])
try:
image_pil = Image.open(BytesIO(inputs["image_bytes"])).convert("RGB")
except Exception as e:
raise ValueError(f"Invalid image data: {e}")
results = self.model.predict(
images_pil=[image_pil],
texts_prompt=[inputs["text_prompt"]],
box_threshold=inputs["box_threshold"],
text_threshold=inputs["text_threshold"],
)
results = results[0]
if not len(results["masks"]):
print("No masks detected. Returning original image.")
return {"output_image": image_pil}
# Draw results on the image
image_array = np.asarray(image_pil)
output_image = draw_image(
image_array,
results["masks"],
results["boxes"],
results["scores"],
results["labels"],
)
output_image = Image.fromarray(np.uint8(output_image)).convert("RGB")
return {"output_image": output_image}
def encode_response(self, output: dict) -> Response:
"""Encode the prediction result into an HTTP response.
Returns:
Response: Contains the processed image in PNG format.
"""
try:
image = output["output_image"]
buffer = BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
return Response(content=buffer.getvalue(), media_type="image/png")
except StopIteration:
raise ValueError("No output generated by the prediction.")
lit_api = LangSAMAPI()
server = ls.LitServer(lit_api)
if __name__ == "__main__":
print(f"Starting LitServe and Gradio server on port {PORT}...")
server.run(port=PORT)