File size: 2,388 Bytes
48dd920
 
8ccf632
 
 
 
48dd920
81b26b5
06f0278
 
8ccf632
 
4ea3b6f
8ccf632
 
06f0278
8ccf632
 
 
54192f0
 
8ccf632
 
48dd920
 
 
 
 
06f0278
48dd920
 
 
8ccf632
48dd920
 
 
8ccf632
 
48dd920
8ccf632
 
e2944a6
8ccf632
 
 
48dd920
6ebb7df
4ea3b6f
48dd920
8ccf632
48dd920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ccf632
48dd920
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import random

import gradio as gr
import numpy as np
import spaces
import torch
from dataset_viber import CollectorInterface
from diffusers import DiffusionPipeline

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

@spaces.GPU()
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    image = pipe(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=num_inference_steps,
            generator=generator,
            guidance_scale=0.0
    ).images[0]
    return image

examples = [
    ["a tiny astronaut hatching from an egg on the moon", 0, True, 1024, 1024, 4],
    ["a cat holding a sign that says hello world", 0, True, 1024, 1024, 4],
    ["an anime illustration of a wiener schnitzel", 0, True, 1024, 1024, 4],
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

description = """# FLUX.1 [schnell]
12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
[[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
"""

interface = CollectorInterface(
    fn=infer,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Enter your prompt")
    ],
    outputs=[
        gr.Image(label="Result"),
    ],
    additional_inputs=[
        gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0),
        gr.Checkbox(label="Randomize seed", value=True),
        gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024),
        gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024),
        gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4),
    ],
    title="FLUX.1 [schnell] - with Dataset Viber data collection",
    description=description,
    examples=examples,
    css=css,
    dataset_name="image-generation-flux1-schnell"
)

interface.launch()