File size: 3,093 Bytes
8e5dadf
 
 
47b0ecc
8e5dadf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47b0ecc
8e5dadf
47b0ecc
8e5dadf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
from io import BytesIO

import gradio as gr
import requests
from PIL import Image

from lang_sam import SAM_MODELS
from lang_sam.server import PORT, server


def inference(sam_type, box_threshold, text_threshold, image, text_prompt):
    """Gradio function that makes a request to the /predict LitServe endpoint."""
    url = f"http://localhost:{PORT}/predict"  # Adjust port if needed

    # Prepare the multipart form data
    with open(image, "rb") as img_file:
        files = {
            "image": img_file,
        }
        data = {
            "sam_type": sam_type,
            "box_threshold": str(box_threshold),
            "text_threshold": str(text_threshold),
            "text_prompt": text_prompt,
        }

        try:
            response = requests.post(url, files=files, data=data)
        except Exception as e:
            print(f"Request failed: {e}")
            return None

    if response.status_code == 200:
        try:
            output_image = Image.open(BytesIO(response.content)).convert("RGB")
            return output_image
        except Exception as e:
            print(f"Failed to process response image: {e}")
            return None
    else:
        print(f"Request failed with status code {response.status_code}: {response.text}")
        return None


with gr.Blocks(title="lang-sam") as blocks:
    with gr.Row():
        sam_model_choices = gr.Dropdown(choices=list(SAM_MODELS.keys()), label="SAM Model", value="sam2.1_hiera_small")
        box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, label="Box Threshold")
        text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Text Threshold")
    with gr.Row():
        image_input = gr.Image(type="filepath", label="Input Image")
        output_image = gr.Image(type="pil", label="Output Image")
    text_prompt = gr.Textbox(lines=1, label="Text Prompt")

    submit_btn = gr.Button("Run Prediction")

    submit_btn.click(
        fn=inference,
        inputs=[sam_model_choices, box_threshold, text_threshold, image_input, text_prompt],
        outputs=output_image,
    )

    examples = [
        [
            "sam2.1_hiera_small",
            0.32,
            0.25,
            os.path.join(os.path.dirname(__file__), "assets", "fruits.jpg"),
            "kiwi. watermelon. blueberry.",
        ],
        [
            "sam2.1_hiera_small",
            0.3,
            0.25,
            os.path.join(os.path.dirname(__file__), "assets", "car.jpeg"),
            "wheel.",
        ],
        [
            "sam2.1_hiera_small",
            0.3,
            0.25,
            os.path.join(os.path.dirname(__file__), "assets", "food.jpg"),
            "food.",
        ],
    ]

    gr.Examples(
        examples=examples,
        inputs=[sam_model_choices, box_threshold, text_threshold, image_input, text_prompt],
        outputs=output_image,
    )

server.app = gr.mount_gradio_app(server.app, blocks, path="/gradio")

if __name__ == "__main__":
    print(f"Starting LitServe and Gradio server on port {PORT}...")
    server.run(port=PORT)