tanthinhdt commited on
Commit
ceafbcf
·
verified ·
1 Parent(s): f85ca71

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -199
app.py CHANGED
@@ -1,199 +1,190 @@
1
- import streamlit as st
2
- from time import time
3
- from PIL import Image
4
- from transformers import AutoModelForVision2Seq, AutoProcessor
5
-
6
-
7
- def load_model_and_processor() -> None:
8
- """
9
- Load the model and processor.
10
- """
11
- st.session_state.model = AutoModelForVision2Seq.from_pretrained(
12
- st.session_state.model_id,
13
- cache_dir="models/huggingface",
14
- )
15
- st.session_state.model.eval()
16
-
17
- st.session_state.processor = AutoProcessor.from_pretrained(
18
- st.session_state.model_id,
19
- cache_dir="models/huggingface",
20
- )
21
-
22
-
23
- def to_device() -> None:
24
- """
25
- Move the model to the selected device.
26
- """
27
- st.session_state.model.to(st.session_state.device.lower())
28
-
29
-
30
- def scale_image(image: Image.Image, target_height: int = 500) -> Image.Image:
31
- """
32
- Scale an image to a target height while maintaining the aspect ratio.
33
-
34
- Parameters
35
- ----------
36
- image : Image.Image
37
- The image to scale.
38
- target_height : int, optional (default=500)
39
- The target height of the image.
40
-
41
- Returns
42
- -------
43
- Image.Image
44
- The scaled image.
45
- """
46
- width, height = image.size
47
- aspect_ratio = width / height
48
- target_width = int(aspect_ratio * target_height)
49
- return image.resize((target_width, target_height))
50
-
51
-
52
- def upload_image() -> None:
53
- """
54
- Upload an image.
55
- """
56
- if st.session_state.file_uploader is not None:
57
- st.session_state.image = Image.open(st.session_state.file_uploader)
58
-
59
-
60
- def inference() -> None:
61
- """
62
- Perform inference on an image and generate a caption.
63
- """
64
- start_time = time()
65
- outputs = st.session_state.processor(
66
- images=st.session_state.image,
67
- return_tensors="pt",
68
- )
69
- outputs = {k: v.to(st.session_state.device.lower()) for k, v in outputs.items()}
70
- logits = st.session_state.model.generate(
71
- **outputs,
72
- max_length=st.session_state.max_length,
73
- num_beams=st.session_state.num_beams,
74
- )
75
- caption = st.session_state.processor.decode(
76
- logits[0], skip_special_tokens=True
77
- )
78
- end_time = time()
79
- st.session_state.inference_time = round(end_time - start_time, 2)
80
- st.session_state.caption = caption
81
-
82
-
83
- def main() -> None:
84
- """
85
- Main function for the Streamlit app.
86
- """
87
- if "model" not in st.session_state:
88
- st.session_state.model = AutoModelForVision2Seq.from_pretrained(
89
- "Salesforce/blip-image-captioning-base",
90
- cache_dir="models/huggingface",
91
- )
92
- st.session_state.model.eval().to("cpu")
93
- if "processor" not in st.session_state:
94
- st.session_state.processor = AutoProcessor.from_pretrained(
95
- "Salesforce/blip-image-captioning-base",
96
- cache_dir="models/huggingface",
97
- )
98
- if "image" not in st.session_state:
99
- st.session_state.image = None
100
- if "caption" not in st.session_state:
101
- st.session_state.caption = None
102
- if "inference_time" not in st.session_state:
103
- st.session_state.inference_time = 0.0
104
-
105
- # Set page configuration
106
- st.set_page_config(
107
- page_title="Image Captioning App",
108
- page_icon="📸",
109
- initial_sidebar_state="expanded",
110
- )
111
-
112
- # Set sidebar layout
113
- st.sidebar.header("Workspace")
114
- st.sidebar.file_uploader(
115
- "Upload an image",
116
- type=["jpg", "jpeg", "png"],
117
- accept_multiple_files=False,
118
- on_change=upload_image,
119
- key="file_uploader",
120
- help="Upload an image to generate a caption.",
121
- )
122
- st.sidebar.divider()
123
- st.sidebar.header("Settings")
124
- st.sidebar.selectbox(
125
- label="Model ID",
126
- options=["Salesforce/blip-image-captioning-base"],
127
- index=0,
128
- on_change=load_model_and_processor,
129
- key="model_id",
130
- help="The model to use for image captioning.",
131
- )
132
- st.sidebar.selectbox(
133
- label="Device",
134
- options=["CPU", "CUDA"],
135
- index=0,
136
- on_change=to_device,
137
- key="device",
138
- help="The device to use for inference.",
139
- )
140
- st.sidebar.number_input(
141
- label="Max length",
142
- min_value=32,
143
- max_value=128,
144
- value=128,
145
- step=1,
146
- key="max_length",
147
- help="The maximum length of the generated caption.",
148
- )
149
- st.sidebar.number_input(
150
- label="Number of beams",
151
- min_value=1,
152
- max_value=8,
153
- value=4,
154
- step=1,
155
- key="num_beams",
156
- help="The number of beams to use during decoding.",
157
- )
158
-
159
- # Set main layout
160
- st.markdown(
161
- """
162
- <h1 style='text-align: center;'>
163
- Image Captioning
164
- </h1>
165
- """,
166
- unsafe_allow_html=True,
167
- )
168
- st.divider()
169
- image_container = st.container(height=450)
170
- st.divider()
171
- col_1, col_2, col_3 = st.columns([1, 1, 2])
172
- resolution_display = col_1.empty()
173
- runtime_display = col_2.empty()
174
- caption_display = col_3.empty()
175
-
176
- # Display the image and generate a caption
177
- if st.session_state.image is not None:
178
- image_container.image(scale_image(st.session_state.image, target_height=400))
179
-
180
- resolution_display.metric(
181
- label="Image Resolution",
182
- value=f"{st.session_state.image.width}x{st.session_state.image.height}",
183
- )
184
-
185
- with st.spinner("Generating caption..."):
186
- inference()
187
-
188
- caption_display.text_area(
189
- label="Caption",
190
- value=st.session_state.caption,
191
- )
192
- runtime_display.metric(
193
- label="Inference Time",
194
- value=f"{st.session_state.inference_time}s",
195
- )
196
-
197
-
198
- if __name__ == "__main__":
199
- main()
 
1
+ import torch
2
+ import urllib
3
+ import streamlit as st
4
+ from io import BytesIO
5
+ from time import time
6
+ from PIL import Image
7
+ from transformers import AutoModelForVision2Seq, AutoProcessor
8
+
9
+
10
+ def scale_image(image: Image.Image, target_height: int = 500) -> Image.Image:
11
+ """
12
+ Scale an image to a target height while maintaining the aspect ratio.
13
+
14
+ Parameters
15
+ ----------
16
+ image : Image.Image
17
+ The image to scale.
18
+ target_height : int, optional (default=500)
19
+ The target height of the image.
20
+
21
+ Returns
22
+ -------
23
+ Image.Image
24
+ The scaled image.
25
+ """
26
+ width, height = image.size
27
+ aspect_ratio = width / height
28
+ target_width = int(aspect_ratio * target_height)
29
+ return image.resize((target_width, target_height))
30
+
31
+
32
+ def upload_image() -> None:
33
+ """
34
+ Upload an image.
35
+ """
36
+ if st.session_state.file_uploader is not None:
37
+ st.session_state.image = Image.open(st.session_state.file_uploader)
38
+
39
+
40
+ def read_image_from_url() -> None:
41
+ """
42
+ Read an image from a URL.
43
+ """
44
+ if st.session_state.image_url is not None:
45
+ with urllib.request.urlopen(st.session_state.image_url) as response:
46
+ st.session_state.image = Image.open(BytesIO(response.read()))
47
+
48
+
49
+ def inference() -> None:
50
+ """
51
+ Perform inference on an image and generate a caption.
52
+ """
53
+ start_time = time()
54
+ outputs = st.session_state.processor(
55
+ images=st.session_state.image,
56
+ return_tensors="pt",
57
+ )
58
+ outputs = {k: v.to(st.session_state.device.lower()) for k, v in outputs.items()}
59
+ st.session_state.model.to(st.session_state.device.lower())
60
+ logits = st.session_state.model.generate(
61
+ **outputs,
62
+ max_length=st.session_state.max_length,
63
+ num_beams=st.session_state.num_beams,
64
+ )
65
+ caption = st.session_state.processor.decode(
66
+ logits[0], skip_special_tokens=True
67
+ )
68
+ end_time = time()
69
+
70
+ st.session_state.inference_time = round(end_time - start_time, 2)
71
+ st.session_state.caption = caption
72
+
73
+ st.session_state.model.to("cpu")
74
+ torch.cuda.empty_cache()
75
+
76
+
77
+ def main() -> None:
78
+ """
79
+ Main function for the Streamlit app.
80
+ """
81
+ if "model" not in st.session_state:
82
+ st.session_state.model = AutoModelForVision2Seq.from_pretrained(
83
+ "tanthinhdt/blip-base_with-pretrained_flickr30k",
84
+ cache_dir="models/huggingface",
85
+ )
86
+ st.session_state.model.eval()
87
+ if "processor" not in st.session_state:
88
+ st.session_state.processor = AutoProcessor.from_pretrained(
89
+ "Salesforce/blip-image-captioning-base",
90
+ cache_dir="models/huggingface",
91
+ )
92
+ if "image" not in st.session_state:
93
+ st.session_state.image = None
94
+ if "caption" not in st.session_state:
95
+ st.session_state.caption = None
96
+ if "inference_time" not in st.session_state:
97
+ st.session_state.inference_time = 0.0
98
+
99
+ # Set page configuration
100
+ st.set_page_config(
101
+ page_title="Image Captioning App",
102
+ page_icon="📸",
103
+ initial_sidebar_state="expanded",
104
+ )
105
+
106
+ # Set sidebar layout
107
+ st.sidebar.header("Workspace")
108
+ st.sidebar.file_uploader(
109
+ "Upload an image",
110
+ type=["jpg", "jpeg", "png"],
111
+ accept_multiple_files=False,
112
+ on_change=upload_image,
113
+ key="file_uploader",
114
+ help="Upload an image to generate a caption.",
115
+ )
116
+ st.sidebar.text_input(
117
+ "Image URL",
118
+ on_change=read_image_from_url,
119
+ key="image_url",
120
+ help="Enter the URL of an image to generate a caption.",
121
+ )
122
+ st.sidebar.divider()
123
+ st.sidebar.header("Settings")
124
+ st.sidebar.selectbox(
125
+ label="Device",
126
+ options=["CPU", "CUDA"],
127
+ index=1,
128
+ key="device",
129
+ help="The device to use for inference.",
130
+ )
131
+ st.sidebar.number_input(
132
+ label="Max length",
133
+ min_value=32,
134
+ max_value=128,
135
+ value=64,
136
+ step=1,
137
+ key="max_length",
138
+ help="The maximum length of the generated caption.",
139
+ )
140
+ st.sidebar.number_input(
141
+ label="Number of beams",
142
+ min_value=1,
143
+ max_value=10,
144
+ value=4,
145
+ step=1,
146
+ key="num_beams",
147
+ help="The number of beams to use during decoding.",
148
+ )
149
+
150
+ # Set main layout
151
+ st.markdown(
152
+ """
153
+ <h1 style='text-align: center;'>
154
+ Image Captioning
155
+ </h1>
156
+ """,
157
+ unsafe_allow_html=True,
158
+ )
159
+ st.divider()
160
+ image_container = st.container(height=450)
161
+ st.divider()
162
+ col_1, col_2, col_3 = st.columns([1, 1, 2])
163
+ resolution_display = col_1.empty()
164
+ runtime_display = col_2.empty()
165
+ caption_display = col_3.empty()
166
+
167
+ # Display the image and generate a caption
168
+ if st.session_state.image is not None:
169
+ image_container.image(scale_image(st.session_state.image, target_height=400))
170
+
171
+ resolution_display.metric(
172
+ label="Image Resolution",
173
+ value=f"{st.session_state.image.width}x{st.session_state.image.height}",
174
+ )
175
+
176
+ with st.spinner("Generating caption..."):
177
+ inference()
178
+
179
+ caption_display.text_area(
180
+ label="Caption",
181
+ value=st.session_state.caption,
182
+ )
183
+ runtime_display.metric(
184
+ label="Inference Time",
185
+ value=f"{st.session_state.inference_time}s",
186
+ )
187
+
188
+
189
+ if __name__ == "__main__":
190
+ main()