import random import gradio as gr import numpy as np import spaces import torch from dataset_viber import CollectorInterface from diffusers import DiffusionPipeline dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 @spaces.GPU() def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) image = pipe( prompt=prompt, width=width, height=height, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=0.0 ).images[0] return image examples = [ ["a tiny astronaut hatching from an egg on the moon", 0, True, 1024, 1024, 4], ["a cat holding a sign that says hello world", 0, True, 1024, 1024, 4], ["an anime illustration of a wiener schnitzel", 0, True, 1024, 1024, 4], ] css = """ #col-container { margin: 0 auto; max-width: 520px; } """ description = """# FLUX.1 [schnell] 12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)] """ interface = CollectorInterface( fn=infer, inputs=[ gr.Textbox(label="Prompt", placeholder="Enter your prompt") ], outputs=[ gr.Image(label="Result"), ], additional_inputs=[ gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0), gr.Checkbox(label="Randomize seed", value=True), gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024), gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024), gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4), ], title="FLUX.1 [schnell] - with Dataset Viber data collection", description=description, examples=examples, css=css, dataset_name="image-generation-flux1-schnell" ) interface.launch()