multimodalart's picture
Update app.py
477a209 verified
raw
history blame
2.88 kB
import gradio as gr
from diffusers import StableDiffusionXLPipeline
import numpy as np
import math
import spaces
import torch
import sys
import random
from gradio_imageslider import ImageSlider
theme = gr.themes.Base(
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
)
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
custom_pipeline="multimodalart/sdxl_perturbed_attention_guidance",
torch_dtype=torch.float16
)
device="cuda"
pipe = pipe.to(device)
@spaces.GPU
def run(prompt, negative_prompt="", guidance_scale=7.0, pag_scale=3.0, randomize_seed=True, seed=42, progress=gr.Progress(track_tqdm=True)):
prompt = prompt.strip()
if(randomize_seed):
seed = random.randint(0, sys.maxsize)
if(prompt == ""):
guidance_scale = 0.0
generator = torch.Generator(device="cuda").manual_seed(seed)
image_pag = pipe(prompt, guidance_scale=guidance_scale, pag_scale=3.0, pag_applied_layers=['mid'], generator=generator, num_inference_steps=25).images[0]
generator = torch.Generator(device="cuda").manual_seed(seed)
image_normal = pipe(prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=25).images[0]
return (image_pag, image_normal), seed
css = '''
.gradio-container{
max-width: 768px !important;
margin: 0 auto;
}
'''
with gr.Blocks(css=css, theme=theme) as demo:
gr.Markdown('''# Perturbed Attention Guidance SDXL
SDXL 🧨 [diffusers implementation](https://huggingface.co/multimodalart/sdxl_perturbed_attention_guidance) of [Perturbed-Attenton Guidance](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/)
''')
with gr.Group():
with gr.Row():
prompt = gr.Textbox(show_label=False, scale=4, placeholder="Your prompt", info="Leave blank to test unconditional generation")
button = gr.Button("Generate", min_width=120)
output = ImageSlider(label="Left: PAG, Right: No PAG", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
guidance_scale = gr.Number(label="Guidance Scale", value=7.0)
pag_scale = gr.Number(label="Pag Scale", value=3.0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
seed = gr.Slider(minimum=1, maximum=18446744073709551615, step=1, randomize=True)
gr.Examples(fn=run, examples=[" ", "an insect robot preparing a delicious meal, anime style", "a photo of a group of friends at an amusement park"], inputs=prompt, outputs=[output, seed], cache_examples=True)
gr.on(
triggers=[
button.click,
prompt.submit
],
fn=run,
inputs=[prompt, guidance_scale, pag_scale, seed],
outputs=[output, seed],
)
if __name__ == "__main__":
demo.launch(share=True)