Spaces:
Build error
Build error
import os | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" | |
import numpy as np | |
import torch | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from fromage import models | |
from fromage import utils | |
import gradio as gr | |
import huggingface_hub | |
import tempfile | |
class FromageChatBot: | |
def __init__(self): | |
# Download model from HF Hub. | |
ckpt_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar') | |
args_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json') | |
self.model = models.load_fromage('./', args_path, ckpt_path) | |
self.chat_history = '' | |
self.input_image = None | |
def reset(self): | |
self.chat_history = "" | |
self.input_image = None | |
return [], [] | |
def upload_image(self, state, image_input): | |
state += [(f"![](/file={image_input.name})", "(Image received. Type or ask something to continue.)")] | |
self.input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB') | |
return state, state | |
def save_image_to_local(self, image: Image.Image): | |
# TODO(jykoh): Update so the url path is used, to prevent repeat saving. | |
filename = next(tempfile._get_candidate_names()) + '.png' | |
image.save(filename) | |
return filename | |
def generate_for_prompt(self, input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature): | |
input_prompt = 'Q: ' + input_text + '\nA:' | |
self.chat_history += input_prompt | |
print('Generating for', self.chat_history, flush=True) | |
# If an image was uploaded, prepend it to the model. | |
model_inputs = None | |
if self.input_image is not None: | |
model_inputs = [self.input_image, self.chat_history] | |
else: | |
model_inputs = [self.chat_history] | |
top_p = 1.0 | |
if temperature != 0.0: | |
top_p = 0.95 | |
print('Running model.generate_for_images_and_texts', flush=True) | |
model_outputs = self.model.generate_for_images_and_texts(model_inputs, | |
num_words=num_words, ret_scale_factor=ret_scale_factor, top_p=top_p, | |
temperature=temperature, max_num_rets=max_nm_rets) | |
print('model_outputs', model_outputs, flush=True) | |
im_names = [] | |
response = '' | |
text_outputs = [] | |
for output in model_outputs: | |
if type(output) == str: | |
text_outputs.append(output) | |
response += output | |
elif type(output) == list: | |
for image in output: | |
filename = self.save_image_to_local(image) | |
response += f'<img src="/file={filename}">' | |
elif type(output) == Image.Image: | |
filename = self.save_image_to_local(output) | |
response += f'<img src="/file={filename}">' | |
# TODO(jykoh): Persist image inputs. | |
self.chat_history += ' '.join(text_outputs) | |
if self.chat_history[-1] != '\n': | |
self.chat_history += '\n' | |
self.input_image = None | |
state.append((input_text, response)) | |
return state, state | |
def launch(self): | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
'### Grounding Language Models to Images for Multimodal Generation' | |
) | |
chatbot = gr.Chatbot() | |
gr_state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=0.3, min_width=0): | |
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)") | |
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return") | |
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True) | |
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True) | |
with gr.Column(scale=0.7, min_width=0): | |
image_btn = gr.UploadButton("Image Input", file_types=["image"]) | |
text_input = gr.Textbox(label="Text Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!") | |
clear_btn = gr.Button("Clear History") | |
text_input.submit(self.generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot]) | |
image_btn.upload(self.upload_image, [gr_state, image_btn], [gr_state, chatbot]) | |
clear_btn.click(self.reset, [], [gr_state, chatbot]) | |
demo.launch(share=False, debug=True, server_name="0.0.0.0") | |
def main(): | |
chatbot = FromageChatBot() | |
chatbot.launch() | |
if __name__ == "__main__": | |
chatbot = FromageChatBot() | |
chatbot.launch() |