shgao commited on
Commit
f14200d
Β·
1 Parent(s): ecdaa2c

update new demo

Browse files
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import gradio as gr
2
-
3
 
4
  from sam2edit import create_demo as create_demo_edit_anything
5
- # from sam2image import create_demo as create_demo_generate_anything
6
-
 
 
 
7
 
8
  DESCRIPTION = f'''# [Edit Anything](https://github.com/sail-sg/EditAnything)
9
  **Edit anything and keep the layout by segmenting anything in the image.**
@@ -12,13 +15,45 @@ SHARED_UI_WARNING = f'''### [NOTE] Inference may be slow in this shared UI.
12
  You can duplicate and use it with a paid private GPU.
13
  <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
14
  '''
 
 
 
 
 
 
 
 
15
  with gr.Blocks() as demo:
16
  gr.Markdown(DESCRIPTION)
17
- gr.Markdown(SHARED_UI_WARNING)
18
  with gr.Tabs():
19
- with gr.TabItem('Edit Anything'):
20
- create_demo_edit_anything()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # with gr.TabItem('Generate Anything'):
22
  # create_demo_generate_anything()
 
 
23
 
24
  demo.queue(api_open=False).launch()
 
1
  import gradio as gr
2
+ import os
3
 
4
  from sam2edit import create_demo as create_demo_edit_anything
5
+ from sam2image import create_demo as create_demo_generate_anything
6
+ from sam2edit_beauty import create_demo as create_demo_beauty
7
+ from sam2edit_handsome import create_demo as create_demo_handsome
8
+ from sam2edit_lora import EditAnythingLoraModel, init_sam_model, init_blip_processor, init_blip_model
9
+ from huggingface_hub import hf_hub_download, snapshot_download
10
 
11
  DESCRIPTION = f'''# [Edit Anything](https://github.com/sail-sg/EditAnything)
12
  **Edit anything and keep the layout by segmenting anything in the image.**
 
15
  You can duplicate and use it with a paid private GPU.
16
  <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
17
  '''
18
+
19
+ #
20
+ sam_generator = init_sam_model()
21
+ blip_processor = init_blip_processor()
22
+ blip_model = init_blip_model()
23
+
24
+ sd_models_path = snapshot_download("shgao/sdmodels")
25
+
26
  with gr.Blocks() as demo:
27
  gr.Markdown(DESCRIPTION)
 
28
  with gr.Tabs():
29
+ with gr.TabItem('πŸ–ŒEdit Anything'):
30
+ model = EditAnythingLoraModel(base_model_path="stabilityai/stable-diffusion-2-inpainting",
31
+ controlmodel_name='LAION Pretrained(v0-4)-SD21',
32
+ lora_model_path=None, use_blip=True, extra_inpaint=False,
33
+ sam_generator=sam_generator,
34
+ blip_processor=blip_processor,
35
+ blip_model=blip_model)
36
+ create_demo_edit_anything(model.process)
37
+ with gr.TabItem(' πŸ‘©β€πŸ¦°Beauty Edit/Generation'):
38
+ lora_model_path = hf_hub_download(
39
+ "mlida/Cute_girl_mix4", "cuteGirlMix4_v10.safetensors")
40
+ model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "chilloutmix_NiPrunedFp32Fix"),
41
+ lora_model_path=lora_model_path, use_blip=True, extra_inpaint=True,
42
+ sam_generator=sam_generator,
43
+ blip_processor=blip_processor,
44
+ blip_model=blip_model
45
+ )
46
+ create_demo_beauty(model.process)
47
+ with gr.TabItem(' πŸ‘¨β€πŸŒΎHandsome Edit/Generation'):
48
+ model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "Realistic_Vision_V2.0"),
49
+ lora_model_path=None, use_blip=True, extra_inpaint=True,
50
+ sam_generator=sam_generator,
51
+ blip_processor=blip_processor,
52
+ blip_model=blip_model)
53
+ create_demo_handsome(model.process)
54
  # with gr.TabItem('Generate Anything'):
55
  # create_demo_generate_anything()
56
+ with gr.Tabs():
57
+ gr.Markdown(SHARED_UI_WARNING)
58
 
59
  demo.queue(api_open=False).launch()
sam2edit.py CHANGED
@@ -1,321 +1,85 @@
1
  # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
- from torchvision.utils import save_image
3
- from PIL import Image
4
- from pytorch_lightning import seed_everything
5
- import subprocess
6
- from collections import OrderedDict
7
-
8
- import cv2
9
- import einops
10
  import gradio as gr
11
- import numpy as np
12
- import torch
13
- import random
14
- import os
15
- import requests
16
- from io import BytesIO
17
- from annotator.util import resize_image, HWC3
18
-
19
- def create_demo():
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
21
- use_blip = True
22
- use_gradio = True
23
-
24
- # Diffusion init using diffusers.
25
-
26
- # diffusers==0.14.0 required.
27
- from diffusers import StableDiffusionInpaintPipeline
28
- from diffusers import ControlNetModel, UniPCMultistepScheduler
29
- from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
30
- from diffusers.utils import load_image
31
-
32
- base_model_path = "stabilityai/stable-diffusion-2-inpainting"
33
- config_dict = OrderedDict([('SAM Pretrained(v0-1): Good Natural Sense', 'shgao/edit-anything-v0-1-1'),
34
- ('LAION Pretrained(v0-3): Good Face', 'shgao/edit-anything-v0-3'),
35
- ('SD Inpainting: Not keep position', 'stabilityai/stable-diffusion-2-inpainting')
36
- ])
37
- def obtain_generation_model(controlnet_path):
38
- if controlnet_path=='stabilityai/stable-diffusion-2-inpainting':
39
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
40
- "stabilityai/stable-diffusion-2-inpainting",
41
- torch_dtype=torch.float16,
42
- )
43
- else:
44
- controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
45
- pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
46
- base_model_path, controlnet=controlnet, torch_dtype=torch.float16
47
- )
48
- # speed up diffusion process with faster scheduler and memory optimization
49
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
50
- # remove following line if xformers is not installed
51
- pipe.enable_xformers_memory_efficient_attention()
52
-
53
- pipe.enable_model_cpu_offload() # disable for now because of unknow bug in accelerate
54
- # pipe.to(device)
55
- return pipe
56
- global default_controlnet_path
57
- global pipe
58
- default_controlnet_path = config_dict['LAION Pretrained(v0-3): Good Face']
59
- pipe = obtain_generation_model(default_controlnet_path)
60
-
61
- # Segment-Anything init.
62
- # pip install git+https://github.com/facebookresearch/segment-anything.git
63
-
64
- try:
65
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
66
- except ImportError:
67
- print('segment_anything not installed')
68
- result = subprocess.run(['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
69
- print(f'Install segment_anything {result}')
70
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
71
- if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
72
- result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
73
- print(f'Download sam_vit_h_4b8939.pth {result}')
74
- sam_checkpoint = "models/sam_vit_h_4b8939.pth"
75
- model_type = "default"
76
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
77
- sam.to(device=device)
78
- mask_generator = SamAutomaticMaskGenerator(sam)
79
-
80
-
81
- # BLIP2 init.
82
- if use_blip:
83
- # need the latest transformers
84
- # pip install git+https://github.com/huggingface/transformers.git
85
- from transformers import AutoProcessor, Blip2ForConditionalGeneration
86
-
87
- processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
88
- blip_model = Blip2ForConditionalGeneration.from_pretrained(
89
- "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
90
-
91
-
92
- def get_blip2_text(image):
93
- inputs = processor(image, return_tensors="pt").to(device, torch.float16)
94
- generated_ids = blip_model.generate(**inputs, max_new_tokens=50)
95
- generated_text = processor.batch_decode(
96
- generated_ids, skip_special_tokens=True)[0].strip()
97
- return generated_text
98
-
99
-
100
- def show_anns(anns):
101
- if len(anns) == 0:
102
- return
103
- sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
104
- full_img = None
105
-
106
- # for ann in sorted_anns:
107
- for i in range(len(sorted_anns)):
108
- ann = anns[i]
109
- m = ann['segmentation']
110
- if full_img is None:
111
- full_img = np.zeros((m.shape[0], m.shape[1], 3))
112
- map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
113
- map[m != 0] = i + 1
114
- color_mask = np.random.random((1, 3)).tolist()[0]
115
- full_img[m != 0] = color_mask
116
- full_img = full_img*255
117
- # anno encoding from https://github.com/LUSSeg/ImageNet-S
118
- res = np.zeros((map.shape[0], map.shape[1], 3))
119
- res[:, :, 0] = map % 256
120
- res[:, :, 1] = map // 256
121
- res.astype(np.float32)
122
- full_img = Image.fromarray(np.uint8(full_img))
123
- return full_img, res
124
-
125
-
126
- def get_sam_control(image):
127
- masks = mask_generator.generate(image)
128
- full_img, res = show_anns(masks)
129
- return full_img, res
130
-
131
-
132
- def process(condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
133
-
134
- input_image = source_image["image"]
135
- if mask_image is None:
136
- if enable_all_generate:
137
- print("source_image", source_image["mask"].shape, input_image.shape,)
138
- print(source_image["mask"].max())
139
- mask_image = np.ones((input_image.shape[0], input_image.shape[1], 3))*255
140
- else:
141
- mask_image = source_image["mask"]
142
- global default_controlnet_path
143
- print("To Use:", config_dict[condition_model], "Current:", default_controlnet_path)
144
- if default_controlnet_path!=config_dict[condition_model]:
145
- print("Change condition model to:", config_dict[condition_model])
146
- global pipe
147
- pipe = obtain_generation_model(config_dict[condition_model])
148
- default_controlnet_path = config_dict[condition_model]
149
- torch.cuda.empty_cache()
150
-
151
- with torch.no_grad():
152
- if use_blip and (enable_auto_prompt or len(prompt) == 0):
153
- print("Generating text:")
154
- blip2_prompt = get_blip2_text(input_image)
155
- print("Generated text:", blip2_prompt)
156
- if len(prompt)>0:
157
- prompt = blip2_prompt + ',' + prompt
158
- else:
159
- prompt = blip2_prompt
160
- print("All text:", prompt)
161
-
162
- input_image = HWC3(input_image)
163
-
164
- img = resize_image(input_image, image_resolution)
165
- H, W, C = img.shape
166
-
167
- print("Generating SAM seg:")
168
- # the default SAM model is trained with 1024 size.
169
- full_segmask, detected_map = get_sam_control(
170
- resize_image(input_image, detect_resolution))
171
-
172
- detected_map = HWC3(detected_map.astype(np.uint8))
173
- detected_map = cv2.resize(
174
- detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
175
-
176
- control = torch.from_numpy(
177
- detected_map.copy()).float().cuda()
178
- control = torch.stack([control for _ in range(num_samples)], dim=0)
179
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
180
-
181
- mask_image = HWC3(mask_image.astype(np.uint8))
182
- mask_image = cv2.resize(
183
- mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
184
- mask_image = Image.fromarray(mask_image)
185
-
186
-
187
- if seed == -1:
188
- seed = random.randint(0, 65535)
189
- seed_everything(seed)
190
- generator = torch.manual_seed(seed)
191
- if condition_model=='SD Inpainting: Not keep position':
192
- x_samples = pipe(
193
- image=img,
194
- mask_image=mask_image,
195
- prompt=[prompt + ', ' + a_prompt] * num_samples,
196
- negative_prompt=[n_prompt] * num_samples,
197
- num_images_per_prompt=num_samples,
198
- num_inference_steps=ddim_steps,
199
- generator=generator,
200
- height=H,
201
- width=W,
202
- ).images
203
- else:
204
- x_samples = pipe(
205
- image=img,
206
- mask_image=mask_image,
207
- prompt=[prompt + ', ' + a_prompt] * num_samples,
208
- negative_prompt=[n_prompt] * num_samples,
209
- num_images_per_prompt=num_samples,
210
- num_inference_steps=ddim_steps,
211
- generator=generator,
212
- controlnet_conditioning_image=control.type(torch.float16),
213
- height=H,
214
- width=W,
215
- controlnet_conditioning_scale=control_scale,
216
- ).images
217
-
218
-
219
- results = [x_samples[i] for i in range(num_samples)]
220
- return [full_segmask, mask_image] + results, prompt
221
-
222
-
223
- def download_image(url):
224
- response = requests.get(url)
225
- return Image.open(BytesIO(response.content)).convert("RGB")
226
-
227
- # disable gradio when not using GUI.
228
- if not use_gradio:
229
- # This part is not updated, it's just a example to use it without GUI.
230
- image_path = "../data/samples/sa_223750.jpg"
231
- mask_path = "../data/samples/sa_223750inpaint.png"
232
- input_image = Image.open(image_path)
233
- mask_image = Image.open(mask_path)
234
- enable_auto_prompt = True
235
- input_image = np.array(input_image, dtype=np.uint8)
236
- mask_image = np.array(mask_image, dtype=np.uint8)
237
- prompt = "esplendent sunset sky, red brick wall"
238
- a_prompt = 'best quality, extremely detailed'
239
- n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
240
- num_samples = 3
241
- image_resolution = 512
242
- detect_resolution = 512
243
- ddim_steps = 30
244
- guess_mode = False
245
- strength = 1.0
246
- scale = 9.0
247
- seed = -1
248
- eta = 0.0
249
-
250
- outputs = process(condition_model, input_image, mask_image, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
251
- detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta)
252
-
253
- image_list = []
254
- input_image = resize_image(input_image, 512)
255
- image_list.append(torch.tensor(input_image))
256
- for i in range(len(outputs)):
257
- each = outputs[i]
258
- if type(each) is not np.ndarray:
259
- each = np.array(each, dtype=np.uint8)
260
- each = resize_image(each, 512)
261
- print(i, each.shape)
262
- image_list.append(torch.tensor(each))
263
-
264
- image_list = torch.stack(image_list).permute(0, 3, 1, 2)
265
-
266
- save_image(image_list, "sample.jpg", nrow=3,
267
- normalize=True, value_range=(0, 255))
268
- else:
269
- print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
270
- block = gr.Blocks()
271
- with block as demo:
272
- with gr.Row():
273
- gr.Markdown(
274
- "## Edit Anything")
275
- with gr.Row():
276
- with gr.Column():
277
- source_image = gr.Image(source='upload',label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
278
- enable_all_generate = gr.Checkbox(label='Auto generation on all region.', value=False)
279
- prompt = gr.Textbox(label="Prompt (Text in the expected things of edited region)")
280
- enable_auto_prompt = gr.Checkbox(label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=True)
281
- control_scale = gr.Slider(
282
- label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
283
- run_button = gr.Button(label="Run")
284
  condition_model = gr.Dropdown(choices=list(config_dict.keys()),
285
- value=list(config_dict.keys())[1],
286
- label='Model',
287
- multiselect=False)
288
- num_samples = gr.Slider(
289
- label="Images", minimum=1, maximum=12, value=2, step=1)
290
- with gr.Accordion("Advanced options", open=False):
291
- mask_image = gr.Image(source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
292
- image_resolution = gr.Slider(
293
- label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
294
- strength = gr.Slider(
295
- label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
296
- guess_mode = gr.Checkbox(label='Guess Mode', value=False)
297
- detect_resolution = gr.Slider(
298
- label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
299
- ddim_steps = gr.Slider(
300
- label="Steps", minimum=1, maximum=100, value=30, step=1)
301
- scale = gr.Slider(
302
- label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
303
- seed = gr.Slider(label="Seed", minimum=-1,
304
- maximum=2147483647, step=1, randomize=True)
305
- eta = gr.Number(label="eta (DDIM)", value=0.0)
306
- a_prompt = gr.Textbox(
307
- label="Added Prompt", value='best quality, extremely detailed')
308
- n_prompt = gr.Textbox(label="Negative Prompt",
309
- value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
310
- with gr.Column():
311
- result_gallery = gr.Gallery(
312
- label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
313
- result_text = gr.Text(label='BLIP2+Human Prompt Text')
314
- ips = [condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
315
- detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
316
- run_button.click(fn=process, inputs=ips, outputs=[result_gallery, result_text])
317
- return demo
 
 
 
318
 
319
  if __name__ == '__main__':
320
- demo = create_demo()
 
 
 
321
  demo.queue().launch(server_name='0.0.0.0')
 
1
  # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
+ from diffusers.utils import load_image
4
+ from sam2edit_lora import EditAnythingLoraModel, config_dict
5
+
6
+
7
+ def create_demo(process):
8
+
9
+
10
+
11
+ print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
12
+ WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
13
+ We are not responsible for possible risks using this model.
14
+ '''
15
+ block = gr.Blocks()
16
+ with block as demo:
17
+ with gr.Row():
18
+ gr.Markdown(
19
+ "## Generate Your Beauty powered by EditAnything https://github.com/sail-sg/EditAnything ")
20
+ with gr.Row():
21
+ with gr.Column():
22
+ source_image = gr.Image(
23
+ source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
24
+ enable_all_generate = gr.Checkbox(
25
+ label='Auto generation on all region.', value=False)
26
+ prompt = gr.Textbox(
27
+ label="Prompt (Text in the expected things of edited region)")
28
+ enable_auto_prompt = gr.Checkbox(
29
+ label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
30
+ a_prompt = gr.Textbox(
31
+ label="Added Prompt", value='best quality, extremely detailed')
32
+ n_prompt = gr.Textbox(label="Negative Prompt",
33
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
34
+ control_scale = gr.Slider(
35
+ label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
36
+ run_button = gr.Button(label="Run")
37
+ num_samples = gr.Slider(
38
+ label="Images", minimum=1, maximum=12, value=2, step=1)
39
+ seed = gr.Slider(label="Seed", minimum=-1,
40
+ maximum=2147483647, step=1, randomize=True)
41
+ with gr.Accordion("Advanced options", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  condition_model = gr.Dropdown(choices=list(config_dict.keys()),
43
+ value=list(
44
+ config_dict.keys())[1],
45
+ label='Model',
46
+ multiselect=False)
47
+ mask_image = gr.Image(
48
+ source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
49
+ image_resolution = gr.Slider(
50
+ label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
51
+ strength = gr.Slider(
52
+ label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
53
+ guess_mode = gr.Checkbox(
54
+ label='Guess Mode', value=False)
55
+ detect_resolution = gr.Slider(
56
+ label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
57
+ ddim_steps = gr.Slider(
58
+ label="Steps", minimum=1, maximum=100, value=30, step=1)
59
+ scale = gr.Slider(
60
+ label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
61
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
62
+ with gr.Column():
63
+ result_gallery = gr.Gallery(
64
+ label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
65
+ result_text = gr.Text(label='BLIP2+Human Prompt Text')
66
+ ips = [condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
67
+ detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
68
+ run_button.click(fn=process, inputs=ips, outputs=[
69
+ result_gallery, result_text])
70
+ # with gr.Row():
71
+ # ex = gr.Examples(examples=examples, fn=process,
72
+ # inputs=[a_prompt, n_prompt, scale],
73
+ # outputs=[result_gallery],
74
+ # cache_examples=False)
75
+ with gr.Row():
76
+ gr.Markdown(WARNING_INFO)
77
+ return demo
78
+
79
 
80
  if __name__ == '__main__':
81
+ model = EditAnythingLoraModel(base_model_path="stabilityai/stable-diffusion-2-inpainting",
82
+ controlmodel_name='LAION Pretrained(v0-4)-SD21', extra_inpaint=False,
83
+ lora_model_path=None, use_blip=True)
84
+ demo = create_demo(model.process)
85
  demo.queue().launch(server_name='0.0.0.0')
sam2edit_beauty.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
+ import gradio as gr
3
+ from diffusers.utils import load_image
4
+ from sam2edit_lora import EditAnythingLoraModel, config_dict
5
+
6
+
7
+ def create_demo(process):
8
+
9
+ examples = [
10
+ ["dudou,1girl, beautiful face, solo, candle, brown hair, long hair, <lora:flowergirl:0.9>,ulzzang-6500-v1.1,(raw photo:1.2),((photorealistic:1.4))best quality ,masterpiece, illustration, an extremely delicate and beautiful, extremely detailed ,CG ,unity ,8k wallpaper, Amazing, finely detail, masterpiece,best quality,official art,extremely detailed CG unity 8k wallpaper,absurdres, incredibly absurdres, huge filesize, ultra-detailed, highres, extremely detailed,beautiful detailed girl, extremely detailed eyes and face, beautiful detailed eyes,cinematic lighting,1girl,see-through,looking at viewer,full body,full-body shot,outdoors,arms behind back,(chinese clothes) <lora:cuteGirlMix4_v10:1>",
11
+ "(((mole))),sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, bad anatomy,(long hair:1.4),DeepNegative,(fat:1.2),facing away, looking away,tilted head, lowres,bad anatomy,bad hands, text, error, missing fingers,extra digit, fewer digits, cropped, worstquality, low quality, normal quality,jpegartifacts,signature, watermark, username,blurry,bad feet,cropped,poorly drawn hands,poorly drawn face,mutation,deformed,worst quality,low quality,normal quality,jpeg artifacts,signature,watermark,extra fingers,fewer digits,extra limbs,extra arms,extra legs,malformed limbs,fused fingers,too many fingers,long neck,cross-eyed,mutated hands,polar lowres,bad body,bad proportions,gross proportions,text,error,missing fingers,missing arms,missing legs,extra digit, extra arms, extra leg, extra foot,(freckles),(mole:2)", 5],
12
+ ["best quality, ultra high res, (photorealistic:1.4), (detailed beautiful girl:1.4), (medium breasts:0.8), looking_at_viewer, Detailed facial details, beautiful detailed eyes, (multicolored|blue|pink hair: 1.2), green eyes, slender, haunting smile, (makeup:0.3), red lips, <lora:cuteGirlMix4_v10:0.7>, highly detailed clothes, (ulzzang-6500-v1.1:0.3)",
13
+ "EasyNegative, paintings, sketches, ugly, 3d, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, manboobs, backlight,(ugly:1.3), (duplicate:1.3), (morbid:1.2), (mutilated:1.2), (tranny:1.3), mutated hands, (poorly drawn hands:1.3), blurry, (bad anatomy:1.2), (bad proportions:1.3), extra limbs, (disfigured:1.3), (more than 2 nipples:1.3), (more than 1 navel:1.3), (missing arms:1.3), (extra legs:1.3), (fused fingers:1.6), (too many fingers:1.6), (unclear eyes:1.3), bad hands, missing fingers, extra digit, (futa:1.1), bad body, double navel, mutad arms, hused arms, (puffy nipples, dark areolae, dark nipples, rei no himo, inverted nipples, long nipples), NG_DeepNegative_V1_75t, pubic hair, fat rolls, obese, bad-picture-chill-75v", 8],
14
+ ["best quality, ultra high res, (photorealistic:1.4), (detailed beautiful girl:1.4), (medium breasts:0.8), looking_at_viewer, Detailed facial details, beautiful detailed eyes, (blue|pink hair), green eyes, slender, smile, (makeup:0.4), red lips, (full body, sitting, beach), <lora:cuteGirlMix4_v10:0.7>, highly detailed clothes, (ulzzang-6500-v1.1:0.3)",
15
+ "asyNegative, paintings, sketches, ugly, 3d, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, manboobs, backlight,(ugly:1.3), (duplicate:1.3), (morbid:1.2), (mutilated:1.2), (tranny:1.3), mutated hands, (poorly drawn hands:1.3), blurry, (bad anatomy:1.2), (bad proportions:1.3), extra limbs, (disfigured:1.3), (more than 2 nipples:1.3), (more than 1 navel:1.3), (missing arms:1.3), (extra legs:1.3), (fused fingers:1.6), (too many fingers:1.6), (unclear eyes:1.3), bad hands, missing fingers, extra digit, (futa:1.1), bad body, double navel, mutad arms, hused arms, (puffy nipples, dark areolae, dark nipples, rei no himo, inverted nipples, long nipples), NG_DeepNegative_V1_75t, pubic hair, fat rolls, obese, bad-picture-chill-75v", 7],
16
+ ["mix4, whole body shot, ((8k, RAW photo, highest quality, masterpiece), High detail RAW color photo professional close-up photo, shy expression, cute, beautiful detailed girl, detailed fingers, extremely detailed eyes and face, beautiful detailed nose, beautiful detailed eyes, long eyelashes, light on face, looking at viewer, (closed mouth:1.2), 1girl, cute, young, mature face, (full body:1.3), ((small breasts)), realistic face, realistic body, beautiful detailed thigh,s, same eyes color, (realistic, photo realism:1. 37), (highest quality), (best shadow), (best illustration), ultra high resolution, physics-based rendering, cinematic lighting), solo, 1girl, highly detailed, in office, detailed office, open cardigan, ponytail contorted, beautiful eyes ,sitting in office,dating, business suit, cross-laced clothes, collared shirt, beautiful breast, small breast, Chinese dress, white pantyhose, natural breasts, pink and white hair, <lora:cuteGirlMix4_v10:1>",
17
+ "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), cloth, underwear, bra, low-res, normal quality, ((monochrome)), ((grayscale)), skin spots, acne, skin blemishes, age spots, glans, bad nipples, long nipples, bad vagina, extra fingers,fewer fingers,strange fingers,bad hand, ng_deepnegative_v1_75t, bad-picture-chill-75v", 7]
18
+ ]
19
+
20
+ print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
21
+ WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
22
+ We are not responsible for possible risks using this model.
23
+
24
+ Lora model from https://civitai.com/models/14171/cutegirlmix4 Thanks!
25
+ '''
26
+ block = gr.Blocks()
27
+ with block as demo:
28
+ with gr.Row():
29
+ gr.Markdown(
30
+ "## Generate Your Beauty powered by EditAnything https://github.com/sail-sg/EditAnything ")
31
+ with gr.Row():
32
+ with gr.Column():
33
+ source_image = gr.Image(
34
+ source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
35
+ enable_all_generate = gr.Checkbox(
36
+ label='Auto generation on all region.', value=False)
37
+ prompt = gr.Textbox(
38
+ label="Prompt (Text in the expected things of edited region)")
39
+ enable_auto_prompt = gr.Checkbox(
40
+ label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
41
+ a_prompt = gr.Textbox(
42
+ label="Added Prompt", value='best quality, extremely detailed')
43
+ n_prompt = gr.Textbox(label="Negative Prompt",
44
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
45
+ control_scale = gr.Slider(
46
+ label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
47
+ run_button = gr.Button(label="Run")
48
+ num_samples = gr.Slider(
49
+ label="Images", minimum=1, maximum=12, value=2, step=1)
50
+ seed = gr.Slider(label="Seed", minimum=-1,
51
+ maximum=2147483647, step=1, randomize=True)
52
+ with gr.Accordion("Advanced options", open=False):
53
+ condition_model = gr.Dropdown(choices=list(config_dict.keys()),
54
+ value=list(
55
+ config_dict.keys())[0],
56
+ label='Model',
57
+ multiselect=False)
58
+ mask_image = gr.Image(
59
+ source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
60
+ image_resolution = gr.Slider(
61
+ label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
62
+ strength = gr.Slider(
63
+ label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
64
+ guess_mode = gr.Checkbox(
65
+ label='Guess Mode', value=False)
66
+ detect_resolution = gr.Slider(
67
+ label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
68
+ ddim_steps = gr.Slider(
69
+ label="Steps", minimum=1, maximum=100, value=30, step=1)
70
+ scale = gr.Slider(
71
+ label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
72
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
73
+ with gr.Column():
74
+ result_gallery = gr.Gallery(
75
+ label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
76
+ result_text = gr.Text(label='BLIP2+Human Prompt Text')
77
+ ips = [condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
78
+ detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
79
+ run_button.click(fn=process, inputs=ips, outputs=[
80
+ result_gallery, result_text])
81
+ with gr.Row():
82
+ ex = gr.Examples(examples=examples, fn=process,
83
+ inputs=[a_prompt, n_prompt, scale],
84
+ outputs=[result_gallery],
85
+ cache_examples=False)
86
+ with gr.Row():
87
+ gr.Markdown(WARNING_INFO)
88
+ return demo
89
+
90
+
91
+ if __name__ == '__main__':
92
+ model = EditAnythingLoraModel(base_model_path='../chilloutmix_NiPrunedFp32Fix',
93
+ lora_model_path='../40806/mix4', use_blip=True)
94
+ demo = create_demo(model.process)
95
+ demo.queue().launch(server_name='0.0.0.0')
sam2edit_handsome.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
+ import gradio as gr
3
+ from diffusers.utils import load_image
4
+ from sam2edit_lora import EditAnythingLoraModel, config_dict
5
+
6
+
7
+
8
+ def create_demo(process):
9
+
10
+ examples = [
11
+ ["1man, muscle,full body, vest, short straight hair, glasses, Gym, barbells, dumbbells, treadmills, boxing rings, squat racks, plates, dumbbell racks soft lighting, masterpiece, best quality, 8k uhd, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6>", "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 6],
12
+ ["1man, 25 years- old, full body, wearing long-sleeve white shirt and tie, muscular rand black suit, soft lighting, masterpiece, best quality, 8k uhd, dslr, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6> <lora:uncutPenisLora_v10:0.6>","(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",6],
13
+ ]
14
+
15
+ print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
16
+ WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
17
+ We are not responsible for possible risks using this model.
18
+ Base model from https://huggingface.co/SG161222/Realistic_Vision_V2.0 Thanks!
19
+ '''
20
+ block = gr.Blocks()
21
+ with block as demo:
22
+ with gr.Row():
23
+ gr.Markdown(
24
+ "## Generate Your Handsome powered by EditAnything https://github.com/sail-sg/EditAnything ")
25
+ with gr.Row():
26
+ with gr.Column():
27
+ source_image = gr.Image(
28
+ source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
29
+ enable_all_generate = gr.Checkbox(
30
+ label='Auto generation on all region.', value=False)
31
+ prompt = gr.Textbox(
32
+ label="Prompt (Text in the expected things of edited region)")
33
+ enable_auto_prompt = gr.Checkbox(
34
+ label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
35
+ a_prompt = gr.Textbox(
36
+ label="Added Prompt", value='best quality, extremely detailed')
37
+ n_prompt = gr.Textbox(label="Negative Prompt",
38
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
39
+ control_scale = gr.Slider(
40
+ label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
41
+ run_button = gr.Button(label="Run")
42
+ num_samples = gr.Slider(
43
+ label="Images", minimum=1, maximum=12, value=2, step=1)
44
+ seed = gr.Slider(label="Seed", minimum=-1,
45
+ maximum=2147483647, step=1, randomize=True)
46
+ with gr.Accordion("Advanced options", open=False):
47
+ condition_model = gr.Dropdown(choices=list(config_dict.keys()),
48
+ value=list(
49
+ config_dict.keys())[0],
50
+ label='Model',
51
+ multiselect=False)
52
+ mask_image = gr.Image(
53
+ source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
54
+ image_resolution = gr.Slider(
55
+ label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
56
+ strength = gr.Slider(
57
+ label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
58
+ guess_mode = gr.Checkbox(
59
+ label='Guess Mode', value=False)
60
+ detect_resolution = gr.Slider(
61
+ label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
62
+ ddim_steps = gr.Slider(
63
+ label="Steps", minimum=1, maximum=100, value=30, step=1)
64
+ scale = gr.Slider(
65
+ label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
66
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
67
+ with gr.Column():
68
+ result_gallery = gr.Gallery(
69
+ label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
70
+ result_text = gr.Text(label='BLIP2+Human Prompt Text')
71
+ ips = [condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
72
+ detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
73
+ run_button.click(fn=process, inputs=ips, outputs=[
74
+ result_gallery, result_text])
75
+ with gr.Row():
76
+ ex = gr.Examples(examples=examples, fn=process,
77
+ inputs=[a_prompt, n_prompt, scale],
78
+ outputs=[result_gallery],
79
+ cache_examples=False)
80
+ with gr.Row():
81
+ gr.Markdown(WARNING_INFO)
82
+ return demo
83
+
84
+
85
+
86
+ if __name__ == '__main__':
87
+ model = EditAnythingLoraModel(base_model_path= '../../gradio-rel/EditAnything/models/Realistic_Vision_V2.0',
88
+ lora_model_path= '../../gradio-rel/EditAnything/models/asianmale', use_blip=True)
89
+ demo = create_demo(model.process)
90
+ demo.queue().launch(server_name='0.0.0.0')
sam2edit_lora.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
+ from torchvision.utils import save_image
3
+ from PIL import Image
4
+ from pytorch_lightning import seed_everything
5
+ import subprocess
6
+ from collections import OrderedDict
7
+ import re
8
+ import cv2
9
+ import einops
10
+ import gradio as gr
11
+ import numpy as np
12
+ import torch
13
+ import random
14
+ import os
15
+ import requests
16
+ from io import BytesIO
17
+ from annotator.util import resize_image, HWC3
18
+
19
+ import torch
20
+ from safetensors.torch import load_file
21
+ from collections import defaultdict
22
+ from diffusers import StableDiffusionControlNetPipeline
23
+ from diffusers import ControlNetModel, UniPCMultistepScheduler
24
+ from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
25
+ # from utils.tmp import StableDiffusionControlNetInpaintPipeline
26
+ # need the latest transformers
27
+ # pip install git+https://github.com/huggingface/transformers.git
28
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
29
+
30
+ # Segment-Anything init.
31
+ # pip install git+https://github.com/facebookresearch/segment-anything.git
32
+ try:
33
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
34
+ except ImportError:
35
+ print('segment_anything not installed')
36
+ result = subprocess.run(
37
+ ['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
38
+ print(f'Install segment_anything {result}')
39
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
40
+ if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
41
+ result = subprocess.run(
42
+ ['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
43
+ print(f'Download sam_vit_h_4b8939.pth {result}')
44
+
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+
47
+ config_dict = OrderedDict([
48
+ ('LAION Pretrained(v0-4)-SD15', 'shgao/edit-anything-v0-4-sd15'),
49
+ ('LAION Pretrained(v0-4)-SD21', 'shgao/edit-anything-v0-4-sd21'),
50
+ ])
51
+
52
+
53
+ def init_sam_model():
54
+ sam_checkpoint = "models/sam_vit_h_4b8939.pth"
55
+ model_type = "default"
56
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
57
+ sam.to(device=device)
58
+ sam_generator = SamAutomaticMaskGenerator(sam)
59
+ return sam_generator
60
+
61
+
62
+ def init_blip_processor():
63
+ blip_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
64
+ return blip_processor
65
+
66
+
67
+ def init_blip_model():
68
+ blip_model = Blip2ForConditionalGeneration.from_pretrained(
69
+ "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
70
+ return blip_model
71
+
72
+
73
+ def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
74
+ # https://github.com/huggingface/diffusers/issues/2136
75
+ """ Get pipeline embeds for prompts bigger than the maxlength of the pipe
76
+ :param pipeline:
77
+ :param prompt:
78
+ :param negative_prompt:
79
+ :param device:
80
+ :return:
81
+ """
82
+ max_length = pipeline.tokenizer.model_max_length
83
+
84
+ # simple way to determine length of tokens
85
+ count_prompt = len(re.split(r', ', prompt))
86
+ count_negative_prompt = len(re.split(r', ', negative_prompt))
87
+
88
+ # create the tensor based on which prompt is longer
89
+ if count_prompt >= count_negative_prompt:
90
+ input_ids = pipeline.tokenizer(
91
+ prompt, return_tensors="pt", truncation=False).input_ids.to(device)
92
+ shape_max_length = input_ids.shape[-1]
93
+ negative_ids = pipeline.tokenizer(negative_prompt, truncation=False, padding="max_length",
94
+ max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
95
+ else:
96
+ negative_ids = pipeline.tokenizer(
97
+ negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
98
+ shape_max_length = negative_ids.shape[-1]
99
+ input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
100
+ max_length=shape_max_length).input_ids.to(device)
101
+
102
+ concat_embeds = []
103
+ neg_embeds = []
104
+ for i in range(0, shape_max_length, max_length):
105
+ concat_embeds.append(pipeline.text_encoder(
106
+ input_ids[:, i: i + max_length])[0])
107
+ neg_embeds.append(pipeline.text_encoder(
108
+ negative_ids[:, i: i + max_length])[0])
109
+
110
+ return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
111
+
112
+
113
+ def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
114
+ LORA_PREFIX_UNET = "lora_unet"
115
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
116
+ # load LoRA weight from .safetensors
117
+ if isinstance(checkpoint_path, str):
118
+
119
+ state_dict = load_file(checkpoint_path, device=device)
120
+
121
+ updates = defaultdict(dict)
122
+ for key, value in state_dict.items():
123
+ # it is suggested to print out the key, it usually will be something like below
124
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
125
+
126
+ layer, elem = key.split('.', 1)
127
+ updates[layer][elem] = value
128
+
129
+ # directly update weight in diffusers model
130
+ for layer, elems in updates.items():
131
+
132
+ if "text" in layer:
133
+ layer_infos = layer.split(
134
+ LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
135
+ curr_layer = pipeline.text_encoder
136
+ else:
137
+ layer_infos = layer.split(
138
+ LORA_PREFIX_UNET + "_")[-1].split("_")
139
+ curr_layer = pipeline.unet
140
+
141
+ # find the target layer
142
+ temp_name = layer_infos.pop(0)
143
+ while len(layer_infos) > -1:
144
+ try:
145
+ curr_layer = curr_layer.__getattr__(temp_name)
146
+ if len(layer_infos) > 0:
147
+ temp_name = layer_infos.pop(0)
148
+ elif len(layer_infos) == 0:
149
+ break
150
+ except Exception:
151
+ if len(temp_name) > 0:
152
+ temp_name += "_" + layer_infos.pop(0)
153
+ else:
154
+ temp_name = layer_infos.pop(0)
155
+
156
+ # get elements for this layer
157
+ weight_up = elems['lora_up.weight'].to(dtype)
158
+ weight_down = elems['lora_down.weight'].to(dtype)
159
+ alpha = elems['alpha']
160
+ if alpha:
161
+ alpha = alpha.item() / weight_up.shape[1]
162
+ else:
163
+ alpha = 1.0
164
+
165
+ # update weight
166
+ if len(weight_up.shape) == 4:
167
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(
168
+ 3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
169
+ else:
170
+ curr_layer.weight.data += multiplier * \
171
+ alpha * torch.mm(weight_up, weight_down)
172
+ else:
173
+ for ckptpath in checkpoint_path:
174
+ state_dict = load_file(ckptpath, device=device)
175
+
176
+ updates = defaultdict(dict)
177
+ for key, value in state_dict.items():
178
+ # it is suggested to print out the key, it usually will be something like below
179
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
180
+
181
+ layer, elem = key.split('.', 1)
182
+ updates[layer][elem] = value
183
+
184
+ # directly update weight in diffusers model
185
+ for layer, elems in updates.items():
186
+ if "text" in layer:
187
+ layer_infos = layer.split(
188
+ LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
189
+ curr_layer = pipeline.text_encoder
190
+ else:
191
+ layer_infos = layer.split(
192
+ LORA_PREFIX_UNET + "_")[-1].split("_")
193
+ curr_layer = pipeline.unet
194
+
195
+ # find the target layer
196
+ temp_name = layer_infos.pop(0)
197
+ while len(layer_infos) > -1:
198
+ try:
199
+ curr_layer = curr_layer.__getattr__(temp_name)
200
+ if len(layer_infos) > 0:
201
+ temp_name = layer_infos.pop(0)
202
+ elif len(layer_infos) == 0:
203
+ break
204
+ except Exception:
205
+ if len(temp_name) > 0:
206
+ temp_name += "_" + layer_infos.pop(0)
207
+ else:
208
+ temp_name = layer_infos.pop(0)
209
+
210
+ # get elements for this layer
211
+ weight_up = elems['lora_up.weight'].to(dtype)
212
+ weight_down = elems['lora_down.weight'].to(dtype)
213
+ alpha = elems['alpha']
214
+ if alpha:
215
+ alpha = alpha.item() / weight_up.shape[1]
216
+ else:
217
+ alpha = 1.0
218
+
219
+ # update weight
220
+ if len(weight_up.shape) == 4:
221
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(
222
+ 3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
223
+ else:
224
+ curr_layer.weight.data += multiplier * \
225
+ alpha * torch.mm(weight_up, weight_down)
226
+ return pipeline
227
+
228
+
229
+ def make_inpaint_condition(image, image_mask):
230
+ # image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
231
+ image = image / 255.0
232
+ print("img", image.max(), image.min(), image_mask.max(), image_mask.min())
233
+ # image_mask = np.array(image_mask.convert("L"))
234
+ assert image.shape[0:1] == image_mask.shape[0:
235
+ 1], "image and image_mask must have the same image size"
236
+ image[image_mask > 128] = -1.0 # set as masked pixel
237
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
238
+ image = torch.from_numpy(image)
239
+ return image
240
+
241
+
242
+ def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, generation_only=False, extra_inpaint=True):
243
+ if generation_only and extra_inpaint:
244
+ controlnet = ControlNetModel.from_pretrained(
245
+ controlnet_path, torch_dtype=torch.float16)
246
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
247
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
248
+ )
249
+ elif extra_inpaint:
250
+ print("Warning: ControlNet based inpainting model only support SD1.5 for now.")
251
+ controlnet = [
252
+ ControlNetModel.from_pretrained(
253
+ controlnet_path, torch_dtype=torch.float16),
254
+ ControlNetModel.from_pretrained(
255
+ 'lllyasviel/control_v11p_sd15_inpaint', torch_dtype=torch.float16), # inpainting controlnet
256
+ ]
257
+ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
258
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
259
+ )
260
+ else:
261
+ controlnet = ControlNetModel.from_pretrained(
262
+ controlnet_path, torch_dtype=torch.float16)
263
+ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
264
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
265
+ )
266
+ if lora_model_path is not None:
267
+ pipe = load_lora_weights(
268
+ pipe, [lora_model_path], 1.0, 'cpu', torch.float32)
269
+ # speed up diffusion process with faster scheduler and memory optimization
270
+ pipe.scheduler = UniPCMultistepScheduler.from_config(
271
+ pipe.scheduler.config)
272
+ # remove following line if xformers is not installed
273
+ pipe.enable_xformers_memory_efficient_attention()
274
+
275
+ pipe.enable_model_cpu_offload()
276
+ return pipe
277
+
278
+
279
+ def show_anns(anns):
280
+ if len(anns) == 0:
281
+ return
282
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
283
+ full_img = None
284
+
285
+ # for ann in sorted_anns:
286
+ for i in range(len(sorted_anns)):
287
+ ann = anns[i]
288
+ m = ann['segmentation']
289
+ if full_img is None:
290
+ full_img = np.zeros((m.shape[0], m.shape[1], 3))
291
+ map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
292
+ map[m != 0] = i + 1
293
+ color_mask = np.random.random((1, 3)).tolist()[0]
294
+ full_img[m != 0] = color_mask
295
+ full_img = full_img*255
296
+ # anno encoding from https://github.com/LUSSeg/ImageNet-S
297
+ res = np.zeros((map.shape[0], map.shape[1], 3))
298
+ res[:, :, 0] = map % 256
299
+ res[:, :, 1] = map // 256
300
+ res.astype(np.float32)
301
+ full_img = Image.fromarray(np.uint8(full_img))
302
+ return full_img, res
303
+
304
+
305
+ class EditAnythingLoraModel:
306
+ def __init__(self,
307
+ base_model_path='../chilloutmix_NiPrunedFp32Fix',
308
+ lora_model_path='../40806/mix4', use_blip=True,
309
+ blip_processor=None,
310
+ blip_model=None,
311
+ sam_generator=None,
312
+ controlmodel_name='LAION Pretrained(v0-4)-SD15',
313
+ # used when the base model is not an inpainting model.
314
+ extra_inpaint=True,
315
+ ):
316
+ self.device = device
317
+ self.use_blip = use_blip
318
+
319
+ # Diffusion init using diffusers.
320
+ self.default_controlnet_path = config_dict[controlmodel_name]
321
+ self.base_model_path = base_model_path
322
+ self.lora_model_path = lora_model_path
323
+ self.defalut_enable_all_generate = False
324
+ self.extra_inpaint = extra_inpaint
325
+ self.pipe = obtain_generation_model(
326
+ base_model_path, lora_model_path, self.default_controlnet_path, generation_only=False, extra_inpaint=extra_inpaint)
327
+
328
+ # Segment-Anything init.
329
+ if sam_generator is not None:
330
+ self.sam_generator = sam_generator
331
+ else:
332
+ self.sam_generator = init_sam_model()
333
+
334
+ # BLIP2 init.
335
+ if use_blip:
336
+ if blip_processor is not None:
337
+ self.blip_processor = blip_processor
338
+ else:
339
+ self.blip_processor = init_blip_processor()
340
+
341
+ if blip_model is not None:
342
+ self.blip_model = blip_model
343
+ else:
344
+ self.blip_model = init_blip_model()
345
+
346
+ def get_blip2_text(self, image):
347
+ inputs = self.blip_processor(image, return_tensors="pt").to(
348
+ self.device, torch.float16)
349
+ generated_ids = self.blip_model.generate(**inputs, max_new_tokens=50)
350
+ generated_text = self.blip_processor.batch_decode(
351
+ generated_ids, skip_special_tokens=True)[0].strip()
352
+ return generated_text
353
+
354
+ def get_sam_control(self, image):
355
+ masks = self.sam_generator.generate(image)
356
+ full_img, res = show_anns(masks)
357
+ return full_img, res
358
+
359
+ @torch.inference_mode()
360
+ def process(self, condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
361
+
362
+ input_image = source_image["image"]
363
+ if mask_image is None:
364
+ if enable_all_generate != self.defalut_enable_all_generate:
365
+ self.pipe = obtain_generation_model(
366
+ self.base_model_path, self.lora_model_path, config_dict[condition_model], enable_all_generate, self.extra_inpaint)
367
+ self.defalut_enable_all_generate = enable_all_generate
368
+ if enable_all_generate:
369
+ print("source_image",
370
+ source_image["mask"].shape, input_image.shape,)
371
+ mask_image = np.ones(
372
+ (input_image.shape[0], input_image.shape[1], 3))*255
373
+ else:
374
+ mask_image = source_image["mask"]
375
+ if self.default_controlnet_path != config_dict[condition_model]:
376
+ print("To Use:", config_dict[condition_model],
377
+ "Current:", self.default_controlnet_path)
378
+ print("Change condition model to:", config_dict[condition_model])
379
+ self.pipe = obtain_generation_model(
380
+ self.base_model_path, self.lora_model_path, config_dict[condition_model], enable_all_generate, self.extra_inpaint)
381
+ self.default_controlnet_path = config_dict[condition_model]
382
+ torch.cuda.empty_cache()
383
+
384
+ with torch.no_grad():
385
+ if self.use_blip and enable_auto_prompt:
386
+ print("Generating text:")
387
+ blip2_prompt = self.get_blip2_text(input_image)
388
+ print("Generated text:", blip2_prompt)
389
+ if len(prompt) > 0:
390
+ prompt = blip2_prompt + ',' + prompt
391
+ else:
392
+ prompt = blip2_prompt
393
+
394
+ input_image = HWC3(input_image)
395
+
396
+ img = resize_image(input_image, image_resolution)
397
+ H, W, C = img.shape
398
+
399
+ print("Generating SAM seg:")
400
+ # the default SAM model is trained with 1024 size.
401
+ full_segmask, detected_map = self.get_sam_control(
402
+ resize_image(input_image, detect_resolution))
403
+
404
+ detected_map = HWC3(detected_map.astype(np.uint8))
405
+ detected_map = cv2.resize(
406
+ detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
407
+
408
+ control = torch.from_numpy(
409
+ detected_map.copy()).float().cuda()
410
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
411
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
412
+
413
+ mask_image = HWC3(mask_image.astype(np.uint8))
414
+ mask_image = cv2.resize(
415
+ mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
416
+ if self.extra_inpaint:
417
+ inpaint_image = make_inpaint_condition(img, mask_image)
418
+ mask_image = Image.fromarray(mask_image)
419
+
420
+ if seed == -1:
421
+ seed = random.randint(0, 65535)
422
+ seed_everything(seed)
423
+ generator = torch.manual_seed(seed)
424
+ postive_prompt = prompt + ', ' + a_prompt
425
+ negative_prompt = n_prompt
426
+ prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
427
+ self.pipe, postive_prompt, negative_prompt, "cuda")
428
+ prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
429
+ negative_prompt_embeds = torch.cat(
430
+ [negative_prompt_embeds] * num_samples, dim=0)
431
+ if enable_all_generate and self.extra_inpaint:
432
+ print(control.shape, control_scale)
433
+ self.pipe.safety_checker = lambda images, clip_input: (
434
+ images, False)
435
+ x_samples = self.pipe(
436
+ prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
437
+ num_images_per_prompt=num_samples,
438
+ num_inference_steps=ddim_steps,
439
+ generator=generator,
440
+ height=H,
441
+ width=W,
442
+ image=control.type(torch.float16),
443
+ controlnet_conditioning_scale=float(control_scale),
444
+ ).images
445
+ elif self.extra_inpaint:
446
+ x_samples = self.pipe(
447
+ image=img,
448
+ mask_image=mask_image,
449
+ prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
450
+ num_images_per_prompt=num_samples,
451
+ num_inference_steps=ddim_steps,
452
+ generator=generator,
453
+ controlnet_conditioning_image=[control.type(
454
+ torch.float16), inpaint_image.type(torch.float16)],
455
+ height=H,
456
+ width=W,
457
+ controlnet_conditioning_scale=(float(control_scale), 1.0),
458
+ ).images
459
+ else:
460
+ x_samples = self.pipe(
461
+ image=img,
462
+ mask_image=mask_image,
463
+ prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
464
+ num_images_per_prompt=num_samples,
465
+ num_inference_steps=ddim_steps,
466
+ generator=generator,
467
+ controlnet_conditioning_image=control.type(torch.float16),
468
+ height=H,
469
+ width=W,
470
+ controlnet_conditioning_scale=float(control_scale),
471
+ ).images
472
+
473
+ results = [x_samples[i] for i in range(num_samples)]
474
+ return [full_segmask, mask_image] + results, prompt
475
+
476
+ def download_image(url):
477
+ response = requests.get(url)
478
+ return Image.open(BytesIO(response.content)).convert("RGB")
utils/stable_diffusion_controlnet_inpaint.py CHANGED
@@ -1,7 +1,7 @@
1
  # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
2
  # From https://raw.githubusercontent.com/huggingface/diffusers/53377ef83c6446033f3ee506e3ef718db817b293/examples/community/stable_diffusion_controlnet_inpaint.py
3
  import inspect
4
- from typing import Any, Callable, Dict, List, Optional, Union
5
 
6
  import numpy as np
7
  import PIL.Image
@@ -11,6 +11,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
11
 
12
  from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
13
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
 
14
  from diffusers.schedulers import KarrasDiffusionSchedulers
15
  from diffusers.utils import (
16
  PIL_INTERPOLATION,
@@ -19,7 +20,7 @@ from diffusers.utils import (
19
  randn_tensor,
20
  replace_example_docstring,
21
  )
22
-
23
 
24
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
 
@@ -184,7 +185,7 @@ def prepare_mask_image(mask_image):
184
 
185
 
186
  def prepare_controlnet_conditioning_image(
187
- controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
188
  ):
189
  if not isinstance(controlnet_conditioning_image, torch.Tensor):
190
  if isinstance(controlnet_conditioning_image, PIL.Image.Image):
@@ -214,10 +215,13 @@ def prepare_controlnet_conditioning_image(
214
 
215
  controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
216
 
 
 
 
217
  return controlnet_conditioning_image
218
 
219
 
220
- class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
221
  """
222
  Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
223
  """
@@ -230,7 +234,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
230
  text_encoder: CLIPTextModel,
231
  tokenizer: CLIPTokenizer,
232
  unet: UNet2DConditionModel,
233
- controlnet: ControlNetModel,
234
  scheduler: KarrasDiffusionSchedulers,
235
  safety_checker: StableDiffusionSafetyChecker,
236
  feature_extractor: CLIPImageProcessor,
@@ -253,7 +257,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
253
  "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
254
  " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
255
  )
256
-
 
257
  self.register_modules(
258
  vae=vae,
259
  text_encoder=text_encoder,
@@ -522,6 +527,42 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
522
  extra_step_kwargs["generator"] = generator
523
  return extra_step_kwargs
524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  def check_inputs(
526
  self,
527
  prompt,
@@ -534,6 +575,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
534
  negative_prompt=None,
535
  prompt_embeds=None,
536
  negative_prompt_embeds=None,
 
537
  ):
538
  if height % 8 != 0 or width % 8 != 0:
539
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -572,45 +614,35 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
572
  f" {negative_prompt_embeds.shape}."
573
  )
574
 
575
- controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
576
- controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
577
- controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
578
- controlnet_conditioning_image[0], PIL.Image.Image
579
- )
580
- controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
581
- controlnet_conditioning_image[0], torch.Tensor
582
- )
583
-
584
- if (
585
- not controlnet_cond_image_is_pil
586
- and not controlnet_cond_image_is_tensor
587
- and not controlnet_cond_image_is_pil_list
588
- and not controlnet_cond_image_is_tensor_list
589
- ):
590
- raise TypeError(
591
- "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
592
- )
593
-
594
- if controlnet_cond_image_is_pil:
595
- controlnet_cond_image_batch_size = 1
596
- elif controlnet_cond_image_is_tensor:
597
- controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
598
- elif controlnet_cond_image_is_pil_list:
599
- controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
600
- elif controlnet_cond_image_is_tensor_list:
601
- controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
602
-
603
- if prompt is not None and isinstance(prompt, str):
604
- prompt_batch_size = 1
605
- elif prompt is not None and isinstance(prompt, list):
606
- prompt_batch_size = len(prompt)
607
- elif prompt_embeds is not None:
608
- prompt_batch_size = prompt_embeds.shape[0]
609
-
610
- if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
611
- raise ValueError(
612
- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
613
- )
614
 
615
  if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
616
  raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
@@ -630,6 +662,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
630
  image_channels, image_height, image_width = image.shape
631
  elif image.ndim == 4:
632
  image_batch_size, image_channels, image_height, image_width = image.shape
 
 
633
 
634
  if mask_image.ndim == 2:
635
  mask_image_batch_size = 1
@@ -664,8 +698,11 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
664
 
665
  single_image_latent_channels = self.vae.config.latent_channels
666
 
667
- total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
668
-
 
 
 
669
  if total_latent_channels != self.unet.config.in_channels:
670
  raise ValueError(
671
  f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
@@ -797,7 +834,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
797
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
798
  callback_steps: int = 1,
799
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
800
- controlnet_conditioning_scale: float = 1.0,
801
  ):
802
  r"""
803
  Function invoked when calling the pipeline for generation.
@@ -897,6 +934,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
897
  negative_prompt,
898
  prompt_embeds,
899
  negative_prompt_embeds,
 
900
  )
901
 
902
  # 2. Define call parameters
@@ -913,6 +951,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
913
  # corresponds to doing no classifier free guidance.
914
  do_classifier_free_guidance = guidance_scale > 1.0
915
 
 
 
 
916
  # 3. Encode input prompt
917
  prompt_embeds = self._encode_prompt(
918
  prompt,
@@ -929,15 +970,37 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
929
 
930
  mask_image = prepare_mask_image(mask_image)
931
 
932
- controlnet_conditioning_image = prepare_controlnet_conditioning_image(
933
- controlnet_conditioning_image,
934
- width,
935
- height,
936
- batch_size * num_images_per_prompt,
937
- num_images_per_prompt,
938
- device,
939
- self.controlnet.dtype,
940
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941
 
942
  masked_image = image * (mask_image < 0.5)
943
 
@@ -958,29 +1021,45 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
958
  latents,
959
  )
960
 
961
- mask_image_latents = self.prepare_mask_latents(
962
- mask_image,
963
- batch_size * num_images_per_prompt,
964
- height,
965
- width,
966
- prompt_embeds.dtype,
967
- device,
968
- do_classifier_free_guidance,
969
- )
970
 
971
- masked_image_latents = self.prepare_masked_image_latents(
972
- masked_image,
973
- batch_size * num_images_per_prompt,
974
- height,
975
- width,
976
- prompt_embeds.dtype,
977
- device,
978
- generator,
979
- do_classifier_free_guidance,
980
- )
981
 
982
- if do_classifier_free_guidance:
983
- controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
984
 
985
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
986
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -997,25 +1076,22 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
997
  non_inpainting_latent_model_input = self.scheduler.scale_model_input(
998
  non_inpainting_latent_model_input, t
999
  )
1000
-
1001
- inpainting_latent_model_input = torch.cat(
1002
- [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
1003
- )
 
 
1004
 
1005
  down_block_res_samples, mid_block_res_sample = self.controlnet(
1006
  non_inpainting_latent_model_input,
1007
  t,
1008
  encoder_hidden_states=prompt_embeds,
1009
  controlnet_cond=controlnet_conditioning_image,
 
1010
  return_dict=False,
1011
  )
1012
 
1013
- down_block_res_samples = [
1014
- down_block_res_sample * controlnet_conditioning_scale
1015
- for down_block_res_sample in down_block_res_samples
1016
- ]
1017
- mid_block_res_sample *= controlnet_conditioning_scale
1018
-
1019
  # predict the noise residual
1020
  noise_pred = self.unet(
1021
  inpainting_latent_model_input,
@@ -1039,6 +1115,14 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
1039
  progress_bar.update()
1040
  if callback is not None and i % callback_steps == 0:
1041
  callback(i, t, latents)
 
 
 
 
 
 
 
 
1042
 
1043
  # If we do sequential model offloading, let's offload unet and controlnet
1044
  # manually for max memory savings
 
1
  # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
2
  # From https://raw.githubusercontent.com/huggingface/diffusers/53377ef83c6446033f3ee506e3ef718db817b293/examples/community/stable_diffusion_controlnet_inpaint.py
3
  import inspect
4
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
5
 
6
  import numpy as np
7
  import PIL.Image
 
11
 
12
  from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
13
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
14
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
15
  from diffusers.schedulers import KarrasDiffusionSchedulers
16
  from diffusers.utils import (
17
  PIL_INTERPOLATION,
 
20
  randn_tensor,
21
  replace_example_docstring,
22
  )
23
+ from diffusers.loaders import LoraLoaderMixin
24
 
25
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
 
 
185
 
186
 
187
  def prepare_controlnet_conditioning_image(
188
+ controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance,
189
  ):
190
  if not isinstance(controlnet_conditioning_image, torch.Tensor):
191
  if isinstance(controlnet_conditioning_image, PIL.Image.Image):
 
215
 
216
  controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
217
 
218
+ if do_classifier_free_guidance:
219
+ controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
220
+
221
  return controlnet_conditioning_image
222
 
223
 
224
+ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixin):
225
  """
226
  Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
227
  """
 
234
  text_encoder: CLIPTextModel,
235
  tokenizer: CLIPTokenizer,
236
  unet: UNet2DConditionModel,
237
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
238
  scheduler: KarrasDiffusionSchedulers,
239
  safety_checker: StableDiffusionSafetyChecker,
240
  feature_extractor: CLIPImageProcessor,
 
257
  "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
258
  " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
259
  )
260
+ if isinstance(controlnet, (list, tuple)):
261
+ controlnet = MultiControlNetModel(controlnet)
262
  self.register_modules(
263
  vae=vae,
264
  text_encoder=text_encoder,
 
527
  extra_step_kwargs["generator"] = generator
528
  return extra_step_kwargs
529
 
530
+ def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
531
+ image_is_pil = isinstance(image, PIL.Image.Image)
532
+ image_is_tensor = isinstance(image, torch.Tensor)
533
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
534
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
535
+
536
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
537
+ raise TypeError(
538
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
539
+ )
540
+
541
+ if image_is_pil:
542
+ image_batch_size = 1
543
+ elif image_is_tensor:
544
+ image_batch_size = image.shape[0]
545
+ elif image_is_pil_list:
546
+ image_batch_size = len(image)
547
+ elif image_is_tensor_list:
548
+ image_batch_size = len(image)
549
+ else:
550
+ raise ValueError("controlnet condition image is not valid")
551
+
552
+ if prompt is not None and isinstance(prompt, str):
553
+ prompt_batch_size = 1
554
+ elif prompt is not None and isinstance(prompt, list):
555
+ prompt_batch_size = len(prompt)
556
+ elif prompt_embeds is not None:
557
+ prompt_batch_size = prompt_embeds.shape[0]
558
+ else:
559
+ raise ValueError("prompt or prompt_embeds are not valid")
560
+
561
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
562
+ raise ValueError(
563
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
564
+ )
565
+
566
  def check_inputs(
567
  self,
568
  prompt,
 
575
  negative_prompt=None,
576
  prompt_embeds=None,
577
  negative_prompt_embeds=None,
578
+ controlnet_conditioning_scale=None,
579
  ):
580
  if height % 8 != 0 or width % 8 != 0:
581
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
 
614
  f" {negative_prompt_embeds.shape}."
615
  )
616
 
617
+ # check controlnet condition image
618
+ if isinstance(self.controlnet, ControlNetModel):
619
+ self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
620
+ elif isinstance(self.controlnet, MultiControlNetModel):
621
+ if not isinstance(controlnet_conditioning_image, list):
622
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
623
+ if len(controlnet_conditioning_image) != len(self.controlnet.nets):
624
+ raise ValueError(
625
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
626
+ )
627
+ for image_ in controlnet_conditioning_image:
628
+ self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
629
+ else:
630
+ assert False
631
+
632
+ # Check `controlnet_conditioning_scale`
633
+ if isinstance(self.controlnet, ControlNetModel):
634
+ if not isinstance(controlnet_conditioning_scale, float):
635
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
636
+ elif isinstance(self.controlnet, MultiControlNetModel):
637
+ if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
638
+ self.controlnet.nets
639
+ ):
640
+ raise ValueError(
641
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
642
+ " the same length as the number of controlnets"
643
+ )
644
+ else:
645
+ assert False
 
 
 
 
 
 
 
 
 
 
646
 
647
  if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
648
  raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
 
662
  image_channels, image_height, image_width = image.shape
663
  elif image.ndim == 4:
664
  image_batch_size, image_channels, image_height, image_width = image.shape
665
+ else:
666
+ assert False
667
 
668
  if mask_image.ndim == 2:
669
  mask_image_batch_size = 1
 
698
 
699
  single_image_latent_channels = self.vae.config.latent_channels
700
 
701
+ if self.unet.config.in_channels==4:
702
+ total_latent_channels = single_image_latent_channels # support base model without inpainting ability.
703
+ else:
704
+ total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
705
+
706
  if total_latent_channels != self.unet.config.in_channels:
707
  raise ValueError(
708
  f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
 
834
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
835
  callback_steps: int = 1,
836
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
837
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
838
  ):
839
  r"""
840
  Function invoked when calling the pipeline for generation.
 
934
  negative_prompt,
935
  prompt_embeds,
936
  negative_prompt_embeds,
937
+ controlnet_conditioning_scale,
938
  )
939
 
940
  # 2. Define call parameters
 
951
  # corresponds to doing no classifier free guidance.
952
  do_classifier_free_guidance = guidance_scale > 1.0
953
 
954
+ if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
955
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
956
+
957
  # 3. Encode input prompt
958
  prompt_embeds = self._encode_prompt(
959
  prompt,
 
970
 
971
  mask_image = prepare_mask_image(mask_image)
972
 
973
+ # condition image(s)
974
+ if isinstance(self.controlnet, ControlNetModel):
975
+ controlnet_conditioning_image = prepare_controlnet_conditioning_image(
976
+ controlnet_conditioning_image=controlnet_conditioning_image,
977
+ width=width,
978
+ height=height,
979
+ batch_size=batch_size * num_images_per_prompt,
980
+ num_images_per_prompt=num_images_per_prompt,
981
+ device=device,
982
+ dtype=self.controlnet.dtype,
983
+ do_classifier_free_guidance=do_classifier_free_guidance,
984
+ )
985
+ elif isinstance(self.controlnet, MultiControlNetModel):
986
+ controlnet_conditioning_images = []
987
+
988
+ for image_ in controlnet_conditioning_image:
989
+ image_ = prepare_controlnet_conditioning_image(
990
+ controlnet_conditioning_image=image_,
991
+ width=width,
992
+ height=height,
993
+ batch_size=batch_size * num_images_per_prompt,
994
+ num_images_per_prompt=num_images_per_prompt,
995
+ device=device,
996
+ dtype=self.controlnet.dtype,
997
+ do_classifier_free_guidance=do_classifier_free_guidance,
998
+ )
999
+ controlnet_conditioning_images.append(image_)
1000
+
1001
+ controlnet_conditioning_image = controlnet_conditioning_images
1002
+ else:
1003
+ assert False
1004
 
1005
  masked_image = image * (mask_image < 0.5)
1006
 
 
1021
  latents,
1022
  )
1023
 
1024
+ noise = latents
 
 
 
 
 
 
 
 
1025
 
1026
+ if self.unet.config.in_channels!=4:
1027
+ mask_image_latents = self.prepare_mask_latents(
1028
+ mask_image,
1029
+ batch_size * num_images_per_prompt,
1030
+ height,
1031
+ width,
1032
+ prompt_embeds.dtype,
1033
+ device,
1034
+ do_classifier_free_guidance,
1035
+ )
1036
 
1037
+ masked_image_latents = self.prepare_masked_image_latents(
1038
+ masked_image,
1039
+ batch_size * num_images_per_prompt,
1040
+ height,
1041
+ width,
1042
+ prompt_embeds.dtype,
1043
+ device,
1044
+ generator,
1045
+ do_classifier_free_guidance,
1046
+ )
1047
+ if self.unet.config.in_channels==4:
1048
+ init_masked_image_latents, _ = self.prepare_masked_image_latents(
1049
+ image,
1050
+ batch_size * num_images_per_prompt,
1051
+ height,
1052
+ width,
1053
+ prompt_embeds.dtype,
1054
+ device,
1055
+ generator,
1056
+ do_classifier_free_guidance,
1057
+ ).chunk(2)
1058
+ print(type(mask_image), mask_image.shape)
1059
+ _, _, w, h = mask_image.shape
1060
+ mask_image = torch.nn.functional.interpolate(mask_image, ((w // 8, h // 8)), mode='nearest')
1061
+ mask_image = mask_image.to(latents.device).type_as(latents)
1062
+ mask_image = 1 - mask_image
1063
 
1064
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1065
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
 
1076
  non_inpainting_latent_model_input = self.scheduler.scale_model_input(
1077
  non_inpainting_latent_model_input, t
1078
  )
1079
+ if self.unet.config.in_channels!=4:
1080
+ inpainting_latent_model_input = torch.cat(
1081
+ [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
1082
+ )
1083
+ else:
1084
+ inpainting_latent_model_input = non_inpainting_latent_model_input
1085
 
1086
  down_block_res_samples, mid_block_res_sample = self.controlnet(
1087
  non_inpainting_latent_model_input,
1088
  t,
1089
  encoder_hidden_states=prompt_embeds,
1090
  controlnet_cond=controlnet_conditioning_image,
1091
+ conditioning_scale=controlnet_conditioning_scale,
1092
  return_dict=False,
1093
  )
1094
 
 
 
 
 
 
 
1095
  # predict the noise residual
1096
  noise_pred = self.unet(
1097
  inpainting_latent_model_input,
 
1115
  progress_bar.update()
1116
  if callback is not None and i % callback_steps == 0:
1117
  callback(i, t, latents)
1118
+ # if self.unet.config.in_channels==4:
1119
+ # # masking for non-inpainting models
1120
+ # init_latents_proper = self.scheduler.add_noise(init_masked_image_latents, noise, t)
1121
+ # latents = (init_latents_proper * mask_image) + (latents * (1 - mask_image))
1122
+
1123
+ if self.unet.config.in_channels==4:
1124
+ # fill the unmasked part with original image
1125
+ latents = (init_masked_image_latents * mask_image) + (latents * (1 - mask_image))
1126
 
1127
  # If we do sequential model offloading, let's offload unet and controlnet
1128
  # manually for max memory savings