#!/usr/bin/env python from __future__ import annotations import os import string import gradio as gr import PIL.Image import torch from transformers import AutoProcessor, Blip2ForConditionalGeneration import re DESCRIPTION = "# LLaVA 🌋" def extract_response_pairs(text): pattern = re.compile(r'(USER:.*?)ASSISTANT:(.*?)(?:$|USER:)', re.DOTALL) matches = pattern.findall(text) pairs = [(user.strip(), assistant.strip()) for user, assistant in matches] return pairs def postprocess_output(output: str) -> str: if output and output[-1] not in string.punctuation: output += "." return output def chat(image, text, temperature, length_penalty, repetition_penalty, max_length, min_length, num_beams, top_p, history_chat): prompt = " ".join(history_chat) prompt = f"USER: \n{text}\nASSISTANT:" outputs = pipe(image, prompt=prompt, generate_kwargs={"temperature":temperature, "length_penalty":length_penalty, "repetition_penalty":repetition_penalty, "max_length":max_length, "min_length":min_length, "num_beams":num_beams, "top_p":top_p}) output = postprocess_output(outputs[0]["generated_text"]) history_chat.append(output) chat_val = extract_response_pairs(" ".join(history_chat)) return chat_val, history_chat css = """ #mkd { height: 500px; overflow: auto; border: 1px solid #ccc; } """ with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.Markdown("## LLaVA, one of the greatest multimodal chat models is now available in transformers with 4-bit quantization!") chatbot = gr.Chatbot(label="Chat", show_label=False) with gr.Row(): image = gr.Image(type="pil") text_input = gr.Text(label="Chat Input", show_label=False, max_lines=1, container=False) history_chat = gr.State(value=[]) with gr.Row(): clear_chat_button = gr.Button("Clear") chat_button = gr.Button("Submit", variant="primary") with gr.Accordion(label="Advanced settings", open=False): temperature = gr.Slider( label="Temperature", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=1.0, ) length_penalty = gr.Slider( label="Length Penalty", info="Set to larger for longer sequence, used with beam search.", minimum=-1.0, maximum=2.0, step=0.2, value=1.0, ) repetition_penalty = gr.Slider( label="Repetition Penalty", info="Larger value prevents repetition.", minimum=1.0, maximum=5.0, step=0.5, value=1.5, ) max_length = gr.Slider( label="Max Length", minimum=1, maximum=512, step=1, value=50, ) min_length = gr.Slider( label="Minimum Length", minimum=1, maximum=100, step=1, value=1, ) num_beams = gr.Slider( label="Number of Beams", minimum=1, maximum=10, step=1, value=5, ) top_p = gr.Slider( label="Top P", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=0.9, ) chat_output = [ chatbot, history_chat ] chat_button.click(fn=chat, inputs=[image, text_input, temperature, length_penalty, repetition_penalty, max_length, min_length, num_beams, top_p, history_chat], outputs=chat_output, api_name="Chat", ) chat_inputs = [ image, text_input, temperature, length_penalty, repetition_penalty, max_length, min_length, num_beams, top_p, history_chat ] text_input.submit( fn=chat, inputs=chat_inputs, outputs=chat_output ).success( fn=lambda: "", outputs=chat_inputs, queue=False, api_name=False, ) clear_chat_button.click( fn=lambda: ([], []), inputs=None, outputs=[ chatbot, history_chat ], queue=False, api_name="clear", ) image.change( fn=lambda: ([], []), inputs=None, outputs=[ chatbot, history_chat ], queue=False, ) if __name__ == "__main__": demo.queue(max_size=10).launch()