Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import json | |
import subprocess | |
from llama_cpp import Llama | |
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType | |
from llama_cpp_agent.providers import LlamaCppPythonProvider | |
from llama_cpp_agent.chat_history import BasicChatHistory | |
from llama_cpp_agent.chat_history.messages import Roles | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import torch | |
from diffusers import FluxPipeline | |
import numpy as np | |
import random | |
# Download models | |
hf_hub_download( | |
repo_id="bartowski/gemma-2-9b-it-GGUF", | |
filename="gemma-2-9b-it-Q5_K_M.gguf", | |
local_dir="./models" | |
) | |
# Initialize global variables | |
llm = None | |
llm_model = None | |
# Set up image generation | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype) | |
pipe.enable_model_cpu_offload() # Remove this line if you have enough GPU power | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
def respond_and_generate_image( | |
message, | |
history, | |
model, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
top_k, | |
repeat_penalty, | |
seed=42, | |
randomize_seed=False, | |
width=1024, | |
height=1024, | |
num_inference_steps=4, | |
): | |
global llm, llm_model | |
# Chatbot response | |
chat_template = MessagesFormatterType.GEMMA_2 | |
if llm is None or llm_model != model: | |
llm = Llama( | |
model_path=f"models/{model}", | |
flash_attn=True, | |
n_gpu_layers=81, | |
n_batch=1024, | |
n_ctx=8192, | |
) | |
llm_model = model | |
provider = LlamaCppPythonProvider(llm) | |
agent = LlamaCppAgent( | |
provider, | |
system_prompt=f"{system_message}\nYou can also describe images for generation.", | |
predefined_messages_formatter_type=chat_template, | |
debug_output=True | |
) | |
settings = provider.get_provider_default_settings() | |
settings.temperature = temperature | |
settings.top_k = top_k | |
settings.top_p = top_p | |
settings.max_tokens = max_tokens | |
settings.repeat_penalty = repeat_penalty | |
settings.stream = True | |
messages = BasicChatHistory() | |
for msn in history: | |
user = { | |
'role': Roles.user, | |
'content': msn[0] | |
} | |
assistant = { | |
'role': Roles.assistant, | |
'content': msn[1] | |
} | |
messages.add_message(user) | |
messages.add_message(assistant) | |
stream = agent.get_chat_response( | |
message, | |
llm_sampling_settings=settings, | |
chat_history=messages, | |
returns_streaming_generator=True, | |
print_output=False | |
) | |
response = "" | |
for output in stream: | |
response += output | |
# Image generation | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator("cpu").manual_seed(seed) | |
# Truncate the prompt if it's too long | |
max_length = 77 # CLIP's maximum token length | |
truncated_response = ' '.join(response.split()[:max_length]) | |
try: | |
image = pipe( | |
prompt=truncated_response, | |
guidance_scale=0.0, | |
output_type="pil", | |
num_inference_steps=num_inference_steps, | |
max_sequence_length=256, | |
generator=generator, | |
width=width, | |
height=height | |
).images[0] | |
except Exception as e: | |
print(f"Error generating image: {e}") | |
image = None | |
history.append((message, response)) | |
return history, image, seed, history # Return history twice to match the expected output count | |
description = """ | |
<p><center> | |
<a href="https://huggingface.co/google/gemma-2-9b-it" target="_blank">[9B it Model]</a> | |
<a href="https://huggingface.co/bartowski/gemma-2-9b-it-GGUF" target="_blank">[9B it Model GGUF]</a> | |
<a href="https://huggingface.co/black-forest-labs/FLUX.1-schnell" target="_blank">[FLUX.1-shnell Model]</a> | |
</center></p> | |
""" | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# Chat with Gemma 2 and Generate Images with FLUX.1") | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="Message") | |
clear = gr.Button("Clear") | |
with gr.Column(scale=1): | |
generated_image = gr.Image(label="Generated Image") | |
used_seed = gr.Number(label="Seed Used") | |
with gr.Accordion("Advanced Settings", open=False): | |
model = gr.Dropdown([ | |
'gemma-2-9b-it-Q5_K_M.gguf', | |
], | |
value="gemma-2-9b-it-Q5_K_M.gguf", | |
label="Model" | |
) | |
system_message = gr.Textbox(value="You are a helpful assistant.", label="System message") | |
max_tokens = gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p") | |
top_k = gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k") | |
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty") | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4) | |
state = gr.State([]) | |
msg.submit(respond_and_generate_image, | |
inputs=[msg, state, model, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty, seed, randomize_seed, width, height, num_inference_steps], | |
outputs=[chatbot, generated_image, used_seed, state]) | |
clear.click(lambda: ([], None, None, []), outputs=[chatbot, generated_image, used_seed, state]) | |
if __name__ == "__main__": | |
demo.launch() |