File size: 2,438 Bytes
82ad0f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from functools import partial
from random import randint

import gradio as gr
import torch
from tqdm import tqdm

from NestedPipeline import NestedStableDiffusionPipeline
from NestedScheduler import NestedScheduler


def run(prompt, outer, inner, random_seed, pipe):

    seed = 24 if not random_seed else randint(0, 10000)
    generator = torch.Generator(device).manual_seed(seed)
    outer_diffusion = tqdm(range(outer), desc="Outer Diffusion")
    inner_diffusion = tqdm(range(inner), desc="Inner Diffusion")

    cur = [0, 0]
    for i, j, im in pipe(prompt, num_inference_steps=outer, num_inner_steps=inner, generator=generator):
        if cur[-1] != j:
            inner_diffusion.update()
            cur[-1] = j
        if cur[0] != i and i != outer:
            cur[0] = i
            outer_diffusion.update()
            cur[-1] = 0
            inner_diffusion = tqdm(range(inner), desc="Inner Diffusion")
        elif cur[0] != i:
            outer_diffusion.update()
        monospace_s, monospace_e = "<p style=\"font-family:'Lucida Console', monospace\">", "</p>"
        yield f"{monospace_s}{outer_diffusion.__str__().replace(' ', '&nbsp;')}{monospace_e} \n {monospace_s}{inner_diffusion.__str__().replace(' ', '&nbsp;')}{monospace_e}", im[0]

if __name__ == "__main__":
    scheduler = NestedScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
                                prediction_type='sample', clip_sample=False, set_alpha_to_one=False)
    pipe = NestedStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16",
                                                             torch_dtype=torch.float16, scheduler=scheduler)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe.to(device)
    interface = partial(run, pipe=pipe)
    demo = gr.Interface(
        fn=interface,
        inputs=[gr.Textbox(value="a photograph of a nest with a blue egg inside"),
                gr.Slider(minimum=1, maximum=10, value=4, step=1),
                gr.Slider(minimum=5, maximum=50, value=10, step=1),
                "checkbox"],
        outputs=[gr.HTML(), gr.Image(shape=[512, 512], elem_id="output_image").style(width=512, height=512)],
        # css=".output_image {height: 10% !important; width: 10% !important;}",
        allow_flagging="never"
    )
    demo.queue()
    demo.launch(share=True, server_name="132.68.39.164", server_port=7861)