Spaces:
Runtime error
Runtime error
from functools import partial | |
from random import randint | |
import gradio as gr | |
import torch | |
from tqdm import tqdm | |
from NestedPipeline import NestedStableDiffusionPipeline | |
from NestedScheduler import NestedScheduler | |
def run(prompt, outer, inner, random_seed, pipe): | |
seed = 24 if not random_seed else randint(0, 10000) | |
generator = torch.Generator(device).manual_seed(seed) | |
outer_diffusion = tqdm(range(outer), desc="Outer Diffusion") | |
inner_diffusion = tqdm(range(inner), desc="Inner Diffusion") | |
cur = [0, 0] | |
for i, j, im in pipe(prompt, num_inference_steps=outer, num_inner_steps=inner, generator=generator): | |
if cur[-1] != j: | |
inner_diffusion.update() | |
cur[-1] = j | |
if cur[0] != i and i != outer: | |
cur[0] = i | |
outer_diffusion.update() | |
cur[-1] = 0 | |
inner_diffusion = tqdm(range(inner), desc="Inner Diffusion") | |
elif cur[0] != i: | |
outer_diffusion.update() | |
monospace_s, monospace_e = "<p style=\"font-family:'Lucida Console', monospace\">", "</p>" | |
yield f"{monospace_s}{outer_diffusion.__str__().replace(' ', ' ')}{monospace_e} \n {monospace_s}{inner_diffusion.__str__().replace(' ', ' ')}{monospace_e}", im[0] | |
if __name__ == "__main__": | |
scheduler = NestedScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", | |
prediction_type='sample', clip_sample=False, set_alpha_to_one=False) | |
pipe = NestedStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", | |
torch_dtype=torch.float16, scheduler=scheduler) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe.to(device) | |
interface = partial(run, pipe=pipe) | |
demo = gr.Interface( | |
fn=interface, | |
inputs=[gr.Textbox(value="a photograph of a nest with a blue egg inside"), | |
gr.Slider(minimum=1, maximum=10, value=4, step=1), | |
gr.Slider(minimum=5, maximum=50, value=10, step=1), | |
"checkbox"], | |
outputs=[gr.HTML(), gr.Image(shape=[512, 512], elem_id="output_image").style(width=512, height=512)], | |
# css=".output_image {height: 10% !important; width: 10% !important;}", | |
allow_flagging="never" | |
) | |
demo.queue() | |
demo.launch(share=True, server_name="132.68.39.164", server_port=7861) | |