llava-4bit / app.py
merve's picture
merve HF staff
Update app.py
db8a6e8
raw
history blame
3.21 kB
import os
import string
import gradio as gr
import PIL.Image
import torch
from transformers import BitsAndBytesConfig, pipeline
import re
DESCRIPTION = "# LLaVA πŸŒ‹"
model_id = "llava-hf/llava-1.5-7b-hf"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
def extract_response_pairs(text):
pattern = re.compile(r'(USER:.*?)ASSISTANT:(.*?)(?:$|USER:)', re.DOTALL)
matches = pattern.findall(text)
print(matches)
pairs = [(user.strip(), assistant.strip()) for user, assistant in matches]
return pairs
def postprocess_output(output: str) -> str:
if output and output[-1] not in string.punctuation:
output += "."
return output
def chat(image, text, max_length, history_chat):
prompt = " ".join(history_chat) + f"USER: <image>\n{text}\nASSISTANT:"
outputs = pipe(image, prompt=prompt,
generate_kwargs={
"max_length":max_length})
#output = postprocess_output(outputs[0]["generated_text"])
history_chat.append(outputs[0]["generated_text"])
chat_val = extract_response_pairs(" ".join(history_chat))
return chat_val, history_chat
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.Markdown("LLaVA is now available in transformers with 4-bit quantization ⚑️")
chatbot = gr.Chatbot(label="Chat", show_label=False)
gr.Markdown("Input image and text to start chatting πŸ‘‡ ")
with gr.Row():
image = gr.Image(type="pil")
text_input = gr.Text(label="Chat Input", max_lines=1)
history_chat = gr.State(value=[])
with gr.Row():
clear_chat_button = gr.Button("Clear")
chat_button = gr.Button("Submit", variant="primary")
with gr.Accordion(label="Advanced settings", open=False):
max_length = gr.Slider(
label="Max Length",
minimum=1,
maximum=200,
step=1,
value=150,
)
chat_output = [
chatbot,
history_chat
]
chat_button.click(fn=chat, inputs=[image,
text_input,
max_length,
history_chat],
outputs=chat_output,
api_name="Chat",
)
chat_inputs = [
image,
text_input,
max_length,
history_chat
]
text_input.submit(
fn=chat,
inputs=chat_inputs,
outputs=chat_output
).success(
fn=lambda: "",
outputs=chat_inputs,
queue=False,
api_name=False,
)
clear_chat_button.click(
fn=lambda: ([], []),
inputs=None,
outputs=[
chatbot,
history_chat
],
queue=False,
api_name="clear",
)
image.change(
fn=lambda: ([], []),
inputs=None,
outputs=[
chatbot,
history_chat
],
queue=False,
)
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)