FLUX.1-schnell / app.py
davidberenstein1957's picture
Update app.py
48dd920 verified
raw
history blame
No virus
2.39 kB
import random
import gradio as gr
import numpy as np
import spaces
import torch
from dataset_viber import CollectorInterface
from diffusers import DiffusionPipeline
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
@spaces.GPU()
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
return image
examples = [
["a tiny astronaut hatching from an egg on the moon", 0, True, 1024, 1024, 4],
["a cat holding a sign that says hello world", 0, True, 1024, 1024, 4],
["an anime illustration of a wiener schnitzel", 0, True, 1024, 1024, 4],
]
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
description = """# FLUX.1 [schnell]
12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
[[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
"""
interface = CollectorInterface(
fn=infer,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt")
],
outputs=[
gr.Image(label="Result"),
],
additional_inputs=[
gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0),
gr.Checkbox(label="Randomize seed", value=True),
gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024),
gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024),
gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4),
],
title="FLUX.1 [schnell] - with Dataset Viber data collection",
description=description,
examples=examples,
css=css,
dataset_name="image-generation-flux1-schnell"
)
interface.launch()