File size: 3,126 Bytes
82ad0f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbcbdf5
 
 
82ad0f2
dbcbdf5
 
82ad0f2
 
 
 
 
a05e9a7
 
27f154c
 
a05e9a7
27f154c
82ad0f2
a05e9a7
 
82ad0f2
 
9d761af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from functools import partial
from random import randint

import gradio as gr
import torch
from tqdm import tqdm

from NestedPipeline import NestedStableDiffusionPipeline
from NestedScheduler import NestedScheduler


def run(prompt, outer, inner, random_seed, pipe):

    seed = 24 if not random_seed else randint(0, 10000)
    generator = torch.Generator(device).manual_seed(seed)
    outer_diffusion = tqdm(range(outer), desc="Outer Diffusion")
    inner_diffusion = tqdm(range(inner), desc="Inner Diffusion")

    cur = [0, 0]
    for i, j, im in pipe(prompt, num_inference_steps=outer, num_inner_steps=inner, generator=generator):
        if cur[-1] != j:
            inner_diffusion.update()
            cur[-1] = j
        if cur[0] != i and i != outer:
            cur[0] = i
            outer_diffusion.update()
            cur[-1] = 0
            inner_diffusion = tqdm(range(inner), desc="Inner Diffusion")
        elif cur[0] != i:
            outer_diffusion.update()
        monospace_s, monospace_e = "<p style=\"font-family:'Lucida Console', monospace\">", "</p>"
        yield f"{monospace_s}{outer_diffusion.__str__().replace(' ', '&nbsp;')}{monospace_e} \n {monospace_s}{inner_diffusion.__str__().replace(' ', '&nbsp;')}{monospace_e}", im[0]

if __name__ == "__main__":
    scheduler = NestedScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
                                prediction_type='sample', clip_sample=False, set_alpha_to_one=False)
    fp16 = False
    if fp16:
        pipe = NestedStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16",
                                                             torch_dtype=torch.float16, scheduler=scheduler)
    else:
        pipe = NestedStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe.to(device)
    interface = partial(run, pipe=pipe)
    demo = gr.Interface(
        fn=interface,
        title="Nested Diffusion",
        description="<h3 style=\"text-align: center;\">Anytime text-to-image generation with Stable Diffusion v1.5</h3>\n<p style=\"text-align: center;\"><b>Help: </b>Type the desired prompt in the prompt box, and adjust the number of outer and inner steps to use. Using more steps takes more time, but should create a better image.<br>For more information on Nested Diffuion: <a href=\"https://github.com/noamelata/NestedDiffusion\">Github</a>, <a href=\"https://arxiv.org/abs/2305.19066\">arXiv</a></p>",
        inputs=[gr.Textbox(value="a photograph of a nest with a blue egg inside", label="Prompt"),
                gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Outer Steps"),
                gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Inner Steps"),
                gr.Checkbox(label="Random Seed")],
        outputs=[gr.HTML(), gr.Image(shape=[512, 512], elem_id="output_image").style(width=512, height=512)],
        allow_flagging="never",
        thumbnail="figures/Nested_Egg.png"
    )
    demo.queue()
    demo.launch()