BLIP2 / app.py
openfree's picture
Update app.py
41e0760 verified
#!/usr/bin/env python
import os
import string
import gradio as gr
import PIL.Image
import spaces
import torch
from transformers import AutoProcessor, BitsAndBytesConfig, Blip2ForConditionalGeneration
# ์Šคํƒ€์ผ ์ƒ์ˆ˜ ์ •์˜
CUSTOM_CSS = """
.container {
max-width: 1000px;
margin: auto;
padding: 2rem;
background: linear-gradient(to bottom right, #ffffff, #f8f9fa);
border-radius: 15px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.title {
font-size: 2.5rem;
color: #1a73e8;
text-align: center;
margin-bottom: 2rem;
font-weight: bold;
}
.tab-nav {
background: #f8f9fa;
border-radius: 10px;
padding: 0.5rem;
margin-bottom: 1rem;
}
.input-box {
border: 2px solid #e0e0e0;
border-radius: 8px;
transition: all 0.3s ease;
}
.input-box:focus {
border-color: #1a73e8;
box-shadow: 0 0 0 2px rgba(26, 115, 232, 0.2);
}
.button-primary {
background: #1a73e8;
color: white;
padding: 0.75rem 1.5rem;
border-radius: 8px;
border: none;
cursor: pointer;
transition: all 0.3s ease;
}
.button-primary:hover {
background: #1557b0;
transform: translateY(-1px);
}
.button-secondary {
background: #f8f9fa;
color: #1a73e8;
border: 1px solid #1a73e8;
padding: 0.75rem 1.5rem;
border-radius: 8px;
cursor: pointer;
transition: all 0.3s ease;
}
.button-secondary:hover {
background: #e8f0fe;
}
.output-box {
background: #ffffff;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
border: 1px solid #e0e0e0;
}
.chatbot-message {
padding: 1rem;
margin: 0.5rem 0;
border-radius: 8px;
background: #f8f9fa;
}
.advanced-settings {
background: #ffffff;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
}
.slider-container {
padding: 0.5rem;
background: #f8f9fa;
border-radius: 6px;
}
.examples-container {
margin-top: 2rem;
padding: 1rem;
background: #ffffff;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
}
"""
DESCRIPTION = """
<div class="title">
๐Ÿ–ผ๏ธ BLIP-2 Visual Intelligence System
</div>
<p style='text-align: center; color: #666;'>
Advanced AI system for image understanding and natural conversation
</p>
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p style='color: #dc3545;'>Running on CPU ๐Ÿฅถ This demo requires GPU to function properly.</p>"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_ID_OPT_2_7B = "Salesforce/blip2-opt-2.7b"
MODEL_ID_OPT_6_7B = "Salesforce/blip2-opt-6.7b"
MODEL_ID_FLAN_T5_XL = "Salesforce/blip2-flan-t5-xl"
MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
MODEL_ID = os.getenv("MODEL_ID", MODEL_ID_FLAN_T5_XXL)
if MODEL_ID not in [MODEL_ID_OPT_2_7B, MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XL, MODEL_ID_FLAN_T5_XXL]:
error_message = f"Invalid MODEL_ID: {MODEL_ID}"
raise ValueError(error_message)
if torch.cuda.is_available():
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Blip2ForConditionalGeneration.from_pretrained(
MODEL_ID,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_8bit=True)
)
@spaces.GPU
def generate_caption(
image: PIL.Image.Image,
decoding_method: str = "Nucleus sampling",
temperature: float = 1.0,
length_penalty: float = 1.0,
repetition_penalty: float = 1.5,
max_length: int = 50,
min_length: int = 1,
num_beams: int = 5,
top_p: float = 0.9,
) -> str:
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(
pixel_values=inputs.pixel_values,
do_sample=decoding_method == "Nucleus sampling",
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
top_p=top_p,
)
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
@spaces.GPU
def answer_question(
image: PIL.Image.Image,
prompt: str,
decoding_method: str = "Nucleus sampling",
temperature: float = 1.0,
length_penalty: float = 1.0,
repetition_penalty: float = 1.5,
max_length: int = 50,
min_length: int = 1,
num_beams: int = 5,
top_p: float = 0.9,
) -> str:
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(
**inputs,
do_sample=decoding_method == "Nucleus sampling",
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
top_p=top_p,
)
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
def postprocess_output(output: str) -> str:
if output and output[-1] not in string.punctuation:
output += "."
return output
def chat(
image: PIL.Image.Image,
text: str,
decoding_method: str = "Nucleus sampling",
temperature: float = 1.0,
length_penalty: float = 1.0,
repetition_penalty: float = 1.5,
max_length: int = 50,
min_length: int = 1,
num_beams: int = 5,
top_p: float = 0.9,
history_orig: list[str] | None = None,
history_qa: list[str] | None = None,
) -> tuple[list[tuple[str, str]], list[str], list[str]]:
history_orig = history_orig or []
history_qa = history_qa or []
history_orig.append(text)
text_qa = f"Question: {text} Answer:"
history_qa.append(text_qa)
prompt = " ".join(history_qa)
output = answer_question(
image=image,
prompt=prompt,
decoding_method=decoding_method,
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
top_p=top_p,
)
output = postprocess_output(output)
history_orig.append(output)
history_qa.append(output)
chat_val = list(zip(history_orig[0::2], history_orig[1::2], strict=False))
return chat_val, history_orig, history_qa
chat.zerogpu = True # type: ignore
examples = [
[
"images/house.png",
"How could someone get out of the house?",
],
[
"images/flower.jpg",
"What is this flower and where is it's origin?",
],
[
"images/pizza.jpg",
"What are steps to cook it?",
],
[
"images/sunset.jpg",
"Here is a romantic message going along the photo:",
],
[
"images/forbidden_city.webp",
"In what dynasties was this place built?",
],
]
with gr.Blocks(css=CUSTOM_CSS) as demo:
gr.Markdown(DESCRIPTION)
with gr.Group(elem_classes="container"):
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(
type="pil",
label="Upload Image",
elem_classes="input-box"
)
with gr.Column(scale=2):
with gr.Tabs(elem_classes="tab-nav"):
with gr.Tab(label="โœจ Image Captioning"):
caption_button = gr.Button(
"Generate Caption",
elem_classes="button-primary"
)
caption_output = gr.Textbox(
label="Generated Caption",
elem_classes="output-box"
)
with gr.Tab(label="๐Ÿ’ญ Visual Q&A"):
chatbot = gr.Chatbot(
elem_classes="chatbot-message"
)
history_orig = gr.State(value=[])
history_qa = gr.State(value=[])
vqa_input = gr.Textbox(
placeholder="Ask me anything about the image...",
elem_classes="input-box"
)
with gr.Row():
clear_button = gr.Button(
"Clear Chat",
elem_classes="button-secondary"
)
submit_button = gr.Button(
"Send Message",
elem_classes="button-primary"
)
with gr.Accordion("๐Ÿ› ๏ธ Advanced Settings", open=False, elem_classes="advanced-settings"):
with gr.Row():
with gr.Column():
text_decoding_method = gr.Radio(
choices=["Beam search", "Nucleus sampling"],
value="Nucleus sampling",
label="Decoding Method"
)
temperature = gr.Slider(
minimum=0.5,
maximum=1.0,
value=1.0,
label="Temperature",
info="Used with nucleus sampling",
elem_classes="slider-container"
)
length_penalty = gr.Slider(
minimum=-1.0,
maximum=2.0,
value=1.0,
label="Length Penalty",
info="Set to larger for longer sequence",
elem_classes="slider-container"
)
with gr.Column():
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=5.0,
value=1.5,
label="Repetition Penalty",
info="Larger value prevents repetition",
elem_classes="slider-container"
)
max_length = gr.Slider(
minimum=20,
maximum=512,
value=50,
label="Max Length",
elem_classes="slider-container"
)
min_length = gr.Slider(
minimum=1,
maximum=100,
value=1,
label="Min Length",
elem_classes="slider-container"
)
num_beams = gr.Slider(
minimum=1,
maximum=10,
value=5,
label="Number of Beams",
elem_classes="slider-container"
)
top_p = gr.Slider(
minimum=0.5,
maximum=1.0,
value=0.9,
label="Top P",
info="Used with nucleus sampling",
elem_classes="slider-container"
)
with gr.Group(elem_classes="examples-container"):
gr.Examples(
examples=examples,
inputs=[image, vqa_input],
label="Try these examples"
)
# Event handlers
caption_button.click(
fn=generate_caption,
inputs=[
image,
text_decoding_method,
temperature,
length_penalty,
repetition_penalty,
max_length,
min_length,
num_beams,
top_p,
],
outputs=caption_output,
api_name="caption"
)
chat_inputs = [
image,
vqa_input,
text_decoding_method,
temperature,
length_penalty,
repetition_penalty,
max_length,
min_length,
num_beams,
top_p,
history_orig,
history_qa,
]
chat_outputs = [
chatbot,
history_orig,
history_qa,
]
vqa_input.submit(
fn=chat,
inputs=chat_inputs,
outputs=chat_outputs
).success(
fn=lambda: "",
outputs=vqa_input,
queue=False,
api_name=False
)
submit_button.click(
fn=chat,
inputs=chat_inputs,
outputs=chat_outputs,
api_name="chat"
).success(
fn=lambda: "",
outputs=vqa_input,
queue=False,
api_name=False
)
clear_button.click(
fn=lambda: ("", [], [], []),
inputs=None,
outputs=[
vqa_input,
chatbot,
history_orig,
history_qa,
],
queue=False,
api_name="clear"
)
image.change(
fn=lambda: ("", [], [], []),
inputs=None,
outputs=[
caption_output,
chatbot,
history_orig,
history_qa,
],
queue=False
)
if __name__ == "__main__":
demo.queue(max_size=10).launch()