Spaces:
Runtime error
Runtime error
import os | |
import re | |
import pickle | |
import base64 | |
import requests | |
import argparse | |
import numpy as np | |
import gradio as gr | |
from functools import partial | |
from PIL import Image | |
SERVER_URL = os.getenv('SERVER_URL') | |
def get_images(state): | |
history = '' | |
for i in range(len(state)): | |
for j in range(len(state[i])): | |
history += state[i][j] + '\n' | |
for image_path in re.findall('image/[0-9,a-z]+\.png', history): | |
if os.path.exists(image_path): | |
continue | |
data = {'method': 'get_image', 'args': [image_path], 'kwargs': {}} | |
data = base64.b64encode(pickle.dumps(data)).decode('utf-8') | |
response = requests.post(SERVER_URL, json=data) | |
image = pickle.loads(base64.b64decode(response.json().encode('utf-8'))) | |
image.save(image_path) | |
def bot_request(method, *args, **kwargs): | |
data = {'method': method, 'args': args, 'kwargs': kwargs} | |
data = base64.b64encode(pickle.dumps(data)).decode('utf-8') | |
response = requests.post(SERVER_URL, json=data) | |
response = pickle.loads(base64.b64decode(response.json().encode('utf-8'))) | |
if response is not None: | |
state = response[0] | |
get_images(state) | |
return response | |
def run_image(image, *args, **kwargs): | |
if image is not None: | |
width, height = image.size | |
ratio = min(512 / width, 512 / height) | |
width_new, height_new = (round(width * ratio), round(height * ratio)) | |
width_new = int(np.round(width_new / 64.0)) * 64 | |
height_new = int(np.round(height_new / 64.0)) * 64 | |
image = image.resize((width_new, height_new)) | |
image = image.convert('RGB') | |
return bot_request('run_image', image, *args, **kwargs) | |
def predict_example(temperature, top_p, max_new_token, keep_last_n_paragraphs, image, text): | |
state = [] | |
buffer = '' | |
chatbot, state, text, buffer = run_image(image, state, text, buffer) | |
chatbot, state, text, buffer = bot_request( | |
'run_text', text, state, temperature, top_p, | |
max_new_token, keep_last_n_paragraphs, buffer) | |
return chatbot, state, text, None, buffer | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--temperature', type=float, default=0.0, help='temperature for the llm model') | |
parser.add_argument('--max_new_tokens', type=int, default=256, help='max number of new tokens to generate') | |
parser.add_argument('--top_p', type=float, default=1.0, help='top_p for the llm model') | |
parser.add_argument('--top_k', type=int, default=40, help='top_k for the llm model') | |
parser.add_argument('--keep_last_n_paragraphs', type=int, default=0, help='keep last n paragraphs in the memory') | |
args = parser.parse_args() | |
examples = [ | |
['images/example-1.jpg', 'What is unusual about this image?'], | |
['images/example-2.jpg', 'Make the image look like a cartoon.'], | |
['images/example-3.jpg', 'Segment the tie in the image.'], | |
['images/example-4.jpg', 'Generate a man watching a sea based on the pose of the woman.'], | |
['images/example-5.jpg', 'Replace the dog with a monkey.'], | |
] | |
if not os.path.exists('image'): | |
os.makedirs('image') | |
with gr.Blocks() as demo: | |
state = gr.Chatbot([], visible=False) | |
buffer = gr.Textbox('', visible=False) | |
with gr.Row(): | |
with gr.Column(scale=0.3): | |
with gr.Row(): | |
image = gr.Image(type='pil', label='input image') | |
with gr.Row(): | |
txt = gr.Textbox(lines=7, show_label=False, elem_id='textbox', | |
placeholder='Enter text and press submit, or upload an image').style(container=False) | |
with gr.Row(): | |
submit = gr.Button('Submit') | |
with gr.Row(): | |
clear = gr.Button('Clear') | |
with gr.Row(): | |
llm_name = gr.Radio( | |
["Vicuna-13B"], | |
label="LLM Backend", | |
value="Vicuna-13B", | |
interactive=True) | |
keep_last_n_paragraphs = gr.Slider( | |
minimum=0, | |
maximum=3, | |
value=args.keep_last_n_paragraphs, | |
step=1, | |
interactive=True, | |
label='Remember Last N Paragraphs') | |
max_new_token = gr.Slider( | |
minimum=64, | |
maximum=512, | |
value=args.max_new_tokens, | |
step=1, | |
interactive=True, | |
label='Max New Tokens') | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=args.temperature, | |
step=0.1, | |
interactive=True, | |
visible=False, | |
label='Temperature') | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=args.top_p, | |
step=0.1, | |
interactive=True, | |
visible=False, | |
label='Top P') | |
with gr.Column(scale=0.7): | |
chatbot = gr.Chatbot(elem_id='chatbot', label='π¦ GPT4Tools').style(height=690) | |
image.upload(lambda: '', None, txt) | |
submit.click(run_image, | |
[image, state, txt, buffer], | |
[chatbot, state, txt, buffer]).then( | |
partial(bot_request, 'run_text'), | |
[txt, state, temperature, top_p, max_new_token, keep_last_n_paragraphs, buffer], | |
[chatbot, state, txt, buffer]).then( | |
lambda: None, None, image) | |
clear.click(partial(bot_request, 'clear')) | |
clear.click(lambda: [[], [], '', ''], None, [chatbot, state, txt, buffer]) | |
with gr.Row(): | |
gr.Examples( | |
examples=examples, | |
fn=partial(predict_example, args.temperature, args.top_p, | |
args.max_new_tokens, args.keep_last_n_paragraphs), | |
inputs=[image, txt], | |
outputs=[chatbot, state, txt, image, buffer], | |
cache_examples=True, | |
) | |
demo.queue(concurrency_count=6) | |
demo.launch() | |