File size: 3,910 Bytes
71e47a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import base64
import re
import time
from functools import partial
from io import BytesIO

import gradio as gr
import torch

from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
from modules import shared
from modules.logging_colors import logger

params = {
    "add_all_images_to_prompt": False,
    # device to run vision encoder on
    "vision_device": None,
    # bits to load vision encoder in, either 16 or 32
    "vision_bits": 32,
    # device to run multimodal projector on
    "projector_device": None,
    # multimodal projector bits, either 32 or 16
    "projector_bits": 32
}


# If 'state' is True, will hijack the next chat generation
input_hijack = {
    'state': False,
    'value': ["", ""]
}


# initialized in ui, so that params are loaded from settings
multimodal_embedder: MultimodalEmbedder = None


def chat_input_modifier(text, visible_text, state):
    global input_hijack
    if input_hijack['state']:
        input_hijack['state'] = False
        return input_hijack['value'](text, visible_text)
    else:
        return text, visible_text


def add_chat_picture(picture, text, visible_text):
    # resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
    max_hw, min_hw = max(picture.size), min(picture.size)
    aspect_ratio = max_hw / min_hw
    shortest_edge = int(max(300 / aspect_ratio, 224))
    longest_edge = int(shortest_edge * aspect_ratio)
    w = shortest_edge if picture.width < picture.height else longest_edge
    h = shortest_edge if picture.width >= picture.height else longest_edge
    picture = picture.resize((w, h))

    buffer = BytesIO()
    picture.save(buffer, format="JPEG")
    img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
    image = f'<img src="data:image/jpeg;base64,{img_str}">'

    if '<image>' in text:
        text = text.replace('<image>', image)
    else:
        text = text + '\n' + image

    if visible_text == '' or visible_text is None:
        visible_text = text
    elif '<image>' in visible_text:
        visible_text = visible_text.replace('<image>', image)
    else:
        visible_text = visible_text + '\n' + image

    return text, visible_text


def custom_tokenized_length(prompt):
    return multimodal_embedder.len_in_tokens(prompt)


def tokenizer_modifier(state, prompt, input_ids, input_embeds):
    global params
    start_ts = time.time()
    image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt)

    if image_match is None:
        return prompt, input_ids, input_embeds

    prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
    logger.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
    return (prompt,
            input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
            input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))


def ui():
    global multimodal_embedder
    multimodal_embedder = MultimodalEmbedder(params)
    with gr.Column():
        picture_select = gr.Image(label='Send a picture', type='pil')
        # The models don't seem to deal well with multiple images
        single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
    # Prepare the input hijack
    picture_select.upload(
        lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
        [picture_select],
        None
    )
    picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None)
    single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
    shared.gradio['Generate'].click(lambda: None, None, picture_select)
    shared.gradio['textbox'].submit(lambda: None, None, picture_select)