Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gc | |
import gradio as gr | |
import numpy as np | |
import torch | |
import json | |
import spaces | |
import config | |
import utils | |
import logging | |
from PIL import Image, PngImagePlugin | |
from datetime import datetime | |
from diffusers.models import AutoencoderKL | |
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline | |
from config import ( | |
MODEL, | |
MIN_IMAGE_SIZE, | |
MAX_IMAGE_SIZE, | |
USE_TORCH_COMPILE, | |
ENABLE_CPU_OFFLOAD, | |
OUTPUT_DIR, | |
DEFAULT_NEGATIVE_PROMPT, | |
DEFAULT_ASPECT_RATIO, | |
examples, | |
sampler_list, | |
aspect_ratios, | |
style_list, | |
) | |
import time | |
from typing import List, Dict, Tuple, Optional | |
# Enhanced logging configuration | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
logger = logging.getLogger(__name__) | |
# Constants | |
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1" | |
# PyTorch settings for better performance and determinism | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cuda.matmul.allow_tf32 = True | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {device}") | |
class GenerationError(Exception): | |
"""Custom exception for generation errors""" | |
pass | |
def validate_prompt(prompt: str) -> str: | |
"""Validate and clean up the input prompt.""" | |
if not isinstance(prompt, str): | |
raise GenerationError("Prompt must be a string") | |
try: | |
# Ensure proper UTF-8 encoding/decoding | |
prompt = prompt.encode('utf-8').decode('utf-8') | |
# Add space between ! and , | |
prompt = prompt.replace("!,", "! ,") | |
except UnicodeError: | |
raise GenerationError("Invalid characters in prompt") | |
# Only check if the prompt is completely empty or only whitespace | |
if not prompt or prompt.isspace(): | |
raise GenerationError("Prompt cannot be empty") | |
return prompt.strip() | |
def validate_dimensions(width: int, height: int) -> None: | |
"""Validate image dimensions.""" | |
if not MIN_IMAGE_SIZE <= width <= MAX_IMAGE_SIZE: | |
raise GenerationError(f"Width must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") | |
if not MIN_IMAGE_SIZE <= height <= MAX_IMAGE_SIZE: | |
raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") | |
def generate( | |
prompt: str, | |
negative_prompt: str = DEFAULT_NEGATIVE_PROMPT, | |
seed: int = 0, | |
custom_width: int = 1024, | |
custom_height: int = 1024, | |
guidance_scale: float = 6.0, | |
num_inference_steps: int = 25, | |
sampler: str = "Euler a", | |
aspect_ratio_selector: str = DEFAULT_ASPECT_RATIO, | |
style_selector: str = "(None)", | |
use_upscaler: bool = False, | |
upscaler_strength: float = 0.55, | |
upscale_by: float = 1.5, | |
add_quality_tags: bool = True, | |
progress: gr.Progress = gr.Progress(track_tqdm=True), | |
) -> Tuple[List[str], Dict]: | |
"""Generate images based on the given parameters.""" | |
start_time = time.time() | |
upscaler_pipe = None | |
backup_scheduler = None | |
try: | |
# Memory management | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Input validation | |
prompt = validate_prompt(prompt) | |
if negative_prompt: | |
negative_prompt = negative_prompt.encode('utf-8').decode('utf-8') | |
validate_dimensions(custom_width, custom_height) | |
# Set up generation | |
generator = utils.seed_everything(seed) | |
width, height = utils.aspect_ratio_handler( | |
aspect_ratio_selector, | |
custom_width, | |
custom_height, | |
) | |
# Process prompts | |
if add_quality_tags: | |
prompt = "masterpiece, high score, great score, absurdres, {prompt}".format(prompt=prompt) | |
prompt, negative_prompt = utils.preprocess_prompt( | |
styles, style_selector, prompt, negative_prompt | |
) | |
width, height = utils.preprocess_image_dimensions(width, height) | |
# Set up pipeline | |
backup_scheduler = pipe.scheduler | |
pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler) | |
if use_upscaler: | |
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components) | |
# Prepare metadata | |
metadata = { | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"resolution": f"{width} x {height}", | |
"guidance_scale": guidance_scale, | |
"num_inference_steps": num_inference_steps, | |
"style_preset": style_selector, | |
"seed": seed, | |
"sampler": sampler, | |
"Model": "Animagine XL 4.0", | |
"Model hash": "e3c47aedb0", | |
} | |
if use_upscaler: | |
new_width = int(width * upscale_by) | |
new_height = int(height * upscale_by) | |
metadata["use_upscaler"] = { | |
"upscale_method": "nearest-exact", | |
"upscaler_strength": upscaler_strength, | |
"upscale_by": upscale_by, | |
"new_resolution": f"{new_width} x {new_height}", | |
} | |
else: | |
metadata["use_upscaler"] = None | |
logger.info(f"Starting generation with parameters: {json.dumps(metadata, indent=4)}") | |
# Generate images | |
if use_upscaler: | |
latents = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
output_type="latent", | |
).images | |
upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by) | |
images = upscaler_pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image=upscaled_latents, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
strength=upscaler_strength, | |
generator=generator, | |
output_type="pil", | |
).images | |
else: | |
images = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
output_type="pil", | |
).images | |
# Save images | |
if images: | |
total = len(images) | |
image_paths = [] | |
for idx, image in enumerate(images, 1): | |
progress(idx/total, desc="Saving images...") | |
path = utils.save_image(image, metadata, OUTPUT_DIR, IS_COLAB) | |
image_paths.append(path) | |
logger.info(f"Image {idx}/{total} saved as {path}") | |
generation_time = time.time() - start_time | |
logger.info(f"Generation completed successfully in {generation_time:.2f} seconds") | |
metadata["generation_time"] = f"{generation_time:.2f}s" | |
return image_paths, metadata | |
except GenerationError as e: | |
logger.warning(f"Generation validation error: {str(e)}") | |
raise gr.Error(str(e)) | |
except Exception as e: | |
logger.exception("Unexpected error during generation") | |
raise gr.Error(f"Generation failed: {str(e)}") | |
finally: | |
# Cleanup | |
torch.cuda.empty_cache() | |
gc.collect() | |
if upscaler_pipe is not None: | |
del upscaler_pipe | |
if backup_scheduler is not None and pipe is not None: | |
pipe.scheduler = backup_scheduler | |
utils.free_memory() | |
# Model initialization | |
if torch.cuda.is_available(): | |
try: | |
logger.info("Loading VAE and pipeline...") | |
vae = AutoencoderKL.from_pretrained( | |
"madebyollin/sdxl-vae-fp16-fix", | |
torch_dtype=torch.float16, | |
) | |
pipe = utils.load_pipeline(MODEL, device, vae=vae) | |
logger.info("Pipeline loaded successfully on GPU!") | |
except Exception as e: | |
logger.error(f"Error loading VAE, falling back to default: {e}") | |
pipe = utils.load_pipeline(MODEL, device) | |
else: | |
logger.warning("CUDA not available, running on CPU") | |
pipe = None | |
# Process styles | |
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} | |
with gr.Blocks(css="style.css", theme="Nymbo/Nymbo_Theme_5") as demo: | |
gr.HTML( | |
""" | |
<div class="header"> | |
<div class="title">ANIM4GINE</div> | |
<div class="subtitle">Gradio demo for <a href="https://huggingface.co/CagliostroLab/Animagine-XL-4.0" target="_blank">Animagine XL 4.0</a></div> | |
</div> | |
""", | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Group(): | |
prompt = gr.Text( | |
label="Prompt", | |
max_lines=5, | |
placeholder="Describe what you want to generate", | |
info="Enter your image generation prompt here. Be specific and descriptive for better results.", | |
) | |
negative_prompt = gr.Text( | |
label="Negative Prompt", | |
max_lines=5, | |
placeholder="Describe what you want to avoid", | |
value=DEFAULT_NEGATIVE_PROMPT, | |
info="Specify elements you don't want in the image.", | |
) | |
add_quality_tags = gr.Checkbox( | |
label="Quality Tags", | |
value=True, | |
info="Add quality-enhancing tags to your prompt automatically.", | |
) | |
with gr.Accordion(label="More Settings", open=False): | |
with gr.Group(): | |
aspect_ratio_selector = gr.Radio( | |
label="Aspect Ratio", | |
choices=aspect_ratios, | |
value=DEFAULT_ASPECT_RATIO, | |
container=True, | |
info="Choose the dimensions of your image.", | |
) | |
with gr.Group(visible=False) as custom_resolution: | |
with gr.Row(): | |
custom_width = gr.Slider( | |
label="Width", | |
minimum=MIN_IMAGE_SIZE, | |
maximum=MAX_IMAGE_SIZE, | |
step=8, | |
value=1024, | |
info=f"Image width (must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})", | |
) | |
custom_height = gr.Slider( | |
label="Height", | |
minimum=MIN_IMAGE_SIZE, | |
maximum=MAX_IMAGE_SIZE, | |
step=8, | |
value=1024, | |
info=f"Image height (must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})", | |
) | |
with gr.Group(): | |
use_upscaler = gr.Checkbox( | |
label="Use Upscaler", | |
value=False, | |
info="Enable high-resolution upscaling.", | |
) | |
with gr.Row() as upscaler_row: | |
upscaler_strength = gr.Slider( | |
label="Strength", | |
minimum=0, | |
maximum=1, | |
step=0.05, | |
value=0.55, | |
visible=False, | |
info="Control how much the upscaler affects the final image.", | |
) | |
upscale_by = gr.Slider( | |
label="Upscale by", | |
minimum=1, | |
maximum=1.5, | |
step=0.1, | |
value=1.5, | |
visible=False, | |
info="Multiplier for the final image resolution.", | |
) | |
with gr.Accordion(label="Advanced Parameters", open=False): | |
with gr.Group(): | |
style_selector = gr.Dropdown( | |
label="Style Preset", | |
interactive=True, | |
choices=list(styles.keys()), | |
value="(None)", | |
info="Apply a predefined style to your generation.", | |
) | |
with gr.Group(): | |
sampler = gr.Dropdown( | |
label="Sampler", | |
choices=sampler_list, | |
interactive=True, | |
value="Euler a", | |
info="Different samplers can produce varying results.", | |
) | |
with gr.Group(): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=utils.MAX_SEED, | |
step=1, | |
value=0, | |
info="Set a specific seed for reproducible results.", | |
) | |
randomize_seed = gr.Checkbox( | |
label="Randomize seed", | |
value=True, | |
info="Generate a new random seed for each image.", | |
) | |
with gr.Group(): | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=1, | |
maximum=12, | |
step=0.1, | |
value=6.0, | |
info="Higher values make the image more closely match your prompt.", | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=25, | |
info="More steps generally mean higher quality but slower generation.", | |
) | |
with gr.Column(scale=3): | |
with gr.Blocks(): | |
run_button = gr.Button("Generate", variant="primary", elem_id="generate-button") | |
result = gr.Gallery( | |
label="Generated Images", | |
columns=1, | |
height='768px', | |
preview=True, | |
show_label=True, | |
) | |
with gr.Accordion(label="Generation Parameters", open=False): | |
gr_metadata = gr.JSON( | |
label="Image Metadata", | |
show_label=True, | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=prompt, | |
outputs=[result, gr_metadata], | |
fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs), | |
cache_examples=CACHE_EXAMPLES, | |
) | |
# Discord button in a new full row | |
with gr.Row(): | |
gr.HTML( | |
""" | |
<a href="https://discord.com/invite/cqh9tZgbGc" target="_blank" class="discord-btn"> | |
<svg class="discord-icon" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 127.14 96.36"><path fill="currentColor" d="M107.7,8.07A105.15,105.15,0,0,0,81.47,0a72.06,72.06,0,0,0-3.36,6.83A97.68,97.68,0,0,0,49,6.83,72.37,72.37,0,0,0,45.64,0,105.89,105.89,0,0,0,19.39,8.09C2.79,32.65-1.71,56.6.54,80.21h0A105.73,105.73,0,0,0,32.71,96.36,77.7,77.7,0,0,0,39.6,85.25a68.42,68.42,0,0,1-10.85-5.18c.91-.66,1.8-1.34,2.66-2a75.57,75.57,0,0,0,64.32,0c.87.71,1.76,1.39,2.66,2a68.68,68.68,0,0,1-10.87,5.19,77,77,0,0,0,6.89,11.1A105.25,105.25,0,0,0,126.6,80.22h0C129.24,52.84,122.09,29.11,107.7,8.07ZM42.45,65.69C36.18,65.69,31,60,31,53s5-12.74,11.43-12.74S54,46,53.89,53,48.84,65.69,42.45,65.69Zm42.24,0C78.41,65.69,73.25,60,73.25,53s5-12.74,11.44-12.74S96.23,46,96.12,53,91.08,65.69,84.69,65.69Z"/></svg> | |
<span class="discord-text">Join our Discord Server</span> | |
</a> | |
""" | |
) | |
use_upscaler.change( | |
fn=lambda x: [gr.update(visible=x), gr.update(visible=x)], | |
inputs=use_upscaler, | |
outputs=[upscaler_strength, upscale_by], | |
queue=False, | |
api_name=False, | |
) | |
aspect_ratio_selector.change( | |
fn=lambda x: gr.update(visible=x == "Custom"), | |
inputs=aspect_ratio_selector, | |
outputs=custom_resolution, | |
queue=False, | |
api_name=False, | |
) | |
# Combine all triggers including keyboard shortcuts | |
gr.on( | |
triggers=[ | |
prompt.submit, | |
negative_prompt.submit, | |
run_button.click, | |
], | |
fn=utils.randomize_seed_fn, | |
inputs=[seed, randomize_seed], | |
outputs=seed, | |
queue=False, | |
api_name=False, | |
).then( | |
fn=lambda: gr.update(interactive=False, value="Generating..."), | |
outputs=run_button, | |
).then( | |
fn=generate, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
seed, | |
custom_width, | |
custom_height, | |
guidance_scale, | |
num_inference_steps, | |
sampler, | |
aspect_ratio_selector, | |
style_selector, | |
use_upscaler, | |
upscaler_strength, | |
upscale_by, | |
add_quality_tags, | |
], | |
outputs=[result, gr_metadata], | |
).then( | |
fn=lambda: gr.update(interactive=True, value="Generate"), | |
outputs=run_button, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB) | |