tanthinhdt commited on
Commit
70bbc08
·
verified ·
1 Parent(s): 370f47d

feat: add app

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from time import time
3
+ from PIL import Image
4
+ from transformers import AutoModelForVision2Seq, AutoProcessor
5
+
6
+
7
+ def load_model_and_processor() -> None:
8
+ """
9
+ Load the model and processor.
10
+ """
11
+ st.session_state.model = AutoModelForVision2Seq.from_pretrained(
12
+ st.session_state.model_id,
13
+ cache_dir="models/huggingface",
14
+ )
15
+ st.session_state.model.eval()
16
+
17
+ st.session_state.processor = AutoProcessor.from_pretrained(
18
+ st.session_state.model_id,
19
+ cache_dir="models/huggingface",
20
+ )
21
+
22
+
23
+ def to_device() -> None:
24
+ """
25
+ Move the model to the selected device.
26
+ """
27
+ st.session_state.model.to(st.session_state.device.lower())
28
+
29
+
30
+ def scale_image(image: Image.Image, target_height: int = 500) -> Image.Image:
31
+ """
32
+ Scale an image to a target height while maintaining the aspect ratio.
33
+
34
+ Parameters
35
+ ----------
36
+ image : Image.Image
37
+ The image to scale.
38
+ target_height : int, optional (default=500)
39
+ The target height of the image.
40
+
41
+ Returns
42
+ -------
43
+ Image.Image
44
+ The scaled image.
45
+ """
46
+ width, height = image.size
47
+ aspect_ratio = width / height
48
+ target_width = int(aspect_ratio * target_height)
49
+ return image.resize((target_width, target_height))
50
+
51
+
52
+ def upload_image() -> None:
53
+ """
54
+ Upload an image.
55
+ """
56
+ if st.session_state.file_uploader is not None:
57
+ st.session_state.image = Image.open(st.session_state.file_uploader)
58
+
59
+
60
+ def inference() -> None:
61
+ """
62
+ Perform inference on an image and generate a caption.
63
+ """
64
+ start_time = time()
65
+ outputs = st.session_state.processor(
66
+ images=st.session_state.image,
67
+ return_tensors="pt",
68
+ )
69
+ outputs = {k: v.to(st.session_state.device.lower()) for k, v in outputs.items()}
70
+ logits = st.session_state.model.generate(
71
+ **outputs,
72
+ max_length=st.session_state.max_length,
73
+ num_beams=st.session_state.num_beams,
74
+ )
75
+ caption = st.session_state.processor.decode(
76
+ logits[0], skip_special_tokens=True
77
+ )
78
+ end_time = time()
79
+ st.session_state.inference_time = round(end_time - start_time, 2)
80
+ st.session_state.caption = caption
81
+
82
+
83
+ def main() -> None:
84
+ """
85
+ Main function for the Streamlit app.
86
+ """
87
+ if "model" not in st.session_state:
88
+ st.session_state.model = AutoModelForVision2Seq.from_pretrained(
89
+ "Salesforce/blip-image-captioning-base",
90
+ cache_dir="models/huggingface",
91
+ )
92
+ st.session_state.model.eval().to("cpu")
93
+ if "processor" not in st.session_state:
94
+ st.session_state.processor = AutoProcessor.from_pretrained(
95
+ "Salesforce/blip-image-captioning-base",
96
+ cache_dir="models/huggingface",
97
+ )
98
+ if "image" not in st.session_state:
99
+ st.session_state.image = None
100
+ if "caption" not in st.session_state:
101
+ st.session_state.caption = None
102
+ if "inference_time" not in st.session_state:
103
+ st.session_state.inference_time = 0.0
104
+
105
+ # Set page configuration
106
+ st.set_page_config(
107
+ page_title="Image Captioning App",
108
+ page_icon="📸",
109
+ initial_sidebar_state="expanded",
110
+ )
111
+
112
+ # Set sidebar layout
113
+ st.sidebar.header("Workspace")
114
+ st.sidebar.file_uploader(
115
+ "Upload an image",
116
+ type=["jpg", "jpeg", "png"],
117
+ accept_multiple_files=False,
118
+ on_change=upload_image,
119
+ key="file_uploader",
120
+ help="Upload an image to generate a caption.",
121
+ )
122
+ st.sidebar.divider()
123
+ st.sidebar.header("Settings")
124
+ st.sidebar.selectbox(
125
+ label="Model ID",
126
+ options=["Salesforce/blip-image-captioning-base"],
127
+ index=0,
128
+ on_change=load_model_and_processor,
129
+ key="model_id",
130
+ help="The model to use for image captioning.",
131
+ )
132
+ st.sidebar.selectbox(
133
+ label="Device",
134
+ options=["CPU", "CUDA"],
135
+ index=0,
136
+ on_change=to_device,
137
+ key="device",
138
+ help="The device to use for inference.",
139
+ )
140
+ st.sidebar.number_input(
141
+ label="Max length",
142
+ min_value=32,
143
+ max_value=128,
144
+ value=128,
145
+ step=1,
146
+ key="max_length",
147
+ help="The maximum length of the generated caption.",
148
+ )
149
+ st.sidebar.number_input(
150
+ label="Number of beams",
151
+ min_value=1,
152
+ max_value=8,
153
+ value=4,
154
+ step=1,
155
+ key="num_beams",
156
+ help="The number of beams to use during decoding.",
157
+ )
158
+
159
+ # Set main layout
160
+ st.markdown(
161
+ """
162
+ <h1 style='text-align: center;'>
163
+ Image Captioning
164
+ </h1>
165
+ """,
166
+ unsafe_allow_html=True,
167
+ )
168
+ st.divider()
169
+ image_container = st.container(height=450)
170
+ st.divider()
171
+ col_1, col_2, col_3 = st.columns([1, 1, 2])
172
+ resolution_display = col_1.empty()
173
+ runtime_display = col_2.empty()
174
+ caption_display = col_3.empty()
175
+
176
+ # Display the image and generate a caption
177
+ if st.session_state.image is not None:
178
+ image_container.image(scale_image(st.session_state.image, target_height=400))
179
+
180
+ resolution_display.metric(
181
+ label="Image Resolution",
182
+ value=f"{st.session_state.image.width}x{st.session_state.image.height}",
183
+ )
184
+
185
+ with st.spinner("Generating caption..."):
186
+ inference()
187
+
188
+ caption_display.text_area(
189
+ label="Caption",
190
+ value=st.session_state.caption,
191
+ )
192
+ runtime_display.metric(
193
+ label="Inference Time",
194
+ value=f"{st.session_state.inference_time}s",
195
+ )
196
+
197
+
198
+ if __name__ == "__main__":
199
+ main()