IMCAP / app.py
tanthinhdt's picture
feat: add app
70bbc08 verified
raw
history blame
5.63 kB
import streamlit as st
from time import time
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor
def load_model_and_processor() -> None:
"""
Load the model and processor.
"""
st.session_state.model = AutoModelForVision2Seq.from_pretrained(
st.session_state.model_id,
cache_dir="models/huggingface",
)
st.session_state.model.eval()
st.session_state.processor = AutoProcessor.from_pretrained(
st.session_state.model_id,
cache_dir="models/huggingface",
)
def to_device() -> None:
"""
Move the model to the selected device.
"""
st.session_state.model.to(st.session_state.device.lower())
def scale_image(image: Image.Image, target_height: int = 500) -> Image.Image:
"""
Scale an image to a target height while maintaining the aspect ratio.
Parameters
----------
image : Image.Image
The image to scale.
target_height : int, optional (default=500)
The target height of the image.
Returns
-------
Image.Image
The scaled image.
"""
width, height = image.size
aspect_ratio = width / height
target_width = int(aspect_ratio * target_height)
return image.resize((target_width, target_height))
def upload_image() -> None:
"""
Upload an image.
"""
if st.session_state.file_uploader is not None:
st.session_state.image = Image.open(st.session_state.file_uploader)
def inference() -> None:
"""
Perform inference on an image and generate a caption.
"""
start_time = time()
outputs = st.session_state.processor(
images=st.session_state.image,
return_tensors="pt",
)
outputs = {k: v.to(st.session_state.device.lower()) for k, v in outputs.items()}
logits = st.session_state.model.generate(
**outputs,
max_length=st.session_state.max_length,
num_beams=st.session_state.num_beams,
)
caption = st.session_state.processor.decode(
logits[0], skip_special_tokens=True
)
end_time = time()
st.session_state.inference_time = round(end_time - start_time, 2)
st.session_state.caption = caption
def main() -> None:
"""
Main function for the Streamlit app.
"""
if "model" not in st.session_state:
st.session_state.model = AutoModelForVision2Seq.from_pretrained(
"Salesforce/blip-image-captioning-base",
cache_dir="models/huggingface",
)
st.session_state.model.eval().to("cpu")
if "processor" not in st.session_state:
st.session_state.processor = AutoProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base",
cache_dir="models/huggingface",
)
if "image" not in st.session_state:
st.session_state.image = None
if "caption" not in st.session_state:
st.session_state.caption = None
if "inference_time" not in st.session_state:
st.session_state.inference_time = 0.0
# Set page configuration
st.set_page_config(
page_title="Image Captioning App",
page_icon="๐Ÿ“ธ",
initial_sidebar_state="expanded",
)
# Set sidebar layout
st.sidebar.header("Workspace")
st.sidebar.file_uploader(
"Upload an image",
type=["jpg", "jpeg", "png"],
accept_multiple_files=False,
on_change=upload_image,
key="file_uploader",
help="Upload an image to generate a caption.",
)
st.sidebar.divider()
st.sidebar.header("Settings")
st.sidebar.selectbox(
label="Model ID",
options=["Salesforce/blip-image-captioning-base"],
index=0,
on_change=load_model_and_processor,
key="model_id",
help="The model to use for image captioning.",
)
st.sidebar.selectbox(
label="Device",
options=["CPU", "CUDA"],
index=0,
on_change=to_device,
key="device",
help="The device to use for inference.",
)
st.sidebar.number_input(
label="Max length",
min_value=32,
max_value=128,
value=128,
step=1,
key="max_length",
help="The maximum length of the generated caption.",
)
st.sidebar.number_input(
label="Number of beams",
min_value=1,
max_value=8,
value=4,
step=1,
key="num_beams",
help="The number of beams to use during decoding.",
)
# Set main layout
st.markdown(
"""
<h1 style='text-align: center;'>
Image Captioning
</h1>
""",
unsafe_allow_html=True,
)
st.divider()
image_container = st.container(height=450)
st.divider()
col_1, col_2, col_3 = st.columns([1, 1, 2])
resolution_display = col_1.empty()
runtime_display = col_2.empty()
caption_display = col_3.empty()
# Display the image and generate a caption
if st.session_state.image is not None:
image_container.image(scale_image(st.session_state.image, target_height=400))
resolution_display.metric(
label="Image Resolution",
value=f"{st.session_state.image.width}x{st.session_state.image.height}",
)
with st.spinner("Generating caption..."):
inference()
caption_display.text_area(
label="Caption",
value=st.session_state.caption,
)
runtime_display.metric(
label="Inference Time",
value=f"{st.session_state.inference_time}s",
)
if __name__ == "__main__":
main()