svjack commited on
Commit
45e2d99
·
verified ·
1 Parent(s): 69414d8

Create depth_app.py

Browse files
Files changed (1) hide show
  1. depth_app.py +218 -0
depth_app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import cv2
4
+ import gradio as gr
5
+ import numpy as np
6
+ from huggingface_hub import snapshot_download
7
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
8
+ from diffusers.utils import load_image
9
+ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline
10
+ from kolors.models.modeling_chatglm import ChatGLMModel
11
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
12
+ from kolors.models.controlnet import ControlNetModel
13
+ from diffusers import AutoencoderKL
14
+ from kolors.models.unet_2d_condition import UNet2DConditionModel
15
+ from diffusers import EulerDiscreteScheduler
16
+ from PIL import Image
17
+ from annotator.midas import MidasDetector
18
+ from annotator.util import resize_image, HWC3
19
+
20
+ device = "cuda"
21
+ ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
22
+ ckpt_dir_ipa = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus")
23
+ ckpt_dir_depth = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Depth")
24
+
25
+ text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
26
+ tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
27
+ vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
28
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
29
+ unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
30
+
31
+ controlnet_depth = ControlNetModel.from_pretrained(f"{ckpt_dir_depth}", revision=None).half().to(device)
32
+
33
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_dir_ipa}/image_encoder', ignore_mismatched_sizes=True).to(dtype=torch.float16, device=device)
34
+ ip_img_size = 336
35
+ clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
36
+
37
+ pipe_depth = StableDiffusionXLControlNetImg2ImgPipeline(
38
+ vae=vae,
39
+ controlnet=controlnet_depth,
40
+ text_encoder=text_encoder,
41
+ tokenizer=tokenizer,
42
+ unet=unet,
43
+ scheduler=scheduler,
44
+ image_encoder=image_encoder,
45
+ feature_extractor=clip_image_processor,
46
+ force_zeros_for_empty_prompt=False
47
+ )
48
+
49
+ pipe_depth.load_ip_adapter(f'{ckpt_dir_ipa}', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
50
+
51
+ model_midas = MidasDetector()
52
+
53
+ def process_depth_condition_midas(img, res=1024):
54
+ h, w, _ = img.shape
55
+ img = resize_image(HWC3(img), res)
56
+ result = HWC3(model_midas(img))
57
+ result = cv2.resize(result, (w, h))
58
+ return Image.fromarray(result)
59
+
60
+ MAX_SEED = np.iinfo(np.int32).max
61
+ MAX_IMAGE_SIZE = 1024
62
+
63
+ def infer_depth(prompt,
64
+ image=None,
65
+ ipa_img=None,
66
+ negative_prompt="nsfw,脸部阴影,低分辨率,糟糕的解剖结构、糟糕的手,缺失手指、质量最差、低质量、jpeg伪影、模糊、糟糕,黑脸,霓虹灯",
67
+ seed=66,
68
+ randomize_seed=False,
69
+ guidance_scale=5.0,
70
+ num_inference_steps=50,
71
+ controlnet_conditioning_scale=0.5,
72
+ control_guidance_end=0.9,
73
+ strength=1.0,
74
+ ip_scale=0.5,
75
+ ):
76
+ if randomize_seed:
77
+ seed = random.randint(0, MAX_SEED)
78
+ generator = torch.Generator().manual_seed(seed)
79
+ init_image = resize_image(image, MAX_IMAGE_SIZE)
80
+ pipe = pipe_depth.to("cuda")
81
+ pipe.set_ip_adapter_scale([ip_scale])
82
+ condi_img = process_depth_condition_midas(np.array(init_image), MAX_IMAGE_SIZE)
83
+ image = pipe(
84
+ prompt=prompt,
85
+ image=init_image,
86
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
87
+ control_guidance_end=control_guidance_end,
88
+ ip_adapter_image=[ipa_img],
89
+ strength=strength,
90
+ control_image=condi_img,
91
+ negative_prompt=negative_prompt,
92
+ num_inference_steps=num_inference_steps,
93
+ guidance_scale=guidance_scale,
94
+ num_images_per_prompt=1,
95
+ generator=generator,
96
+ ).images[0]
97
+ return [condi_img, image], seed
98
+
99
+ depth_examples = [
100
+ ["一个漂亮的女孩,最好的质量,超细节,8K画质",
101
+ "image/1.png", "image/woman_1.png"],
102
+ ]
103
+
104
+ css = """
105
+ #col-left {
106
+ margin: 0 auto;
107
+ max-width: 600px;
108
+ }
109
+ #col-right {
110
+ margin: 0 auto;
111
+ max-width: 750px;
112
+ }
113
+ #button {
114
+ color: blue;
115
+ }
116
+ """
117
+
118
+ def load_description(fp):
119
+ with open(fp, 'r', encoding='utf-8') as f:
120
+ content = f.read()
121
+ return content
122
+
123
+ with gr.Blocks(css=css) as DepthApp:
124
+ gr.HTML(load_description("assets/title.md"))
125
+ with gr.Row():
126
+ with gr.Column(elem_id="col-left"):
127
+ with gr.Row():
128
+ prompt = gr.Textbox(
129
+ label="Prompt",
130
+ placeholder="Enter your prompt",
131
+ lines=2
132
+ )
133
+ with gr.Row():
134
+ image = gr.Image(label="Image", type="pil")
135
+ ipa_image = gr.Image(label="IP-Adapter-Image", type="pil")
136
+ with gr.Accordion("Advanced Settings", open=False):
137
+ negative_prompt = gr.Textbox(
138
+ label="Negative prompt",
139
+ placeholder="Enter a negative prompt",
140
+ visible=True,
141
+ value="nsfw,脸部阴影,低分辨率,糟糕的解剖结构、糟糕的手,缺失手指、质量最差、低质量、jpeg伪影、模糊、糟糕,黑脸,霓虹灯"
142
+ )
143
+ seed = gr.Slider(
144
+ label="Seed",
145
+ minimum=0,
146
+ maximum=MAX_SEED,
147
+ step=1,
148
+ value=0,
149
+ )
150
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
151
+ with gr.Row():
152
+ guidance_scale = gr.Slider(
153
+ label="Guidance scale",
154
+ minimum=0.0,
155
+ maximum=10.0,
156
+ step=0.1,
157
+ value=5.0,
158
+ )
159
+ num_inference_steps = gr.Slider(
160
+ label="Number of inference steps",
161
+ minimum=10,
162
+ maximum=50,
163
+ step=1,
164
+ value=30,
165
+ )
166
+ with gr.Row():
167
+ controlnet_conditioning_scale = gr.Slider(
168
+ label="Controlnet Conditioning Scale",
169
+ minimum=0.0,
170
+ maximum=1.0,
171
+ step=0.1,
172
+ value=0.5,
173
+ )
174
+ control_guidance_end = gr.Slider(
175
+ label="Control Guidance End",
176
+ minimum=0.0,
177
+ maximum=1.0,
178
+ step=0.1,
179
+ value=0.9,
180
+ )
181
+ with gr.Row():
182
+ strength = gr.Slider(
183
+ label="Strength",
184
+ minimum=0.0,
185
+ maximum=1.0,
186
+ step=0.1,
187
+ value=1.0,
188
+ )
189
+ ip_scale = gr.Slider(
190
+ label="IP_Scale",
191
+ minimum=0.0,
192
+ maximum=1.0,
193
+ step=0.1,
194
+ value=0.5,
195
+ )
196
+ with gr.Row():
197
+ depth_button = gr.Button("Depth", elem_id="button")
198
+
199
+ with gr.Column(elem_id="col-right"):
200
+ result = gr.Gallery(label="Result", show_label=False, columns=2)
201
+ seed_used = gr.Number(label="Seed Used")
202
+
203
+ with gr.Row():
204
+ gr.Examples(
205
+ fn=infer_depth,
206
+ examples=depth_examples,
207
+ inputs=[prompt, image, ipa_image],
208
+ outputs=[result, seed_used],
209
+ label="Depth"
210
+ )
211
+
212
+ depth_button.click(
213
+ fn=infer_depth,
214
+ inputs=[prompt, image, ipa_image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength, ip_scale],
215
+ outputs=[result, seed_used]
216
+ )
217
+
218
+ DepthApp.queue().launch(debug=True, share=True)