File size: 6,983 Bytes
d4ab1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275d668
43a8a9b
d4ab1b2
e3761d9
aa9bed2
b5259ef
 
05afdd5
 
e3761d9
d4ab1b2
511d4d6
d4ab1b2
 
 
 
 
 
aa9bed2
d4ab1b2
 
 
b6d1bbe
d4ab1b2
 
 
 
 
 
 
53ffbf2
d4ab1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41df45d
d4ab1b2
 
 
 
 
 
 
 
 
 
b6d1bbe
d4ab1b2
 
 
 
 
 
41df45d
d4ab1b2
 
 
 
 
 
 
 
 
 
 
 
4f97608
 
 
 
d4ab1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f97608
d4ab1b2
 
 
 
4be6f60
d4ab1b2
 
41df45d
 
 
 
4f97608
41df45d
d4ab1b2
 
 
 
 
 
 
 
 
43a8a9b
d4ab1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41df45d
d4ab1b2
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/env python
from __future__ import annotations

import os
import random
import time

import gradio as gr
import numpy as np
import PIL.Image

from huggingface_hub import snapshot_download
from diffusers import DiffusionPipeline

from optimum.intel.openvino.modeling_diffusion import OVModelVaeDecoder, OVBaseModel, OVStableDiffusionPipeline

import os
from tqdm import tqdm
import gradio_user_history as gr_user_history

from concurrent.futures import ThreadPoolExecutor
import uuid

DESCRIPTION = '''# Latent Consistency Model OpenVINO CPU TAESD
Based on [Latency Consistency Model OpenVINO CPU](https://huggingface.co/spaces/deinferno/Latent_Consistency_Model_OpenVino_CPU) HF space 

Converted from [SoteMix](https://huggingface.co/Disty0/SoteMix) to [LCM_SoteMix](https://huggingface.co/Disty0/LCM_SoteMix) and then to OpenVINO

This model is for Anime art style.

Slower but higher quality version with Full VAE: [LCM_SoteMix_OpenVINO_CPU_Space](https://huggingface.co/spaces/Disty0/LCM_SoteMix_OpenVINO_CPU_Space)

[LCM Project page](https://latent-consistency-models.github.io)

<p>Running on a Dual Core CPU with OpenVINO Acceleration</p>
'''

MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1"

model_id = "Disty0/LCM_SoteMix"
batch_size = -1
width = int(os.getenv("IMAGE_WIDTH", "512"))
height = int(os.getenv("IMAGE_HEIGHT", "512"))
num_images = int(os.getenv("NUM_IMAGES", "1"))
guidance_scale = float(os.getenv("GUIDANCE_SCALE", "1.0"))

class CustomOVModelVaeDecoder(OVModelVaeDecoder):
    def __init__(
        self, model: openvino.runtime.Model, parent_model: OVBaseModel, ov_config: Optional[Dict[str, str]] = None, model_dir: str = None,
    ):
        super(OVModelVaeDecoder, self).__init__(model, parent_model, ov_config, "vae_decoder", model_dir)

pipe = OVStableDiffusionPipeline.from_pretrained(model_id, compile = False, ov_config = {"CACHE_DIR":""})

# Inject TAESD

taesd_dir = snapshot_download(repo_id="deinferno/taesd-openvino")
pipe.vae_decoder = CustomOVModelVaeDecoder(model = OVBaseModel.load_model(f"{taesd_dir}/vae_decoder/openvino_model.xml"), parent_model = pipe, model_dir = taesd_dir)

pipe.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images)
pipe.compile()

def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed

def save_image(img, profile: gr.OAuthProfile | None, metadata: dict):
    unique_name = str(uuid.uuid4()) + '.png'
    img.save(unique_name)
    gr_user_history.save_image(label=metadata["prompt"], image=img, profile=profile, metadata=metadata)
    return unique_name

def save_images(image_array, profile: gr.OAuthProfile | None, metadata: dict):
    paths = []
    with ThreadPoolExecutor() as executor:
        paths = list(executor.map(save_image, image_array, [profile]*len(image_array), [metadata]*len(image_array)))
    return paths

def generate(
    prompt: str,
    negative_prompt: str,
    seed: int = 0,
    num_inference_steps: int = 4,
    randomize_seed: bool = False,
    progress = gr.Progress(track_tqdm=True),
    profile: gr.OAuthProfile | None = None,
) -> PIL.Image.Image:
    global batch_size
    global width
    global height
    global num_images
    global guidance_scale

    seed = randomize_seed_fn(seed, randomize_seed)
    np.random.seed(seed)
    start_time = time.time()
    result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        num_images_per_prompt=num_images,
        output_type="pil",
    ).images
    paths = save_images(result, profile, metadata={"prompt": prompt, "seed": seed, "width": width, "height": height, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps})
    print(time.time() - start_time)
    return paths, seed

examples = [
    "masterpiece, best quality, highres, 1girl, solo,",
    "masterpiece, best quality, highres, 1girl, solo, pov, scenery, wind, petals, rim lighting, shrine, lens flare, light scatter, depth of field, lens refraction,",
    "masterpiece, best quality, highres, 1girl, solo, scenery, wind, petals, rim lighting, shrine, lens flare, light scatter, depth of field, lens refraction, dark red hair, long hair, blue eyes, straight hair, cat ears, medium breasts, mature female, white sweater,",
    "masterpiece, best quality, highres, 1girl, solo, supernova, abstract, abstract background, bloom, swirling lights, light particles, fire, galaxy, floating, romanticized, blush, looking at viewer, dark red hair, long hair, blue eyes, straight hair, cat ears, medium breasts, mature female, white sweater,",
]

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(
        value="Duplicate Space for private use",
        elem_id="duplicate-button",
        visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
    )
    with gr.Group():
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                value="masterpiece, best quality, highres, 1girl, solo,",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        result = gr.Gallery(
            label="Generated images", show_label=False, elem_id="gallery"
        )
    with gr.Accordion("Advanced options", open=False):
        with gr.Row():
            negative_prompt = gr.Text(
	        label="Negative Prompt",
	        max_lines=1,
	        value="worst quality, low quality, lowres, monochrome, realistic,",
	    )
        seed = gr.Slider(
            label="Seed",
            minimum=0,
            maximum=MAX_SEED,
            step=1,
            value=0,
            randomize=True
        )
        randomize_seed = gr.Checkbox(label="Randomize seed across runs", value=True)
        with gr.Row():
            num_inference_steps = gr.Slider(
                label="Number of inference steps for base",
                minimum=1,
                maximum=8,
                step=1,
                value=4,
            )

    with gr.Accordion("Past generations", open=False):
        gr_user_history.render()
    
    gr.Examples(
        examples=examples,
        inputs=prompt,
        outputs=result,
        fn=generate,
        cache_examples=CACHE_EXAMPLES,
    )

    gr.on(
        triggers=[
            prompt.submit,
            run_button.click,
        ],
        fn=generate,
        inputs=[
            prompt,
            negative_prompt,
            seed,
            num_inference_steps,
            randomize_seed
        ],
        outputs=[result, seed],
        api_name="run",
    )

if __name__ == "__main__":
    demo.queue(api_open=False)
    # demo.queue(max_size=20).launch()
    demo.launch()