Spaces:
Runtime error
Runtime error
update new demo
Browse files- app.py +41 -6
- sam2edit.py +79 -315
- sam2edit_beauty.py +95 -0
- sam2edit_handsome.py +90 -0
- sam2edit_lora.py +478 -0
- utils/stable_diffusion_controlnet_inpaint.py +172 -88
app.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
|
4 |
from sam2edit import create_demo as create_demo_edit_anything
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
7 |
|
8 |
DESCRIPTION = f'''# [Edit Anything](https://github.com/sail-sg/EditAnything)
|
9 |
**Edit anything and keep the layout by segmenting anything in the image.**
|
@@ -12,13 +15,45 @@ SHARED_UI_WARNING = f'''### [NOTE] Inference may be slow in this shared UI.
|
|
12 |
You can duplicate and use it with a paid private GPU.
|
13 |
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
|
14 |
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
with gr.Blocks() as demo:
|
16 |
gr.Markdown(DESCRIPTION)
|
17 |
-
gr.Markdown(SHARED_UI_WARNING)
|
18 |
with gr.Tabs():
|
19 |
-
with gr.TabItem('Edit Anything'):
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# with gr.TabItem('Generate Anything'):
|
22 |
# create_demo_generate_anything()
|
|
|
|
|
23 |
|
24 |
demo.queue(api_open=False).launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
|
4 |
from sam2edit import create_demo as create_demo_edit_anything
|
5 |
+
from sam2image import create_demo as create_demo_generate_anything
|
6 |
+
from sam2edit_beauty import create_demo as create_demo_beauty
|
7 |
+
from sam2edit_handsome import create_demo as create_demo_handsome
|
8 |
+
from sam2edit_lora import EditAnythingLoraModel, init_sam_model, init_blip_processor, init_blip_model
|
9 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
10 |
|
11 |
DESCRIPTION = f'''# [Edit Anything](https://github.com/sail-sg/EditAnything)
|
12 |
**Edit anything and keep the layout by segmenting anything in the image.**
|
|
|
15 |
You can duplicate and use it with a paid private GPU.
|
16 |
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
|
17 |
'''
|
18 |
+
|
19 |
+
#
|
20 |
+
sam_generator = init_sam_model()
|
21 |
+
blip_processor = init_blip_processor()
|
22 |
+
blip_model = init_blip_model()
|
23 |
+
|
24 |
+
sd_models_path = snapshot_download("shgao/sdmodels")
|
25 |
+
|
26 |
with gr.Blocks() as demo:
|
27 |
gr.Markdown(DESCRIPTION)
|
|
|
28 |
with gr.Tabs():
|
29 |
+
with gr.TabItem('πEdit Anything'):
|
30 |
+
model = EditAnythingLoraModel(base_model_path="stabilityai/stable-diffusion-2-inpainting",
|
31 |
+
controlmodel_name='LAION Pretrained(v0-4)-SD21',
|
32 |
+
lora_model_path=None, use_blip=True, extra_inpaint=False,
|
33 |
+
sam_generator=sam_generator,
|
34 |
+
blip_processor=blip_processor,
|
35 |
+
blip_model=blip_model)
|
36 |
+
create_demo_edit_anything(model.process)
|
37 |
+
with gr.TabItem(' π©βπ¦°Beauty Edit/Generation'):
|
38 |
+
lora_model_path = hf_hub_download(
|
39 |
+
"mlida/Cute_girl_mix4", "cuteGirlMix4_v10.safetensors")
|
40 |
+
model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "chilloutmix_NiPrunedFp32Fix"),
|
41 |
+
lora_model_path=lora_model_path, use_blip=True, extra_inpaint=True,
|
42 |
+
sam_generator=sam_generator,
|
43 |
+
blip_processor=blip_processor,
|
44 |
+
blip_model=blip_model
|
45 |
+
)
|
46 |
+
create_demo_beauty(model.process)
|
47 |
+
with gr.TabItem(' π¨βπΎHandsome Edit/Generation'):
|
48 |
+
model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "Realistic_Vision_V2.0"),
|
49 |
+
lora_model_path=None, use_blip=True, extra_inpaint=True,
|
50 |
+
sam_generator=sam_generator,
|
51 |
+
blip_processor=blip_processor,
|
52 |
+
blip_model=blip_model)
|
53 |
+
create_demo_handsome(model.process)
|
54 |
# with gr.TabItem('Generate Anything'):
|
55 |
# create_demo_generate_anything()
|
56 |
+
with gr.Tabs():
|
57 |
+
gr.Markdown(SHARED_UI_WARNING)
|
58 |
|
59 |
demo.queue(api_open=False).launch()
|
sam2edit.py
CHANGED
@@ -1,321 +1,85 @@
|
|
1 |
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
2 |
-
from torchvision.utils import save_image
|
3 |
-
from PIL import Image
|
4 |
-
from pytorch_lightning import seed_everything
|
5 |
-
import subprocess
|
6 |
-
from collections import OrderedDict
|
7 |
-
|
8 |
-
import cv2
|
9 |
-
import einops
|
10 |
import gradio as gr
|
11 |
-
|
12 |
-
import
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
"
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
# remove following line if xformers is not installed
|
51 |
-
pipe.enable_xformers_memory_efficient_attention()
|
52 |
-
|
53 |
-
pipe.enable_model_cpu_offload() # disable for now because of unknow bug in accelerate
|
54 |
-
# pipe.to(device)
|
55 |
-
return pipe
|
56 |
-
global default_controlnet_path
|
57 |
-
global pipe
|
58 |
-
default_controlnet_path = config_dict['LAION Pretrained(v0-3): Good Face']
|
59 |
-
pipe = obtain_generation_model(default_controlnet_path)
|
60 |
-
|
61 |
-
# Segment-Anything init.
|
62 |
-
# pip install git+https://github.com/facebookresearch/segment-anything.git
|
63 |
-
|
64 |
-
try:
|
65 |
-
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
66 |
-
except ImportError:
|
67 |
-
print('segment_anything not installed')
|
68 |
-
result = subprocess.run(['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
|
69 |
-
print(f'Install segment_anything {result}')
|
70 |
-
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
71 |
-
if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
|
72 |
-
result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
|
73 |
-
print(f'Download sam_vit_h_4b8939.pth {result}')
|
74 |
-
sam_checkpoint = "models/sam_vit_h_4b8939.pth"
|
75 |
-
model_type = "default"
|
76 |
-
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
77 |
-
sam.to(device=device)
|
78 |
-
mask_generator = SamAutomaticMaskGenerator(sam)
|
79 |
-
|
80 |
-
|
81 |
-
# BLIP2 init.
|
82 |
-
if use_blip:
|
83 |
-
# need the latest transformers
|
84 |
-
# pip install git+https://github.com/huggingface/transformers.git
|
85 |
-
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
86 |
-
|
87 |
-
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
88 |
-
blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
89 |
-
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
|
90 |
-
|
91 |
-
|
92 |
-
def get_blip2_text(image):
|
93 |
-
inputs = processor(image, return_tensors="pt").to(device, torch.float16)
|
94 |
-
generated_ids = blip_model.generate(**inputs, max_new_tokens=50)
|
95 |
-
generated_text = processor.batch_decode(
|
96 |
-
generated_ids, skip_special_tokens=True)[0].strip()
|
97 |
-
return generated_text
|
98 |
-
|
99 |
-
|
100 |
-
def show_anns(anns):
|
101 |
-
if len(anns) == 0:
|
102 |
-
return
|
103 |
-
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
104 |
-
full_img = None
|
105 |
-
|
106 |
-
# for ann in sorted_anns:
|
107 |
-
for i in range(len(sorted_anns)):
|
108 |
-
ann = anns[i]
|
109 |
-
m = ann['segmentation']
|
110 |
-
if full_img is None:
|
111 |
-
full_img = np.zeros((m.shape[0], m.shape[1], 3))
|
112 |
-
map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
|
113 |
-
map[m != 0] = i + 1
|
114 |
-
color_mask = np.random.random((1, 3)).tolist()[0]
|
115 |
-
full_img[m != 0] = color_mask
|
116 |
-
full_img = full_img*255
|
117 |
-
# anno encoding from https://github.com/LUSSeg/ImageNet-S
|
118 |
-
res = np.zeros((map.shape[0], map.shape[1], 3))
|
119 |
-
res[:, :, 0] = map % 256
|
120 |
-
res[:, :, 1] = map // 256
|
121 |
-
res.astype(np.float32)
|
122 |
-
full_img = Image.fromarray(np.uint8(full_img))
|
123 |
-
return full_img, res
|
124 |
-
|
125 |
-
|
126 |
-
def get_sam_control(image):
|
127 |
-
masks = mask_generator.generate(image)
|
128 |
-
full_img, res = show_anns(masks)
|
129 |
-
return full_img, res
|
130 |
-
|
131 |
-
|
132 |
-
def process(condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
|
133 |
-
|
134 |
-
input_image = source_image["image"]
|
135 |
-
if mask_image is None:
|
136 |
-
if enable_all_generate:
|
137 |
-
print("source_image", source_image["mask"].shape, input_image.shape,)
|
138 |
-
print(source_image["mask"].max())
|
139 |
-
mask_image = np.ones((input_image.shape[0], input_image.shape[1], 3))*255
|
140 |
-
else:
|
141 |
-
mask_image = source_image["mask"]
|
142 |
-
global default_controlnet_path
|
143 |
-
print("To Use:", config_dict[condition_model], "Current:", default_controlnet_path)
|
144 |
-
if default_controlnet_path!=config_dict[condition_model]:
|
145 |
-
print("Change condition model to:", config_dict[condition_model])
|
146 |
-
global pipe
|
147 |
-
pipe = obtain_generation_model(config_dict[condition_model])
|
148 |
-
default_controlnet_path = config_dict[condition_model]
|
149 |
-
torch.cuda.empty_cache()
|
150 |
-
|
151 |
-
with torch.no_grad():
|
152 |
-
if use_blip and (enable_auto_prompt or len(prompt) == 0):
|
153 |
-
print("Generating text:")
|
154 |
-
blip2_prompt = get_blip2_text(input_image)
|
155 |
-
print("Generated text:", blip2_prompt)
|
156 |
-
if len(prompt)>0:
|
157 |
-
prompt = blip2_prompt + ',' + prompt
|
158 |
-
else:
|
159 |
-
prompt = blip2_prompt
|
160 |
-
print("All text:", prompt)
|
161 |
-
|
162 |
-
input_image = HWC3(input_image)
|
163 |
-
|
164 |
-
img = resize_image(input_image, image_resolution)
|
165 |
-
H, W, C = img.shape
|
166 |
-
|
167 |
-
print("Generating SAM seg:")
|
168 |
-
# the default SAM model is trained with 1024 size.
|
169 |
-
full_segmask, detected_map = get_sam_control(
|
170 |
-
resize_image(input_image, detect_resolution))
|
171 |
-
|
172 |
-
detected_map = HWC3(detected_map.astype(np.uint8))
|
173 |
-
detected_map = cv2.resize(
|
174 |
-
detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
175 |
-
|
176 |
-
control = torch.from_numpy(
|
177 |
-
detected_map.copy()).float().cuda()
|
178 |
-
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
179 |
-
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
180 |
-
|
181 |
-
mask_image = HWC3(mask_image.astype(np.uint8))
|
182 |
-
mask_image = cv2.resize(
|
183 |
-
mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
|
184 |
-
mask_image = Image.fromarray(mask_image)
|
185 |
-
|
186 |
-
|
187 |
-
if seed == -1:
|
188 |
-
seed = random.randint(0, 65535)
|
189 |
-
seed_everything(seed)
|
190 |
-
generator = torch.manual_seed(seed)
|
191 |
-
if condition_model=='SD Inpainting: Not keep position':
|
192 |
-
x_samples = pipe(
|
193 |
-
image=img,
|
194 |
-
mask_image=mask_image,
|
195 |
-
prompt=[prompt + ', ' + a_prompt] * num_samples,
|
196 |
-
negative_prompt=[n_prompt] * num_samples,
|
197 |
-
num_images_per_prompt=num_samples,
|
198 |
-
num_inference_steps=ddim_steps,
|
199 |
-
generator=generator,
|
200 |
-
height=H,
|
201 |
-
width=W,
|
202 |
-
).images
|
203 |
-
else:
|
204 |
-
x_samples = pipe(
|
205 |
-
image=img,
|
206 |
-
mask_image=mask_image,
|
207 |
-
prompt=[prompt + ', ' + a_prompt] * num_samples,
|
208 |
-
negative_prompt=[n_prompt] * num_samples,
|
209 |
-
num_images_per_prompt=num_samples,
|
210 |
-
num_inference_steps=ddim_steps,
|
211 |
-
generator=generator,
|
212 |
-
controlnet_conditioning_image=control.type(torch.float16),
|
213 |
-
height=H,
|
214 |
-
width=W,
|
215 |
-
controlnet_conditioning_scale=control_scale,
|
216 |
-
).images
|
217 |
-
|
218 |
-
|
219 |
-
results = [x_samples[i] for i in range(num_samples)]
|
220 |
-
return [full_segmask, mask_image] + results, prompt
|
221 |
-
|
222 |
-
|
223 |
-
def download_image(url):
|
224 |
-
response = requests.get(url)
|
225 |
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
226 |
-
|
227 |
-
# disable gradio when not using GUI.
|
228 |
-
if not use_gradio:
|
229 |
-
# This part is not updated, it's just a example to use it without GUI.
|
230 |
-
image_path = "../data/samples/sa_223750.jpg"
|
231 |
-
mask_path = "../data/samples/sa_223750inpaint.png"
|
232 |
-
input_image = Image.open(image_path)
|
233 |
-
mask_image = Image.open(mask_path)
|
234 |
-
enable_auto_prompt = True
|
235 |
-
input_image = np.array(input_image, dtype=np.uint8)
|
236 |
-
mask_image = np.array(mask_image, dtype=np.uint8)
|
237 |
-
prompt = "esplendent sunset sky, red brick wall"
|
238 |
-
a_prompt = 'best quality, extremely detailed'
|
239 |
-
n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
|
240 |
-
num_samples = 3
|
241 |
-
image_resolution = 512
|
242 |
-
detect_resolution = 512
|
243 |
-
ddim_steps = 30
|
244 |
-
guess_mode = False
|
245 |
-
strength = 1.0
|
246 |
-
scale = 9.0
|
247 |
-
seed = -1
|
248 |
-
eta = 0.0
|
249 |
-
|
250 |
-
outputs = process(condition_model, input_image, mask_image, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
|
251 |
-
detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta)
|
252 |
-
|
253 |
-
image_list = []
|
254 |
-
input_image = resize_image(input_image, 512)
|
255 |
-
image_list.append(torch.tensor(input_image))
|
256 |
-
for i in range(len(outputs)):
|
257 |
-
each = outputs[i]
|
258 |
-
if type(each) is not np.ndarray:
|
259 |
-
each = np.array(each, dtype=np.uint8)
|
260 |
-
each = resize_image(each, 512)
|
261 |
-
print(i, each.shape)
|
262 |
-
image_list.append(torch.tensor(each))
|
263 |
-
|
264 |
-
image_list = torch.stack(image_list).permute(0, 3, 1, 2)
|
265 |
-
|
266 |
-
save_image(image_list, "sample.jpg", nrow=3,
|
267 |
-
normalize=True, value_range=(0, 255))
|
268 |
-
else:
|
269 |
-
print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
|
270 |
-
block = gr.Blocks()
|
271 |
-
with block as demo:
|
272 |
-
with gr.Row():
|
273 |
-
gr.Markdown(
|
274 |
-
"## Edit Anything")
|
275 |
-
with gr.Row():
|
276 |
-
with gr.Column():
|
277 |
-
source_image = gr.Image(source='upload',label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
|
278 |
-
enable_all_generate = gr.Checkbox(label='Auto generation on all region.', value=False)
|
279 |
-
prompt = gr.Textbox(label="Prompt (Text in the expected things of edited region)")
|
280 |
-
enable_auto_prompt = gr.Checkbox(label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=True)
|
281 |
-
control_scale = gr.Slider(
|
282 |
-
label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
|
283 |
-
run_button = gr.Button(label="Run")
|
284 |
condition_model = gr.Dropdown(choices=list(config_dict.keys()),
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
|
|
|
|
|
|
318 |
|
319 |
if __name__ == '__main__':
|
320 |
-
|
|
|
|
|
|
|
321 |
demo.queue().launch(server_name='0.0.0.0')
|
|
|
1 |
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
+
from diffusers.utils import load_image
|
4 |
+
from sam2edit_lora import EditAnythingLoraModel, config_dict
|
5 |
+
|
6 |
+
|
7 |
+
def create_demo(process):
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
|
12 |
+
WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
|
13 |
+
We are not responsible for possible risks using this model.
|
14 |
+
'''
|
15 |
+
block = gr.Blocks()
|
16 |
+
with block as demo:
|
17 |
+
with gr.Row():
|
18 |
+
gr.Markdown(
|
19 |
+
"## Generate Your Beauty powered by EditAnything https://github.com/sail-sg/EditAnything ")
|
20 |
+
with gr.Row():
|
21 |
+
with gr.Column():
|
22 |
+
source_image = gr.Image(
|
23 |
+
source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
|
24 |
+
enable_all_generate = gr.Checkbox(
|
25 |
+
label='Auto generation on all region.', value=False)
|
26 |
+
prompt = gr.Textbox(
|
27 |
+
label="Prompt (Text in the expected things of edited region)")
|
28 |
+
enable_auto_prompt = gr.Checkbox(
|
29 |
+
label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
|
30 |
+
a_prompt = gr.Textbox(
|
31 |
+
label="Added Prompt", value='best quality, extremely detailed')
|
32 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
33 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
34 |
+
control_scale = gr.Slider(
|
35 |
+
label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
|
36 |
+
run_button = gr.Button(label="Run")
|
37 |
+
num_samples = gr.Slider(
|
38 |
+
label="Images", minimum=1, maximum=12, value=2, step=1)
|
39 |
+
seed = gr.Slider(label="Seed", minimum=-1,
|
40 |
+
maximum=2147483647, step=1, randomize=True)
|
41 |
+
with gr.Accordion("Advanced options", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
condition_model = gr.Dropdown(choices=list(config_dict.keys()),
|
43 |
+
value=list(
|
44 |
+
config_dict.keys())[1],
|
45 |
+
label='Model',
|
46 |
+
multiselect=False)
|
47 |
+
mask_image = gr.Image(
|
48 |
+
source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
|
49 |
+
image_resolution = gr.Slider(
|
50 |
+
label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
51 |
+
strength = gr.Slider(
|
52 |
+
label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
53 |
+
guess_mode = gr.Checkbox(
|
54 |
+
label='Guess Mode', value=False)
|
55 |
+
detect_resolution = gr.Slider(
|
56 |
+
label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
|
57 |
+
ddim_steps = gr.Slider(
|
58 |
+
label="Steps", minimum=1, maximum=100, value=30, step=1)
|
59 |
+
scale = gr.Slider(
|
60 |
+
label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
61 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
62 |
+
with gr.Column():
|
63 |
+
result_gallery = gr.Gallery(
|
64 |
+
label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
65 |
+
result_text = gr.Text(label='BLIP2+Human Prompt Text')
|
66 |
+
ips = [condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
|
67 |
+
detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
|
68 |
+
run_button.click(fn=process, inputs=ips, outputs=[
|
69 |
+
result_gallery, result_text])
|
70 |
+
# with gr.Row():
|
71 |
+
# ex = gr.Examples(examples=examples, fn=process,
|
72 |
+
# inputs=[a_prompt, n_prompt, scale],
|
73 |
+
# outputs=[result_gallery],
|
74 |
+
# cache_examples=False)
|
75 |
+
with gr.Row():
|
76 |
+
gr.Markdown(WARNING_INFO)
|
77 |
+
return demo
|
78 |
+
|
79 |
|
80 |
if __name__ == '__main__':
|
81 |
+
model = EditAnythingLoraModel(base_model_path="stabilityai/stable-diffusion-2-inpainting",
|
82 |
+
controlmodel_name='LAION Pretrained(v0-4)-SD21', extra_inpaint=False,
|
83 |
+
lora_model_path=None, use_blip=True)
|
84 |
+
demo = create_demo(model.process)
|
85 |
demo.queue().launch(server_name='0.0.0.0')
|
sam2edit_beauty.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
2 |
+
import gradio as gr
|
3 |
+
from diffusers.utils import load_image
|
4 |
+
from sam2edit_lora import EditAnythingLoraModel, config_dict
|
5 |
+
|
6 |
+
|
7 |
+
def create_demo(process):
|
8 |
+
|
9 |
+
examples = [
|
10 |
+
["dudou,1girl, beautiful face, solo, candle, brown hair, long hair, <lora:flowergirl:0.9>,ulzzang-6500-v1.1,(raw photo:1.2),((photorealistic:1.4))best quality ,masterpiece, illustration, an extremely delicate and beautiful, extremely detailed ,CG ,unity ,8k wallpaper, Amazing, finely detail, masterpiece,best quality,official art,extremely detailed CG unity 8k wallpaper,absurdres, incredibly absurdres, huge filesize, ultra-detailed, highres, extremely detailed,beautiful detailed girl, extremely detailed eyes and face, beautiful detailed eyes,cinematic lighting,1girl,see-through,looking at viewer,full body,full-body shot,outdoors,arms behind back,(chinese clothes) <lora:cuteGirlMix4_v10:1>",
|
11 |
+
"(((mole))),sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, bad anatomy,(long hair:1.4),DeepNegative,(fat:1.2),facing away, looking away,tilted head, lowres,bad anatomy,bad hands, text, error, missing fingers,extra digit, fewer digits, cropped, worstquality, low quality, normal quality,jpegartifacts,signature, watermark, username,blurry,bad feet,cropped,poorly drawn hands,poorly drawn face,mutation,deformed,worst quality,low quality,normal quality,jpeg artifacts,signature,watermark,extra fingers,fewer digits,extra limbs,extra arms,extra legs,malformed limbs,fused fingers,too many fingers,long neck,cross-eyed,mutated hands,polar lowres,bad body,bad proportions,gross proportions,text,error,missing fingers,missing arms,missing legs,extra digit, extra arms, extra leg, extra foot,(freckles),(mole:2)", 5],
|
12 |
+
["best quality, ultra high res, (photorealistic:1.4), (detailed beautiful girl:1.4), (medium breasts:0.8), looking_at_viewer, Detailed facial details, beautiful detailed eyes, (multicolored|blue|pink hair: 1.2), green eyes, slender, haunting smile, (makeup:0.3), red lips, <lora:cuteGirlMix4_v10:0.7>, highly detailed clothes, (ulzzang-6500-v1.1:0.3)",
|
13 |
+
"EasyNegative, paintings, sketches, ugly, 3d, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, manboobs, backlight,(ugly:1.3), (duplicate:1.3), (morbid:1.2), (mutilated:1.2), (tranny:1.3), mutated hands, (poorly drawn hands:1.3), blurry, (bad anatomy:1.2), (bad proportions:1.3), extra limbs, (disfigured:1.3), (more than 2 nipples:1.3), (more than 1 navel:1.3), (missing arms:1.3), (extra legs:1.3), (fused fingers:1.6), (too many fingers:1.6), (unclear eyes:1.3), bad hands, missing fingers, extra digit, (futa:1.1), bad body, double navel, mutad arms, hused arms, (puffy nipples, dark areolae, dark nipples, rei no himo, inverted nipples, long nipples), NG_DeepNegative_V1_75t, pubic hair, fat rolls, obese, bad-picture-chill-75v", 8],
|
14 |
+
["best quality, ultra high res, (photorealistic:1.4), (detailed beautiful girl:1.4), (medium breasts:0.8), looking_at_viewer, Detailed facial details, beautiful detailed eyes, (blue|pink hair), green eyes, slender, smile, (makeup:0.4), red lips, (full body, sitting, beach), <lora:cuteGirlMix4_v10:0.7>, highly detailed clothes, (ulzzang-6500-v1.1:0.3)",
|
15 |
+
"asyNegative, paintings, sketches, ugly, 3d, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, manboobs, backlight,(ugly:1.3), (duplicate:1.3), (morbid:1.2), (mutilated:1.2), (tranny:1.3), mutated hands, (poorly drawn hands:1.3), blurry, (bad anatomy:1.2), (bad proportions:1.3), extra limbs, (disfigured:1.3), (more than 2 nipples:1.3), (more than 1 navel:1.3), (missing arms:1.3), (extra legs:1.3), (fused fingers:1.6), (too many fingers:1.6), (unclear eyes:1.3), bad hands, missing fingers, extra digit, (futa:1.1), bad body, double navel, mutad arms, hused arms, (puffy nipples, dark areolae, dark nipples, rei no himo, inverted nipples, long nipples), NG_DeepNegative_V1_75t, pubic hair, fat rolls, obese, bad-picture-chill-75v", 7],
|
16 |
+
["mix4, whole body shot, ((8k, RAW photo, highest quality, masterpiece), High detail RAW color photo professional close-up photo, shy expression, cute, beautiful detailed girl, detailed fingers, extremely detailed eyes and face, beautiful detailed nose, beautiful detailed eyes, long eyelashes, light on face, looking at viewer, (closed mouth:1.2), 1girl, cute, young, mature face, (full body:1.3), ((small breasts)), realistic face, realistic body, beautiful detailed thigh,s, same eyes color, (realistic, photo realism:1. 37), (highest quality), (best shadow), (best illustration), ultra high resolution, physics-based rendering, cinematic lighting), solo, 1girl, highly detailed, in office, detailed office, open cardigan, ponytail contorted, beautiful eyes ,sitting in office,dating, business suit, cross-laced clothes, collared shirt, beautiful breast, small breast, Chinese dress, white pantyhose, natural breasts, pink and white hair, <lora:cuteGirlMix4_v10:1>",
|
17 |
+
"paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), cloth, underwear, bra, low-res, normal quality, ((monochrome)), ((grayscale)), skin spots, acne, skin blemishes, age spots, glans, bad nipples, long nipples, bad vagina, extra fingers,fewer fingers,strange fingers,bad hand, ng_deepnegative_v1_75t, bad-picture-chill-75v", 7]
|
18 |
+
]
|
19 |
+
|
20 |
+
print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
|
21 |
+
WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
|
22 |
+
We are not responsible for possible risks using this model.
|
23 |
+
|
24 |
+
Lora model from https://civitai.com/models/14171/cutegirlmix4 Thanks!
|
25 |
+
'''
|
26 |
+
block = gr.Blocks()
|
27 |
+
with block as demo:
|
28 |
+
with gr.Row():
|
29 |
+
gr.Markdown(
|
30 |
+
"## Generate Your Beauty powered by EditAnything https://github.com/sail-sg/EditAnything ")
|
31 |
+
with gr.Row():
|
32 |
+
with gr.Column():
|
33 |
+
source_image = gr.Image(
|
34 |
+
source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
|
35 |
+
enable_all_generate = gr.Checkbox(
|
36 |
+
label='Auto generation on all region.', value=False)
|
37 |
+
prompt = gr.Textbox(
|
38 |
+
label="Prompt (Text in the expected things of edited region)")
|
39 |
+
enable_auto_prompt = gr.Checkbox(
|
40 |
+
label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
|
41 |
+
a_prompt = gr.Textbox(
|
42 |
+
label="Added Prompt", value='best quality, extremely detailed')
|
43 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
44 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
45 |
+
control_scale = gr.Slider(
|
46 |
+
label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
|
47 |
+
run_button = gr.Button(label="Run")
|
48 |
+
num_samples = gr.Slider(
|
49 |
+
label="Images", minimum=1, maximum=12, value=2, step=1)
|
50 |
+
seed = gr.Slider(label="Seed", minimum=-1,
|
51 |
+
maximum=2147483647, step=1, randomize=True)
|
52 |
+
with gr.Accordion("Advanced options", open=False):
|
53 |
+
condition_model = gr.Dropdown(choices=list(config_dict.keys()),
|
54 |
+
value=list(
|
55 |
+
config_dict.keys())[0],
|
56 |
+
label='Model',
|
57 |
+
multiselect=False)
|
58 |
+
mask_image = gr.Image(
|
59 |
+
source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
|
60 |
+
image_resolution = gr.Slider(
|
61 |
+
label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
62 |
+
strength = gr.Slider(
|
63 |
+
label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
64 |
+
guess_mode = gr.Checkbox(
|
65 |
+
label='Guess Mode', value=False)
|
66 |
+
detect_resolution = gr.Slider(
|
67 |
+
label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
|
68 |
+
ddim_steps = gr.Slider(
|
69 |
+
label="Steps", minimum=1, maximum=100, value=30, step=1)
|
70 |
+
scale = gr.Slider(
|
71 |
+
label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
72 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
73 |
+
with gr.Column():
|
74 |
+
result_gallery = gr.Gallery(
|
75 |
+
label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
76 |
+
result_text = gr.Text(label='BLIP2+Human Prompt Text')
|
77 |
+
ips = [condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
|
78 |
+
detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
|
79 |
+
run_button.click(fn=process, inputs=ips, outputs=[
|
80 |
+
result_gallery, result_text])
|
81 |
+
with gr.Row():
|
82 |
+
ex = gr.Examples(examples=examples, fn=process,
|
83 |
+
inputs=[a_prompt, n_prompt, scale],
|
84 |
+
outputs=[result_gallery],
|
85 |
+
cache_examples=False)
|
86 |
+
with gr.Row():
|
87 |
+
gr.Markdown(WARNING_INFO)
|
88 |
+
return demo
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == '__main__':
|
92 |
+
model = EditAnythingLoraModel(base_model_path='../chilloutmix_NiPrunedFp32Fix',
|
93 |
+
lora_model_path='../40806/mix4', use_blip=True)
|
94 |
+
demo = create_demo(model.process)
|
95 |
+
demo.queue().launch(server_name='0.0.0.0')
|
sam2edit_handsome.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
2 |
+
import gradio as gr
|
3 |
+
from diffusers.utils import load_image
|
4 |
+
from sam2edit_lora import EditAnythingLoraModel, config_dict
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def create_demo(process):
|
9 |
+
|
10 |
+
examples = [
|
11 |
+
["1man, muscle,full body, vest, short straight hair, glasses, Gym, barbells, dumbbells, treadmills, boxing rings, squat racks, plates, dumbbell racks soft lighting, masterpiece, best quality, 8k uhd, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6>", "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 6],
|
12 |
+
["1man, 25 years- old, full body, wearing long-sleeve white shirt and tie, muscular rand black suit, soft lighting, masterpiece, best quality, 8k uhd, dslr, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6> <lora:uncutPenisLora_v10:0.6>","(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",6],
|
13 |
+
]
|
14 |
+
|
15 |
+
print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
|
16 |
+
WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
|
17 |
+
We are not responsible for possible risks using this model.
|
18 |
+
Base model from https://huggingface.co/SG161222/Realistic_Vision_V2.0 Thanks!
|
19 |
+
'''
|
20 |
+
block = gr.Blocks()
|
21 |
+
with block as demo:
|
22 |
+
with gr.Row():
|
23 |
+
gr.Markdown(
|
24 |
+
"## Generate Your Handsome powered by EditAnything https://github.com/sail-sg/EditAnything ")
|
25 |
+
with gr.Row():
|
26 |
+
with gr.Column():
|
27 |
+
source_image = gr.Image(
|
28 |
+
source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
|
29 |
+
enable_all_generate = gr.Checkbox(
|
30 |
+
label='Auto generation on all region.', value=False)
|
31 |
+
prompt = gr.Textbox(
|
32 |
+
label="Prompt (Text in the expected things of edited region)")
|
33 |
+
enable_auto_prompt = gr.Checkbox(
|
34 |
+
label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
|
35 |
+
a_prompt = gr.Textbox(
|
36 |
+
label="Added Prompt", value='best quality, extremely detailed')
|
37 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
38 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
39 |
+
control_scale = gr.Slider(
|
40 |
+
label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
|
41 |
+
run_button = gr.Button(label="Run")
|
42 |
+
num_samples = gr.Slider(
|
43 |
+
label="Images", minimum=1, maximum=12, value=2, step=1)
|
44 |
+
seed = gr.Slider(label="Seed", minimum=-1,
|
45 |
+
maximum=2147483647, step=1, randomize=True)
|
46 |
+
with gr.Accordion("Advanced options", open=False):
|
47 |
+
condition_model = gr.Dropdown(choices=list(config_dict.keys()),
|
48 |
+
value=list(
|
49 |
+
config_dict.keys())[0],
|
50 |
+
label='Model',
|
51 |
+
multiselect=False)
|
52 |
+
mask_image = gr.Image(
|
53 |
+
source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
|
54 |
+
image_resolution = gr.Slider(
|
55 |
+
label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
56 |
+
strength = gr.Slider(
|
57 |
+
label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
58 |
+
guess_mode = gr.Checkbox(
|
59 |
+
label='Guess Mode', value=False)
|
60 |
+
detect_resolution = gr.Slider(
|
61 |
+
label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
|
62 |
+
ddim_steps = gr.Slider(
|
63 |
+
label="Steps", minimum=1, maximum=100, value=30, step=1)
|
64 |
+
scale = gr.Slider(
|
65 |
+
label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
66 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
67 |
+
with gr.Column():
|
68 |
+
result_gallery = gr.Gallery(
|
69 |
+
label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
70 |
+
result_text = gr.Text(label='BLIP2+Human Prompt Text')
|
71 |
+
ips = [condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
|
72 |
+
detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
|
73 |
+
run_button.click(fn=process, inputs=ips, outputs=[
|
74 |
+
result_gallery, result_text])
|
75 |
+
with gr.Row():
|
76 |
+
ex = gr.Examples(examples=examples, fn=process,
|
77 |
+
inputs=[a_prompt, n_prompt, scale],
|
78 |
+
outputs=[result_gallery],
|
79 |
+
cache_examples=False)
|
80 |
+
with gr.Row():
|
81 |
+
gr.Markdown(WARNING_INFO)
|
82 |
+
return demo
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == '__main__':
|
87 |
+
model = EditAnythingLoraModel(base_model_path= '../../gradio-rel/EditAnything/models/Realistic_Vision_V2.0',
|
88 |
+
lora_model_path= '../../gradio-rel/EditAnything/models/asianmale', use_blip=True)
|
89 |
+
demo = create_demo(model.process)
|
90 |
+
demo.queue().launch(server_name='0.0.0.0')
|
sam2edit_lora.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
2 |
+
from torchvision.utils import save_image
|
3 |
+
from PIL import Image
|
4 |
+
from pytorch_lightning import seed_everything
|
5 |
+
import subprocess
|
6 |
+
from collections import OrderedDict
|
7 |
+
import re
|
8 |
+
import cv2
|
9 |
+
import einops
|
10 |
+
import gradio as gr
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import random
|
14 |
+
import os
|
15 |
+
import requests
|
16 |
+
from io import BytesIO
|
17 |
+
from annotator.util import resize_image, HWC3
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from safetensors.torch import load_file
|
21 |
+
from collections import defaultdict
|
22 |
+
from diffusers import StableDiffusionControlNetPipeline
|
23 |
+
from diffusers import ControlNetModel, UniPCMultistepScheduler
|
24 |
+
from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
25 |
+
# from utils.tmp import StableDiffusionControlNetInpaintPipeline
|
26 |
+
# need the latest transformers
|
27 |
+
# pip install git+https://github.com/huggingface/transformers.git
|
28 |
+
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
29 |
+
|
30 |
+
# Segment-Anything init.
|
31 |
+
# pip install git+https://github.com/facebookresearch/segment-anything.git
|
32 |
+
try:
|
33 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
34 |
+
except ImportError:
|
35 |
+
print('segment_anything not installed')
|
36 |
+
result = subprocess.run(
|
37 |
+
['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
|
38 |
+
print(f'Install segment_anything {result}')
|
39 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
40 |
+
if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
|
41 |
+
result = subprocess.run(
|
42 |
+
['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
|
43 |
+
print(f'Download sam_vit_h_4b8939.pth {result}')
|
44 |
+
|
45 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
46 |
+
|
47 |
+
config_dict = OrderedDict([
|
48 |
+
('LAION Pretrained(v0-4)-SD15', 'shgao/edit-anything-v0-4-sd15'),
|
49 |
+
('LAION Pretrained(v0-4)-SD21', 'shgao/edit-anything-v0-4-sd21'),
|
50 |
+
])
|
51 |
+
|
52 |
+
|
53 |
+
def init_sam_model():
|
54 |
+
sam_checkpoint = "models/sam_vit_h_4b8939.pth"
|
55 |
+
model_type = "default"
|
56 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
57 |
+
sam.to(device=device)
|
58 |
+
sam_generator = SamAutomaticMaskGenerator(sam)
|
59 |
+
return sam_generator
|
60 |
+
|
61 |
+
|
62 |
+
def init_blip_processor():
|
63 |
+
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
64 |
+
return blip_processor
|
65 |
+
|
66 |
+
|
67 |
+
def init_blip_model():
|
68 |
+
blip_model = Blip2ForConditionalGeneration.from_pretrained(
|
69 |
+
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
|
70 |
+
return blip_model
|
71 |
+
|
72 |
+
|
73 |
+
def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
|
74 |
+
# https://github.com/huggingface/diffusers/issues/2136
|
75 |
+
""" Get pipeline embeds for prompts bigger than the maxlength of the pipe
|
76 |
+
:param pipeline:
|
77 |
+
:param prompt:
|
78 |
+
:param negative_prompt:
|
79 |
+
:param device:
|
80 |
+
:return:
|
81 |
+
"""
|
82 |
+
max_length = pipeline.tokenizer.model_max_length
|
83 |
+
|
84 |
+
# simple way to determine length of tokens
|
85 |
+
count_prompt = len(re.split(r', ', prompt))
|
86 |
+
count_negative_prompt = len(re.split(r', ', negative_prompt))
|
87 |
+
|
88 |
+
# create the tensor based on which prompt is longer
|
89 |
+
if count_prompt >= count_negative_prompt:
|
90 |
+
input_ids = pipeline.tokenizer(
|
91 |
+
prompt, return_tensors="pt", truncation=False).input_ids.to(device)
|
92 |
+
shape_max_length = input_ids.shape[-1]
|
93 |
+
negative_ids = pipeline.tokenizer(negative_prompt, truncation=False, padding="max_length",
|
94 |
+
max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
|
95 |
+
else:
|
96 |
+
negative_ids = pipeline.tokenizer(
|
97 |
+
negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
|
98 |
+
shape_max_length = negative_ids.shape[-1]
|
99 |
+
input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
|
100 |
+
max_length=shape_max_length).input_ids.to(device)
|
101 |
+
|
102 |
+
concat_embeds = []
|
103 |
+
neg_embeds = []
|
104 |
+
for i in range(0, shape_max_length, max_length):
|
105 |
+
concat_embeds.append(pipeline.text_encoder(
|
106 |
+
input_ids[:, i: i + max_length])[0])
|
107 |
+
neg_embeds.append(pipeline.text_encoder(
|
108 |
+
negative_ids[:, i: i + max_length])[0])
|
109 |
+
|
110 |
+
return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
|
111 |
+
|
112 |
+
|
113 |
+
def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
|
114 |
+
LORA_PREFIX_UNET = "lora_unet"
|
115 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
116 |
+
# load LoRA weight from .safetensors
|
117 |
+
if isinstance(checkpoint_path, str):
|
118 |
+
|
119 |
+
state_dict = load_file(checkpoint_path, device=device)
|
120 |
+
|
121 |
+
updates = defaultdict(dict)
|
122 |
+
for key, value in state_dict.items():
|
123 |
+
# it is suggested to print out the key, it usually will be something like below
|
124 |
+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
125 |
+
|
126 |
+
layer, elem = key.split('.', 1)
|
127 |
+
updates[layer][elem] = value
|
128 |
+
|
129 |
+
# directly update weight in diffusers model
|
130 |
+
for layer, elems in updates.items():
|
131 |
+
|
132 |
+
if "text" in layer:
|
133 |
+
layer_infos = layer.split(
|
134 |
+
LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
135 |
+
curr_layer = pipeline.text_encoder
|
136 |
+
else:
|
137 |
+
layer_infos = layer.split(
|
138 |
+
LORA_PREFIX_UNET + "_")[-1].split("_")
|
139 |
+
curr_layer = pipeline.unet
|
140 |
+
|
141 |
+
# find the target layer
|
142 |
+
temp_name = layer_infos.pop(0)
|
143 |
+
while len(layer_infos) > -1:
|
144 |
+
try:
|
145 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
146 |
+
if len(layer_infos) > 0:
|
147 |
+
temp_name = layer_infos.pop(0)
|
148 |
+
elif len(layer_infos) == 0:
|
149 |
+
break
|
150 |
+
except Exception:
|
151 |
+
if len(temp_name) > 0:
|
152 |
+
temp_name += "_" + layer_infos.pop(0)
|
153 |
+
else:
|
154 |
+
temp_name = layer_infos.pop(0)
|
155 |
+
|
156 |
+
# get elements for this layer
|
157 |
+
weight_up = elems['lora_up.weight'].to(dtype)
|
158 |
+
weight_down = elems['lora_down.weight'].to(dtype)
|
159 |
+
alpha = elems['alpha']
|
160 |
+
if alpha:
|
161 |
+
alpha = alpha.item() / weight_up.shape[1]
|
162 |
+
else:
|
163 |
+
alpha = 1.0
|
164 |
+
|
165 |
+
# update weight
|
166 |
+
if len(weight_up.shape) == 4:
|
167 |
+
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(
|
168 |
+
3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
169 |
+
else:
|
170 |
+
curr_layer.weight.data += multiplier * \
|
171 |
+
alpha * torch.mm(weight_up, weight_down)
|
172 |
+
else:
|
173 |
+
for ckptpath in checkpoint_path:
|
174 |
+
state_dict = load_file(ckptpath, device=device)
|
175 |
+
|
176 |
+
updates = defaultdict(dict)
|
177 |
+
for key, value in state_dict.items():
|
178 |
+
# it is suggested to print out the key, it usually will be something like below
|
179 |
+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
180 |
+
|
181 |
+
layer, elem = key.split('.', 1)
|
182 |
+
updates[layer][elem] = value
|
183 |
+
|
184 |
+
# directly update weight in diffusers model
|
185 |
+
for layer, elems in updates.items():
|
186 |
+
if "text" in layer:
|
187 |
+
layer_infos = layer.split(
|
188 |
+
LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
189 |
+
curr_layer = pipeline.text_encoder
|
190 |
+
else:
|
191 |
+
layer_infos = layer.split(
|
192 |
+
LORA_PREFIX_UNET + "_")[-1].split("_")
|
193 |
+
curr_layer = pipeline.unet
|
194 |
+
|
195 |
+
# find the target layer
|
196 |
+
temp_name = layer_infos.pop(0)
|
197 |
+
while len(layer_infos) > -1:
|
198 |
+
try:
|
199 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
200 |
+
if len(layer_infos) > 0:
|
201 |
+
temp_name = layer_infos.pop(0)
|
202 |
+
elif len(layer_infos) == 0:
|
203 |
+
break
|
204 |
+
except Exception:
|
205 |
+
if len(temp_name) > 0:
|
206 |
+
temp_name += "_" + layer_infos.pop(0)
|
207 |
+
else:
|
208 |
+
temp_name = layer_infos.pop(0)
|
209 |
+
|
210 |
+
# get elements for this layer
|
211 |
+
weight_up = elems['lora_up.weight'].to(dtype)
|
212 |
+
weight_down = elems['lora_down.weight'].to(dtype)
|
213 |
+
alpha = elems['alpha']
|
214 |
+
if alpha:
|
215 |
+
alpha = alpha.item() / weight_up.shape[1]
|
216 |
+
else:
|
217 |
+
alpha = 1.0
|
218 |
+
|
219 |
+
# update weight
|
220 |
+
if len(weight_up.shape) == 4:
|
221 |
+
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(
|
222 |
+
3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
223 |
+
else:
|
224 |
+
curr_layer.weight.data += multiplier * \
|
225 |
+
alpha * torch.mm(weight_up, weight_down)
|
226 |
+
return pipeline
|
227 |
+
|
228 |
+
|
229 |
+
def make_inpaint_condition(image, image_mask):
|
230 |
+
# image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
|
231 |
+
image = image / 255.0
|
232 |
+
print("img", image.max(), image.min(), image_mask.max(), image_mask.min())
|
233 |
+
# image_mask = np.array(image_mask.convert("L"))
|
234 |
+
assert image.shape[0:1] == image_mask.shape[0:
|
235 |
+
1], "image and image_mask must have the same image size"
|
236 |
+
image[image_mask > 128] = -1.0 # set as masked pixel
|
237 |
+
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
|
238 |
+
image = torch.from_numpy(image)
|
239 |
+
return image
|
240 |
+
|
241 |
+
|
242 |
+
def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, generation_only=False, extra_inpaint=True):
|
243 |
+
if generation_only and extra_inpaint:
|
244 |
+
controlnet = ControlNetModel.from_pretrained(
|
245 |
+
controlnet_path, torch_dtype=torch.float16)
|
246 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
247 |
+
base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
|
248 |
+
)
|
249 |
+
elif extra_inpaint:
|
250 |
+
print("Warning: ControlNet based inpainting model only support SD1.5 for now.")
|
251 |
+
controlnet = [
|
252 |
+
ControlNetModel.from_pretrained(
|
253 |
+
controlnet_path, torch_dtype=torch.float16),
|
254 |
+
ControlNetModel.from_pretrained(
|
255 |
+
'lllyasviel/control_v11p_sd15_inpaint', torch_dtype=torch.float16), # inpainting controlnet
|
256 |
+
]
|
257 |
+
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
258 |
+
base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
|
259 |
+
)
|
260 |
+
else:
|
261 |
+
controlnet = ControlNetModel.from_pretrained(
|
262 |
+
controlnet_path, torch_dtype=torch.float16)
|
263 |
+
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
264 |
+
base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
|
265 |
+
)
|
266 |
+
if lora_model_path is not None:
|
267 |
+
pipe = load_lora_weights(
|
268 |
+
pipe, [lora_model_path], 1.0, 'cpu', torch.float32)
|
269 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
270 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(
|
271 |
+
pipe.scheduler.config)
|
272 |
+
# remove following line if xformers is not installed
|
273 |
+
pipe.enable_xformers_memory_efficient_attention()
|
274 |
+
|
275 |
+
pipe.enable_model_cpu_offload()
|
276 |
+
return pipe
|
277 |
+
|
278 |
+
|
279 |
+
def show_anns(anns):
|
280 |
+
if len(anns) == 0:
|
281 |
+
return
|
282 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
283 |
+
full_img = None
|
284 |
+
|
285 |
+
# for ann in sorted_anns:
|
286 |
+
for i in range(len(sorted_anns)):
|
287 |
+
ann = anns[i]
|
288 |
+
m = ann['segmentation']
|
289 |
+
if full_img is None:
|
290 |
+
full_img = np.zeros((m.shape[0], m.shape[1], 3))
|
291 |
+
map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
|
292 |
+
map[m != 0] = i + 1
|
293 |
+
color_mask = np.random.random((1, 3)).tolist()[0]
|
294 |
+
full_img[m != 0] = color_mask
|
295 |
+
full_img = full_img*255
|
296 |
+
# anno encoding from https://github.com/LUSSeg/ImageNet-S
|
297 |
+
res = np.zeros((map.shape[0], map.shape[1], 3))
|
298 |
+
res[:, :, 0] = map % 256
|
299 |
+
res[:, :, 1] = map // 256
|
300 |
+
res.astype(np.float32)
|
301 |
+
full_img = Image.fromarray(np.uint8(full_img))
|
302 |
+
return full_img, res
|
303 |
+
|
304 |
+
|
305 |
+
class EditAnythingLoraModel:
|
306 |
+
def __init__(self,
|
307 |
+
base_model_path='../chilloutmix_NiPrunedFp32Fix',
|
308 |
+
lora_model_path='../40806/mix4', use_blip=True,
|
309 |
+
blip_processor=None,
|
310 |
+
blip_model=None,
|
311 |
+
sam_generator=None,
|
312 |
+
controlmodel_name='LAION Pretrained(v0-4)-SD15',
|
313 |
+
# used when the base model is not an inpainting model.
|
314 |
+
extra_inpaint=True,
|
315 |
+
):
|
316 |
+
self.device = device
|
317 |
+
self.use_blip = use_blip
|
318 |
+
|
319 |
+
# Diffusion init using diffusers.
|
320 |
+
self.default_controlnet_path = config_dict[controlmodel_name]
|
321 |
+
self.base_model_path = base_model_path
|
322 |
+
self.lora_model_path = lora_model_path
|
323 |
+
self.defalut_enable_all_generate = False
|
324 |
+
self.extra_inpaint = extra_inpaint
|
325 |
+
self.pipe = obtain_generation_model(
|
326 |
+
base_model_path, lora_model_path, self.default_controlnet_path, generation_only=False, extra_inpaint=extra_inpaint)
|
327 |
+
|
328 |
+
# Segment-Anything init.
|
329 |
+
if sam_generator is not None:
|
330 |
+
self.sam_generator = sam_generator
|
331 |
+
else:
|
332 |
+
self.sam_generator = init_sam_model()
|
333 |
+
|
334 |
+
# BLIP2 init.
|
335 |
+
if use_blip:
|
336 |
+
if blip_processor is not None:
|
337 |
+
self.blip_processor = blip_processor
|
338 |
+
else:
|
339 |
+
self.blip_processor = init_blip_processor()
|
340 |
+
|
341 |
+
if blip_model is not None:
|
342 |
+
self.blip_model = blip_model
|
343 |
+
else:
|
344 |
+
self.blip_model = init_blip_model()
|
345 |
+
|
346 |
+
def get_blip2_text(self, image):
|
347 |
+
inputs = self.blip_processor(image, return_tensors="pt").to(
|
348 |
+
self.device, torch.float16)
|
349 |
+
generated_ids = self.blip_model.generate(**inputs, max_new_tokens=50)
|
350 |
+
generated_text = self.blip_processor.batch_decode(
|
351 |
+
generated_ids, skip_special_tokens=True)[0].strip()
|
352 |
+
return generated_text
|
353 |
+
|
354 |
+
def get_sam_control(self, image):
|
355 |
+
masks = self.sam_generator.generate(image)
|
356 |
+
full_img, res = show_anns(masks)
|
357 |
+
return full_img, res
|
358 |
+
|
359 |
+
@torch.inference_mode()
|
360 |
+
def process(self, condition_model, source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
|
361 |
+
|
362 |
+
input_image = source_image["image"]
|
363 |
+
if mask_image is None:
|
364 |
+
if enable_all_generate != self.defalut_enable_all_generate:
|
365 |
+
self.pipe = obtain_generation_model(
|
366 |
+
self.base_model_path, self.lora_model_path, config_dict[condition_model], enable_all_generate, self.extra_inpaint)
|
367 |
+
self.defalut_enable_all_generate = enable_all_generate
|
368 |
+
if enable_all_generate:
|
369 |
+
print("source_image",
|
370 |
+
source_image["mask"].shape, input_image.shape,)
|
371 |
+
mask_image = np.ones(
|
372 |
+
(input_image.shape[0], input_image.shape[1], 3))*255
|
373 |
+
else:
|
374 |
+
mask_image = source_image["mask"]
|
375 |
+
if self.default_controlnet_path != config_dict[condition_model]:
|
376 |
+
print("To Use:", config_dict[condition_model],
|
377 |
+
"Current:", self.default_controlnet_path)
|
378 |
+
print("Change condition model to:", config_dict[condition_model])
|
379 |
+
self.pipe = obtain_generation_model(
|
380 |
+
self.base_model_path, self.lora_model_path, config_dict[condition_model], enable_all_generate, self.extra_inpaint)
|
381 |
+
self.default_controlnet_path = config_dict[condition_model]
|
382 |
+
torch.cuda.empty_cache()
|
383 |
+
|
384 |
+
with torch.no_grad():
|
385 |
+
if self.use_blip and enable_auto_prompt:
|
386 |
+
print("Generating text:")
|
387 |
+
blip2_prompt = self.get_blip2_text(input_image)
|
388 |
+
print("Generated text:", blip2_prompt)
|
389 |
+
if len(prompt) > 0:
|
390 |
+
prompt = blip2_prompt + ',' + prompt
|
391 |
+
else:
|
392 |
+
prompt = blip2_prompt
|
393 |
+
|
394 |
+
input_image = HWC3(input_image)
|
395 |
+
|
396 |
+
img = resize_image(input_image, image_resolution)
|
397 |
+
H, W, C = img.shape
|
398 |
+
|
399 |
+
print("Generating SAM seg:")
|
400 |
+
# the default SAM model is trained with 1024 size.
|
401 |
+
full_segmask, detected_map = self.get_sam_control(
|
402 |
+
resize_image(input_image, detect_resolution))
|
403 |
+
|
404 |
+
detected_map = HWC3(detected_map.astype(np.uint8))
|
405 |
+
detected_map = cv2.resize(
|
406 |
+
detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
407 |
+
|
408 |
+
control = torch.from_numpy(
|
409 |
+
detected_map.copy()).float().cuda()
|
410 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
411 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
412 |
+
|
413 |
+
mask_image = HWC3(mask_image.astype(np.uint8))
|
414 |
+
mask_image = cv2.resize(
|
415 |
+
mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
|
416 |
+
if self.extra_inpaint:
|
417 |
+
inpaint_image = make_inpaint_condition(img, mask_image)
|
418 |
+
mask_image = Image.fromarray(mask_image)
|
419 |
+
|
420 |
+
if seed == -1:
|
421 |
+
seed = random.randint(0, 65535)
|
422 |
+
seed_everything(seed)
|
423 |
+
generator = torch.manual_seed(seed)
|
424 |
+
postive_prompt = prompt + ', ' + a_prompt
|
425 |
+
negative_prompt = n_prompt
|
426 |
+
prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
|
427 |
+
self.pipe, postive_prompt, negative_prompt, "cuda")
|
428 |
+
prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
|
429 |
+
negative_prompt_embeds = torch.cat(
|
430 |
+
[negative_prompt_embeds] * num_samples, dim=0)
|
431 |
+
if enable_all_generate and self.extra_inpaint:
|
432 |
+
print(control.shape, control_scale)
|
433 |
+
self.pipe.safety_checker = lambda images, clip_input: (
|
434 |
+
images, False)
|
435 |
+
x_samples = self.pipe(
|
436 |
+
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
|
437 |
+
num_images_per_prompt=num_samples,
|
438 |
+
num_inference_steps=ddim_steps,
|
439 |
+
generator=generator,
|
440 |
+
height=H,
|
441 |
+
width=W,
|
442 |
+
image=control.type(torch.float16),
|
443 |
+
controlnet_conditioning_scale=float(control_scale),
|
444 |
+
).images
|
445 |
+
elif self.extra_inpaint:
|
446 |
+
x_samples = self.pipe(
|
447 |
+
image=img,
|
448 |
+
mask_image=mask_image,
|
449 |
+
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
|
450 |
+
num_images_per_prompt=num_samples,
|
451 |
+
num_inference_steps=ddim_steps,
|
452 |
+
generator=generator,
|
453 |
+
controlnet_conditioning_image=[control.type(
|
454 |
+
torch.float16), inpaint_image.type(torch.float16)],
|
455 |
+
height=H,
|
456 |
+
width=W,
|
457 |
+
controlnet_conditioning_scale=(float(control_scale), 1.0),
|
458 |
+
).images
|
459 |
+
else:
|
460 |
+
x_samples = self.pipe(
|
461 |
+
image=img,
|
462 |
+
mask_image=mask_image,
|
463 |
+
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
|
464 |
+
num_images_per_prompt=num_samples,
|
465 |
+
num_inference_steps=ddim_steps,
|
466 |
+
generator=generator,
|
467 |
+
controlnet_conditioning_image=control.type(torch.float16),
|
468 |
+
height=H,
|
469 |
+
width=W,
|
470 |
+
controlnet_conditioning_scale=float(control_scale),
|
471 |
+
).images
|
472 |
+
|
473 |
+
results = [x_samples[i] for i in range(num_samples)]
|
474 |
+
return [full_segmask, mask_image] + results, prompt
|
475 |
+
|
476 |
+
def download_image(url):
|
477 |
+
response = requests.get(url)
|
478 |
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
utils/stable_diffusion_controlnet_inpaint.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
|
2 |
# From https://raw.githubusercontent.com/huggingface/diffusers/53377ef83c6446033f3ee506e3ef718db817b293/examples/community/stable_diffusion_controlnet_inpaint.py
|
3 |
import inspect
|
4 |
-
from typing import Any, Callable, Dict, List, Optional, Union
|
5 |
|
6 |
import numpy as np
|
7 |
import PIL.Image
|
@@ -11,6 +11,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
|
11 |
|
12 |
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
|
13 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
|
|
14 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
15 |
from diffusers.utils import (
|
16 |
PIL_INTERPOLATION,
|
@@ -19,7 +20,7 @@ from diffusers.utils import (
|
|
19 |
randn_tensor,
|
20 |
replace_example_docstring,
|
21 |
)
|
22 |
-
|
23 |
|
24 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
25 |
|
@@ -184,7 +185,7 @@ def prepare_mask_image(mask_image):
|
|
184 |
|
185 |
|
186 |
def prepare_controlnet_conditioning_image(
|
187 |
-
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
|
188 |
):
|
189 |
if not isinstance(controlnet_conditioning_image, torch.Tensor):
|
190 |
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
|
@@ -214,10 +215,13 @@ def prepare_controlnet_conditioning_image(
|
|
214 |
|
215 |
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
|
216 |
|
|
|
|
|
|
|
217 |
return controlnet_conditioning_image
|
218 |
|
219 |
|
220 |
-
class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
221 |
"""
|
222 |
Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
|
223 |
"""
|
@@ -230,7 +234,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
230 |
text_encoder: CLIPTextModel,
|
231 |
tokenizer: CLIPTokenizer,
|
232 |
unet: UNet2DConditionModel,
|
233 |
-
controlnet: ControlNetModel,
|
234 |
scheduler: KarrasDiffusionSchedulers,
|
235 |
safety_checker: StableDiffusionSafetyChecker,
|
236 |
feature_extractor: CLIPImageProcessor,
|
@@ -253,7 +257,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
253 |
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
254 |
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
255 |
)
|
256 |
-
|
|
|
257 |
self.register_modules(
|
258 |
vae=vae,
|
259 |
text_encoder=text_encoder,
|
@@ -522,6 +527,42 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
522 |
extra_step_kwargs["generator"] = generator
|
523 |
return extra_step_kwargs
|
524 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
def check_inputs(
|
526 |
self,
|
527 |
prompt,
|
@@ -534,6 +575,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
534 |
negative_prompt=None,
|
535 |
prompt_embeds=None,
|
536 |
negative_prompt_embeds=None,
|
|
|
537 |
):
|
538 |
if height % 8 != 0 or width % 8 != 0:
|
539 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
@@ -572,45 +614,35 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
572 |
f" {negative_prompt_embeds.shape}."
|
573 |
)
|
574 |
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
controlnet_conditioning_image
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
)
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
prompt_batch_size = 1
|
605 |
-
elif prompt is not None and isinstance(prompt, list):
|
606 |
-
prompt_batch_size = len(prompt)
|
607 |
-
elif prompt_embeds is not None:
|
608 |
-
prompt_batch_size = prompt_embeds.shape[0]
|
609 |
-
|
610 |
-
if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
|
611 |
-
raise ValueError(
|
612 |
-
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
|
613 |
-
)
|
614 |
|
615 |
if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
|
616 |
raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
|
@@ -630,6 +662,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
630 |
image_channels, image_height, image_width = image.shape
|
631 |
elif image.ndim == 4:
|
632 |
image_batch_size, image_channels, image_height, image_width = image.shape
|
|
|
|
|
633 |
|
634 |
if mask_image.ndim == 2:
|
635 |
mask_image_batch_size = 1
|
@@ -664,8 +698,11 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
664 |
|
665 |
single_image_latent_channels = self.vae.config.latent_channels
|
666 |
|
667 |
-
|
668 |
-
|
|
|
|
|
|
|
669 |
if total_latent_channels != self.unet.config.in_channels:
|
670 |
raise ValueError(
|
671 |
f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
|
@@ -797,7 +834,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
797 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
798 |
callback_steps: int = 1,
|
799 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
800 |
-
controlnet_conditioning_scale: float = 1.0,
|
801 |
):
|
802 |
r"""
|
803 |
Function invoked when calling the pipeline for generation.
|
@@ -897,6 +934,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
897 |
negative_prompt,
|
898 |
prompt_embeds,
|
899 |
negative_prompt_embeds,
|
|
|
900 |
)
|
901 |
|
902 |
# 2. Define call parameters
|
@@ -913,6 +951,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
913 |
# corresponds to doing no classifier free guidance.
|
914 |
do_classifier_free_guidance = guidance_scale > 1.0
|
915 |
|
|
|
|
|
|
|
916 |
# 3. Encode input prompt
|
917 |
prompt_embeds = self._encode_prompt(
|
918 |
prompt,
|
@@ -929,15 +970,37 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
929 |
|
930 |
mask_image = prepare_mask_image(mask_image)
|
931 |
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
941 |
|
942 |
masked_image = image * (mask_image < 0.5)
|
943 |
|
@@ -958,29 +1021,45 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
958 |
latents,
|
959 |
)
|
960 |
|
961 |
-
|
962 |
-
mask_image,
|
963 |
-
batch_size * num_images_per_prompt,
|
964 |
-
height,
|
965 |
-
width,
|
966 |
-
prompt_embeds.dtype,
|
967 |
-
device,
|
968 |
-
do_classifier_free_guidance,
|
969 |
-
)
|
970 |
|
971 |
-
|
972 |
-
|
973 |
-
|
974 |
-
|
975 |
-
|
976 |
-
|
977 |
-
|
978 |
-
|
979 |
-
|
980 |
-
|
981 |
|
982 |
-
|
983 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
984 |
|
985 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
986 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
@@ -997,25 +1076,22 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
997 |
non_inpainting_latent_model_input = self.scheduler.scale_model_input(
|
998 |
non_inpainting_latent_model_input, t
|
999 |
)
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
|
|
|
|
1004 |
|
1005 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1006 |
non_inpainting_latent_model_input,
|
1007 |
t,
|
1008 |
encoder_hidden_states=prompt_embeds,
|
1009 |
controlnet_cond=controlnet_conditioning_image,
|
|
|
1010 |
return_dict=False,
|
1011 |
)
|
1012 |
|
1013 |
-
down_block_res_samples = [
|
1014 |
-
down_block_res_sample * controlnet_conditioning_scale
|
1015 |
-
for down_block_res_sample in down_block_res_samples
|
1016 |
-
]
|
1017 |
-
mid_block_res_sample *= controlnet_conditioning_scale
|
1018 |
-
|
1019 |
# predict the noise residual
|
1020 |
noise_pred = self.unet(
|
1021 |
inpainting_latent_model_input,
|
@@ -1039,6 +1115,14 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
|
1039 |
progress_bar.update()
|
1040 |
if callback is not None and i % callback_steps == 0:
|
1041 |
callback(i, t, latents)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1042 |
|
1043 |
# If we do sequential model offloading, let's offload unet and controlnet
|
1044 |
# manually for max memory savings
|
|
|
1 |
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
|
2 |
# From https://raw.githubusercontent.com/huggingface/diffusers/53377ef83c6446033f3ee506e3ef718db817b293/examples/community/stable_diffusion_controlnet_inpaint.py
|
3 |
import inspect
|
4 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
5 |
|
6 |
import numpy as np
|
7 |
import PIL.Image
|
|
|
11 |
|
12 |
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
|
13 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
14 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
15 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
16 |
from diffusers.utils import (
|
17 |
PIL_INTERPOLATION,
|
|
|
20 |
randn_tensor,
|
21 |
replace_example_docstring,
|
22 |
)
|
23 |
+
from diffusers.loaders import LoraLoaderMixin
|
24 |
|
25 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
26 |
|
|
|
185 |
|
186 |
|
187 |
def prepare_controlnet_conditioning_image(
|
188 |
+
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance,
|
189 |
):
|
190 |
if not isinstance(controlnet_conditioning_image, torch.Tensor):
|
191 |
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
|
|
|
215 |
|
216 |
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
|
217 |
|
218 |
+
if do_classifier_free_guidance:
|
219 |
+
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
|
220 |
+
|
221 |
return controlnet_conditioning_image
|
222 |
|
223 |
|
224 |
+
class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixin):
|
225 |
"""
|
226 |
Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
|
227 |
"""
|
|
|
234 |
text_encoder: CLIPTextModel,
|
235 |
tokenizer: CLIPTokenizer,
|
236 |
unet: UNet2DConditionModel,
|
237 |
+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
238 |
scheduler: KarrasDiffusionSchedulers,
|
239 |
safety_checker: StableDiffusionSafetyChecker,
|
240 |
feature_extractor: CLIPImageProcessor,
|
|
|
257 |
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
258 |
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
259 |
)
|
260 |
+
if isinstance(controlnet, (list, tuple)):
|
261 |
+
controlnet = MultiControlNetModel(controlnet)
|
262 |
self.register_modules(
|
263 |
vae=vae,
|
264 |
text_encoder=text_encoder,
|
|
|
527 |
extra_step_kwargs["generator"] = generator
|
528 |
return extra_step_kwargs
|
529 |
|
530 |
+
def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
|
531 |
+
image_is_pil = isinstance(image, PIL.Image.Image)
|
532 |
+
image_is_tensor = isinstance(image, torch.Tensor)
|
533 |
+
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
534 |
+
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
535 |
+
|
536 |
+
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
|
537 |
+
raise TypeError(
|
538 |
+
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
|
539 |
+
)
|
540 |
+
|
541 |
+
if image_is_pil:
|
542 |
+
image_batch_size = 1
|
543 |
+
elif image_is_tensor:
|
544 |
+
image_batch_size = image.shape[0]
|
545 |
+
elif image_is_pil_list:
|
546 |
+
image_batch_size = len(image)
|
547 |
+
elif image_is_tensor_list:
|
548 |
+
image_batch_size = len(image)
|
549 |
+
else:
|
550 |
+
raise ValueError("controlnet condition image is not valid")
|
551 |
+
|
552 |
+
if prompt is not None and isinstance(prompt, str):
|
553 |
+
prompt_batch_size = 1
|
554 |
+
elif prompt is not None and isinstance(prompt, list):
|
555 |
+
prompt_batch_size = len(prompt)
|
556 |
+
elif prompt_embeds is not None:
|
557 |
+
prompt_batch_size = prompt_embeds.shape[0]
|
558 |
+
else:
|
559 |
+
raise ValueError("prompt or prompt_embeds are not valid")
|
560 |
+
|
561 |
+
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
562 |
+
raise ValueError(
|
563 |
+
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
564 |
+
)
|
565 |
+
|
566 |
def check_inputs(
|
567 |
self,
|
568 |
prompt,
|
|
|
575 |
negative_prompt=None,
|
576 |
prompt_embeds=None,
|
577 |
negative_prompt_embeds=None,
|
578 |
+
controlnet_conditioning_scale=None,
|
579 |
):
|
580 |
if height % 8 != 0 or width % 8 != 0:
|
581 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
|
|
614 |
f" {negative_prompt_embeds.shape}."
|
615 |
)
|
616 |
|
617 |
+
# check controlnet condition image
|
618 |
+
if isinstance(self.controlnet, ControlNetModel):
|
619 |
+
self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
|
620 |
+
elif isinstance(self.controlnet, MultiControlNetModel):
|
621 |
+
if not isinstance(controlnet_conditioning_image, list):
|
622 |
+
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
623 |
+
if len(controlnet_conditioning_image) != len(self.controlnet.nets):
|
624 |
+
raise ValueError(
|
625 |
+
"For multiple controlnets: `image` must have the same length as the number of controlnets."
|
626 |
+
)
|
627 |
+
for image_ in controlnet_conditioning_image:
|
628 |
+
self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
|
629 |
+
else:
|
630 |
+
assert False
|
631 |
+
|
632 |
+
# Check `controlnet_conditioning_scale`
|
633 |
+
if isinstance(self.controlnet, ControlNetModel):
|
634 |
+
if not isinstance(controlnet_conditioning_scale, float):
|
635 |
+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
636 |
+
elif isinstance(self.controlnet, MultiControlNetModel):
|
637 |
+
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
638 |
+
self.controlnet.nets
|
639 |
+
):
|
640 |
+
raise ValueError(
|
641 |
+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
642 |
+
" the same length as the number of controlnets"
|
643 |
+
)
|
644 |
+
else:
|
645 |
+
assert False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
|
647 |
if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
|
648 |
raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
|
|
|
662 |
image_channels, image_height, image_width = image.shape
|
663 |
elif image.ndim == 4:
|
664 |
image_batch_size, image_channels, image_height, image_width = image.shape
|
665 |
+
else:
|
666 |
+
assert False
|
667 |
|
668 |
if mask_image.ndim == 2:
|
669 |
mask_image_batch_size = 1
|
|
|
698 |
|
699 |
single_image_latent_channels = self.vae.config.latent_channels
|
700 |
|
701 |
+
if self.unet.config.in_channels==4:
|
702 |
+
total_latent_channels = single_image_latent_channels # support base model without inpainting ability.
|
703 |
+
else:
|
704 |
+
total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
|
705 |
+
|
706 |
if total_latent_channels != self.unet.config.in_channels:
|
707 |
raise ValueError(
|
708 |
f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
|
|
|
834 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
835 |
callback_steps: int = 1,
|
836 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
837 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
838 |
):
|
839 |
r"""
|
840 |
Function invoked when calling the pipeline for generation.
|
|
|
934 |
negative_prompt,
|
935 |
prompt_embeds,
|
936 |
negative_prompt_embeds,
|
937 |
+
controlnet_conditioning_scale,
|
938 |
)
|
939 |
|
940 |
# 2. Define call parameters
|
|
|
951 |
# corresponds to doing no classifier free guidance.
|
952 |
do_classifier_free_guidance = guidance_scale > 1.0
|
953 |
|
954 |
+
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
955 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
|
956 |
+
|
957 |
# 3. Encode input prompt
|
958 |
prompt_embeds = self._encode_prompt(
|
959 |
prompt,
|
|
|
970 |
|
971 |
mask_image = prepare_mask_image(mask_image)
|
972 |
|
973 |
+
# condition image(s)
|
974 |
+
if isinstance(self.controlnet, ControlNetModel):
|
975 |
+
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
|
976 |
+
controlnet_conditioning_image=controlnet_conditioning_image,
|
977 |
+
width=width,
|
978 |
+
height=height,
|
979 |
+
batch_size=batch_size * num_images_per_prompt,
|
980 |
+
num_images_per_prompt=num_images_per_prompt,
|
981 |
+
device=device,
|
982 |
+
dtype=self.controlnet.dtype,
|
983 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
984 |
+
)
|
985 |
+
elif isinstance(self.controlnet, MultiControlNetModel):
|
986 |
+
controlnet_conditioning_images = []
|
987 |
+
|
988 |
+
for image_ in controlnet_conditioning_image:
|
989 |
+
image_ = prepare_controlnet_conditioning_image(
|
990 |
+
controlnet_conditioning_image=image_,
|
991 |
+
width=width,
|
992 |
+
height=height,
|
993 |
+
batch_size=batch_size * num_images_per_prompt,
|
994 |
+
num_images_per_prompt=num_images_per_prompt,
|
995 |
+
device=device,
|
996 |
+
dtype=self.controlnet.dtype,
|
997 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
998 |
+
)
|
999 |
+
controlnet_conditioning_images.append(image_)
|
1000 |
+
|
1001 |
+
controlnet_conditioning_image = controlnet_conditioning_images
|
1002 |
+
else:
|
1003 |
+
assert False
|
1004 |
|
1005 |
masked_image = image * (mask_image < 0.5)
|
1006 |
|
|
|
1021 |
latents,
|
1022 |
)
|
1023 |
|
1024 |
+
noise = latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1025 |
|
1026 |
+
if self.unet.config.in_channels!=4:
|
1027 |
+
mask_image_latents = self.prepare_mask_latents(
|
1028 |
+
mask_image,
|
1029 |
+
batch_size * num_images_per_prompt,
|
1030 |
+
height,
|
1031 |
+
width,
|
1032 |
+
prompt_embeds.dtype,
|
1033 |
+
device,
|
1034 |
+
do_classifier_free_guidance,
|
1035 |
+
)
|
1036 |
|
1037 |
+
masked_image_latents = self.prepare_masked_image_latents(
|
1038 |
+
masked_image,
|
1039 |
+
batch_size * num_images_per_prompt,
|
1040 |
+
height,
|
1041 |
+
width,
|
1042 |
+
prompt_embeds.dtype,
|
1043 |
+
device,
|
1044 |
+
generator,
|
1045 |
+
do_classifier_free_guidance,
|
1046 |
+
)
|
1047 |
+
if self.unet.config.in_channels==4:
|
1048 |
+
init_masked_image_latents, _ = self.prepare_masked_image_latents(
|
1049 |
+
image,
|
1050 |
+
batch_size * num_images_per_prompt,
|
1051 |
+
height,
|
1052 |
+
width,
|
1053 |
+
prompt_embeds.dtype,
|
1054 |
+
device,
|
1055 |
+
generator,
|
1056 |
+
do_classifier_free_guidance,
|
1057 |
+
).chunk(2)
|
1058 |
+
print(type(mask_image), mask_image.shape)
|
1059 |
+
_, _, w, h = mask_image.shape
|
1060 |
+
mask_image = torch.nn.functional.interpolate(mask_image, ((w // 8, h // 8)), mode='nearest')
|
1061 |
+
mask_image = mask_image.to(latents.device).type_as(latents)
|
1062 |
+
mask_image = 1 - mask_image
|
1063 |
|
1064 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1065 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
|
1076 |
non_inpainting_latent_model_input = self.scheduler.scale_model_input(
|
1077 |
non_inpainting_latent_model_input, t
|
1078 |
)
|
1079 |
+
if self.unet.config.in_channels!=4:
|
1080 |
+
inpainting_latent_model_input = torch.cat(
|
1081 |
+
[non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
|
1082 |
+
)
|
1083 |
+
else:
|
1084 |
+
inpainting_latent_model_input = non_inpainting_latent_model_input
|
1085 |
|
1086 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1087 |
non_inpainting_latent_model_input,
|
1088 |
t,
|
1089 |
encoder_hidden_states=prompt_embeds,
|
1090 |
controlnet_cond=controlnet_conditioning_image,
|
1091 |
+
conditioning_scale=controlnet_conditioning_scale,
|
1092 |
return_dict=False,
|
1093 |
)
|
1094 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1095 |
# predict the noise residual
|
1096 |
noise_pred = self.unet(
|
1097 |
inpainting_latent_model_input,
|
|
|
1115 |
progress_bar.update()
|
1116 |
if callback is not None and i % callback_steps == 0:
|
1117 |
callback(i, t, latents)
|
1118 |
+
# if self.unet.config.in_channels==4:
|
1119 |
+
# # masking for non-inpainting models
|
1120 |
+
# init_latents_proper = self.scheduler.add_noise(init_masked_image_latents, noise, t)
|
1121 |
+
# latents = (init_latents_proper * mask_image) + (latents * (1 - mask_image))
|
1122 |
+
|
1123 |
+
if self.unet.config.in_channels==4:
|
1124 |
+
# fill the unmasked part with original image
|
1125 |
+
latents = (init_masked_image_latents * mask_image) + (latents * (1 - mask_image))
|
1126 |
|
1127 |
# If we do sequential model offloading, let's offload unet and controlnet
|
1128 |
# manually for max memory savings
|