Spaces:
Build error
Build error
File size: 4,964 Bytes
d32adcb 8b300d9 05e5f88 8b300d9 5067d2d 8b300d9 d30851e 8b300d9 5067d2d 8b300d9 d30851e e466d0e d30851e d32adcb 7240187 d30851e 5067d2d 8b300d9 5067d2d cefbfeb 8b300d9 cefbfeb 8b300d9 d30851e 5067d2d 8b300d9 5067d2d d30851e 8b300d9 7240187 8b300d9 5067d2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from fromage import models
from fromage import utils
import gradio as gr
import huggingface_hub
import tempfile
class FromageChatBot:
def __init__(self):
# Download model from HF Hub.
ckpt_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
args_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
self.model = models.load_fromage('./', args_path, ckpt_path)
self.chat_history = ''
self.input_image = None
def reset(self):
self.chat_history = ""
self.input_image = None
return [], []
def upload_image(self, state, image_input):
state += [(f"![](/file={image_input.name})", "(Image received. Type or ask something to continue.)")]
self.input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
return state, state
def save_image_to_local(self, image: Image.Image):
# TODO(jykoh): Update so the url path is used, to prevent repeat saving.
filename = next(tempfile._get_candidate_names()) + '.png'
image.save(filename)
return filename
def generate_for_prompt(self, input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
input_prompt = 'Q: ' + input_text + '\nA:'
self.chat_history += input_prompt
print('Generating for', self.chat_history, flush=True)
# If an image was uploaded, prepend it to the model.
model_inputs = None
if self.input_image is not None:
model_inputs = [self.input_image, self.chat_history]
else:
model_inputs = [self.chat_history]
top_p = 1.0
if temperature != 0.0:
top_p = 0.95
print('Running model.generate_for_images_and_texts', flush=True)
model_outputs = self.model.generate_for_images_and_texts(model_inputs,
num_words=num_words, ret_scale_factor=ret_scale_factor, top_p=top_p,
temperature=temperature, max_num_rets=max_nm_rets)
print('model_outputs', model_outputs, flush=True)
im_names = []
response = ''
text_outputs = []
for output in model_outputs:
if type(output) == str:
text_outputs.append(output)
response += output
elif type(output) == list:
for image in output:
filename = self.save_image_to_local(image)
response += f'<img src="/file={filename}">'
elif type(output) == Image.Image:
filename = self.save_image_to_local(output)
response += f'<img src="/file={filename}">'
# TODO(jykoh): Persist image inputs.
self.chat_history += ' '.join(text_outputs)
if self.chat_history[-1] != '\n':
self.chat_history += '\n'
self.input_image = None
state.append((input_text, response))
return state, state
def launch(self):
with gr.Blocks() as demo:
gr.Markdown(
'### Grounding Language Models to Images for Multimodal Generation'
)
chatbot = gr.Chatbot()
gr_state = gr.State([])
with gr.Row():
with gr.Column(scale=0.3, min_width=0):
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
with gr.Column(scale=0.7, min_width=0):
image_btn = gr.UploadButton("Image Input", file_types=["image"])
text_input = gr.Textbox(label="Text Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
clear_btn = gr.Button("Clear History")
text_input.submit(self.generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
image_btn.upload(self.upload_image, [gr_state, image_btn], [gr_state, chatbot])
clear_btn.click(self.reset, [], [gr_state, chatbot])
demo.launch(share=False, debug=True, server_name="0.0.0.0")
def main():
chatbot = FromageChatBot()
chatbot.launch()
if __name__ == "__main__":
chatbot = FromageChatBot()
chatbot.launch() |