Flux_Gemma2 / app.py
el-el-san's picture
Update app.py
c6019b8 verified
raw
history blame
No virus
6.2 kB
import spaces
import json
import subprocess
from llama_cpp import Llama
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
import gradio as gr
from huggingface_hub import hf_hub_download
import torch
from diffusers import FluxPipeline
import numpy as np
import random
# Download models
hf_hub_download(
repo_id="bartowski/gemma-2-9b-it-GGUF",
filename="gemma-2-9b-it-Q5_K_M.gguf",
local_dir="./models"
)
# Initialize global variables
llm = None
llm_model = None
# Set up image generation
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
pipe.enable_model_cpu_offload() # Remove this line if you have enough GPU power
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
@spaces.GPU(duration=120)
def respond_and_generate_image(
message,
history,
model,
system_message,
max_tokens,
temperature,
top_p,
top_k,
repeat_penalty,
seed=42,
randomize_seed=False,
width=1024,
height=1024,
num_inference_steps=4,
):
global llm, llm_model
# Chatbot response
chat_template = MessagesFormatterType.GEMMA_2
if llm is None or llm_model != model:
llm = Llama(
model_path=f"models/{model}",
flash_attn=True,
n_gpu_layers=81,
n_batch=1024,
n_ctx=8192,
)
llm_model = model
provider = LlamaCppPythonProvider(llm)
agent = LlamaCppAgent(
provider,
system_prompt=f"{system_message}\nYou can also describe images for generation.",
predefined_messages_formatter_type=chat_template,
debug_output=True
)
settings = provider.get_provider_default_settings()
settings.temperature = temperature
settings.top_k = top_k
settings.top_p = top_p
settings.max_tokens = max_tokens
settings.repeat_penalty = repeat_penalty
settings.stream = True
messages = BasicChatHistory()
for msn in history:
user = {
'role': Roles.user,
'content': msn[0]
}
assistant = {
'role': Roles.assistant,
'content': msn[1]
}
messages.add_message(user)
messages.add_message(assistant)
stream = agent.get_chat_response(
message,
llm_sampling_settings=settings,
chat_history=messages,
returns_streaming_generator=True,
print_output=False
)
response = ""
for output in stream:
response += output
# Image generation
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator("cpu").manual_seed(seed)
# Truncate the prompt if it's too long
max_length = 77 # CLIP's maximum token length
truncated_response = ' '.join(response.split()[:max_length])
try:
image = pipe(
prompt=truncated_response,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=num_inference_steps,
max_sequence_length=256,
generator=generator,
width=width,
height=height
).images[0]
except Exception as e:
print(f"Error generating image: {e}")
image = None
history.append((message, response))
return history, image, seed, history # Return history twice to match the expected output count
description = """
<p><center>
<a href="https://huggingface.co/google/gemma-2-9b-it" target="_blank">[9B it Model]</a>
<a href="https://huggingface.co/bartowski/gemma-2-9b-it-GGUF" target="_blank">[9B it Model GGUF]</a>
<a href="https://huggingface.co/black-forest-labs/FLUX.1-schnell" target="_blank">[FLUX.1-shnell Model]</a>
</center></p>
"""
demo = gr.Blocks()
with demo:
gr.Markdown("# Chat with Gemma 2 and Generate Images with FLUX.1")
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1):
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Message")
clear = gr.Button("Clear")
with gr.Column(scale=1):
generated_image = gr.Image(label="Generated Image")
used_seed = gr.Number(label="Seed Used")
with gr.Accordion("Advanced Settings", open=False):
model = gr.Dropdown([
'gemma-2-9b-it-Q5_K_M.gguf',
],
value="gemma-2-9b-it-Q5_K_M.gguf",
label="Model"
)
system_message = gr.Textbox(value="You are a helpful assistant.", label="System message")
max_tokens = gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
top_k = gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-k")
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
state = gr.State([])
msg.submit(respond_and_generate_image,
inputs=[msg, state, model, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty, seed, randomize_seed, width, height, num_inference_steps],
outputs=[chatbot, generated_image, used_seed, state])
clear.click(lambda: ([], None, None, []), outputs=[chatbot, generated_image, used_seed, state])
if __name__ == "__main__":
demo.launch()