File size: 4,554 Bytes
d32adcb
 
8b300d9
 
 
 
 
 
 
 
 
 
 
 
f3e34b0
 
 
 
 
 
 
a03fe94
 
 
f3e34b0
a03fe94
f3e34b0
 
cce1831
a03fe94
cce1831
 
f3e34b0
 
 
 
 
 
 
 
 
a03fe94
 
 
f3e34b0
 
 
 
 
a03fe94
f3e34b0
a03fe94
 
 
f3e34b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19d0ef9
a03fe94
f3e34b0
 
 
a03fe94
 
f3e34b0
 
 
 
 
 
 
 
a03fe94
f3e34b0
 
 
 
 
 
 
 
 
 
 
 
 
 
a03fe94
f3e34b0
 
 
0fdb545
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt

from fromage import models
from fromage import utils
import gradio as gr
import huggingface_hub
import tempfile


# Download model from HF Hub.
ckpt_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
args_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
model = models.load_fromage('./', args_path, ckpt_path)


def upload_image(state, image_input):
    conversation = state[0]
    chat_history = state[1]
    conversation += [(f"![](/file={image_input.name})", "")]
    input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
    return [conversation, chat_history, input_image], conversation


def reset():
    return [[], [], None], []


def save_image_to_local(image: Image.Image):
    # TODO(jykoh): Update so the url path is used, to prevent repeat saving.
    filename = next(tempfile._get_candidate_names()) + '.png'
    image.save(filename)
    return filename


def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
    input_prompt = 'Q: ' + input_text + '\nA:'
    conversation = state[0]
    chat_history = state[1]
    input_image = state[2]
    print('Generating for', chat_history, flush=True)

    # If an image was uploaded, prepend it to the model.
    model_inputs = None
    if input_image is not None:
        model_inputs = chat_history + [input_image]
    else:
        model_inputs = chat_history

    model_inputs.append(input_prompt)

    top_p = 1.0
    if temperature != 0.0:
        top_p = 0.95

    print('Running model.generate_for_images_and_texts with', model_inputs, flush=True)
    model_outputs = model.generate_for_images_and_texts(model_inputs, 
        num_words=num_words, ret_scale_factor=ret_scale_factor, top_p=top_p,
        temperature=temperature, max_num_rets=max_nm_rets)
    print('model_outputs', model_outputs, flush=True)

    im_names = []
    response = ''
    text_outputs = []
    for output in model_outputs:
        if type(output) == str:
            text_outputs.append(output)
            response += output
        elif type(output) == list:
            for image in output:
                filename = save_image_to_local(image)
                response += f'<img src="/file={filename}">'
        elif type(output) == Image.Image:
                filename = save_image_to_local(output)
                response += f'<img src="/file={filename}">'

    # TODO(jykoh): Persist image inputs.
    chat_history = model_inputs + ' '.join([s for s in model_outputs if type(s) == str]) + '\n'
    conversation.append((input_text, response))

    # Set input image to None.
    print('state', state, flush=True)
    print('updated state', [conversation, chat_history, None], flush=True)
    return [conversation, chat_history, None], conversation


with gr.Blocks() as demo:
    gr.Markdown(
        '### Grounding Language Models to Images for Multimodal Generation'
    )

    chatbot = gr.Chatbot()
    gr_state = gr.State([[], [], None])  # chat_history, input_image

    with gr.Row():
        with gr.Column(scale=0.3, min_width=0):
            ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
            max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
            gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
            gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)

        with gr.Column(scale=0.7, min_width=0):
            image_btn = gr.UploadButton("Image Input", file_types=["image"])     
            text_input = gr.Textbox(label="Text Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
            clear_btn = gr.Button("Clear History")

    text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
    text_input.submit(lambda: "", None, text_input)  # Reset chatbox.
    image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
    clear_btn.click(reset, [], [gr_state, chatbot])

demo.launch(share=False, debug=True, server_name="0.0.0.0")