File size: 3,209 Bytes
8400add
 
 
 
 
 
ee95e21
8400add
 
 
 
e5327ee
 
 
 
 
 
 
 
3d139ce
8400add
 
 
3d139ce
8400add
 
 
 
 
 
 
 
 
 
 
 
 
3d139ce
8400add
3d139ce
8400add
 
3d139ce
 
8400add
3d139ce
 
8400add
 
3d139ce
8400add
 
 
 
 
 
 
 
 
 
 
 
db8a6e8
8400add
db8a6e8
8400add
 
db8a6e8
8400add
 
 
 
 
 
 
 
 
3d139ce
8400add
db8a6e8
8400add
3d139ce
8400add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d139ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import string

import gradio as gr
import PIL.Image
import torch
from transformers import BitsAndBytesConfig, pipeline
import re

DESCRIPTION = "# LLaVA 🌋"

model_id = "llava-hf/llava-1.5-7b-hf"
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)
pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})



def extract_response_pairs(text):
    pattern = re.compile(r'(USER:.*?)ASSISTANT:(.*?)(?:$|USER:)', re.DOTALL)
    matches = pattern.findall(text)
    print(matches)

    pairs = [(user.strip(), assistant.strip()) for user, assistant in matches]

    return pairs


def postprocess_output(output: str) -> str:
    if output and output[-1] not in string.punctuation:
        output += "."
    return output



def chat(image, text, max_length, history_chat):
  
  prompt = " ".join(history_chat) + f"USER: <image>\n{text}\nASSISTANT:"
  
  outputs = pipe(image, prompt=prompt, 
                  generate_kwargs={
                  "max_length":max_length})
  
  #output = postprocess_output(outputs[0]["generated_text"])
  history_chat.append(outputs[0]["generated_text"])

  chat_val =  extract_response_pairs(" ".join(history_chat))

  return chat_val, history_chat


css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
  """
with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.Markdown("LLaVA is now available in transformers with 4-bit quantization ⚡️")
    chatbot = gr.Chatbot(label="Chat", show_label=False)
    gr.Markdown("Input image and text to start chatting 👇 ")
    with gr.Row():
      image = gr.Image(type="pil")
      text_input = gr.Text(label="Chat Input", max_lines=1)
      
    history_chat = gr.State(value=[])
    with gr.Row():
        clear_chat_button = gr.Button("Clear")
        chat_button = gr.Button("Submit", variant="primary")
    with gr.Accordion(label="Advanced settings", open=False):
        max_length = gr.Slider(
            label="Max Length",
            minimum=1,
            maximum=200,
            step=1,
            value=150,
        )

    chat_output = [
        chatbot,
        history_chat
    ]
    chat_button.click(fn=chat, inputs=[image, 
            text_input,
            max_length,
            history_chat],
        outputs=chat_output,
        api_name="Chat",
    )

    chat_inputs = [
        image,
        text_input,
        max_length,
        history_chat
    ]
    text_input.submit(
        fn=chat,
        inputs=chat_inputs,
        outputs=chat_output
    ).success(
        fn=lambda: "",
        outputs=chat_inputs,
        queue=False,
        api_name=False,
    )
    clear_chat_button.click(
        fn=lambda: ([], []),
        inputs=None,
        outputs=[
            chatbot,
            history_chat
        ],
        queue=False,
        api_name="clear",
    )
    image.change(
        fn=lambda: ([], []),
        inputs=None,
        outputs=[
            chatbot,
            history_chat
        ],
        queue=False,
    )
    

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)