Spaces:
Runtime error
Runtime error
update new versiion
Browse files- app.py +14 -13
- requirements.txt +1 -1
- sam2edit.py +14 -68
- sam2edit_beauty.py +18 -66
- sam2edit_demo.py +140 -0
- sam2edit_handsome.py +17 -67
- sam2edit_lora.py +143 -62
- utils/stable_diffusion_controlnet_inpaint.py +9 -5
app.py
CHANGED
@@ -15,9 +15,7 @@ SHARED_UI_WARNING = f'''### [NOTE] Inference may be slow in this shared UI.
|
|
15 |
You can duplicate and use it with a paid private GPU.
|
16 |
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
|
17 |
'''
|
18 |
-
|
19 |
-
#
|
20 |
-
sam_generator = init_sam_model()
|
21 |
blip_processor = init_blip_processor()
|
22 |
blip_model = init_blip_model()
|
23 |
|
@@ -31,30 +29,33 @@ with gr.Blocks() as demo:
|
|
31 |
controlmodel_name='LAION Pretrained(v0-4)-SD21',
|
32 |
lora_model_path=None, use_blip=True, extra_inpaint=False,
|
33 |
sam_generator=sam_generator,
|
|
|
34 |
blip_processor=blip_processor,
|
35 |
blip_model=blip_model)
|
36 |
-
create_demo_edit_anything(model.process)
|
37 |
with gr.TabItem(' π©βπ¦°Beauty Edit/Generation'):
|
38 |
lora_model_path = hf_hub_download(
|
39 |
"mlida/Cute_girl_mix4", "cuteGirlMix4_v10.safetensors")
|
40 |
model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "chilloutmix_NiPrunedFp32Fix"),
|
41 |
lora_model_path=lora_model_path, use_blip=True, extra_inpaint=True,
|
42 |
sam_generator=sam_generator,
|
|
|
43 |
blip_processor=blip_processor,
|
44 |
blip_model=blip_model,
|
45 |
lora_weight=0.5,
|
46 |
)
|
47 |
-
create_demo_beauty(model.process)
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
55 |
# with gr.TabItem('Generate Anything'):
|
56 |
# create_demo_generate_anything()
|
57 |
with gr.Tabs():
|
58 |
gr.Markdown(SHARED_UI_WARNING)
|
59 |
|
60 |
-
demo.queue(api_open=False).launch()
|
|
|
15 |
You can duplicate and use it with a paid private GPU.
|
16 |
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
|
17 |
'''
|
18 |
+
sam_generator, mask_predictor = init_sam_model()
|
|
|
|
|
19 |
blip_processor = init_blip_processor()
|
20 |
blip_model = init_blip_model()
|
21 |
|
|
|
29 |
controlmodel_name='LAION Pretrained(v0-4)-SD21',
|
30 |
lora_model_path=None, use_blip=True, extra_inpaint=False,
|
31 |
sam_generator=sam_generator,
|
32 |
+
mask_predictor=mask_predictor,
|
33 |
blip_processor=blip_processor,
|
34 |
blip_model=blip_model)
|
35 |
+
create_demo_edit_anything(model.process, model.process_image_click)
|
36 |
with gr.TabItem(' π©βπ¦°Beauty Edit/Generation'):
|
37 |
lora_model_path = hf_hub_download(
|
38 |
"mlida/Cute_girl_mix4", "cuteGirlMix4_v10.safetensors")
|
39 |
model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "chilloutmix_NiPrunedFp32Fix"),
|
40 |
lora_model_path=lora_model_path, use_blip=True, extra_inpaint=True,
|
41 |
sam_generator=sam_generator,
|
42 |
+
mask_predictor=mask_predictor,
|
43 |
blip_processor=blip_processor,
|
44 |
blip_model=blip_model,
|
45 |
lora_weight=0.5,
|
46 |
)
|
47 |
+
create_demo_beauty(model.process, model.process_image_click)
|
48 |
+
with gr.TabItem(' π¨βπΎHandsome Edit/Generation'):
|
49 |
+
model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "Realistic_Vision_V2.0"),
|
50 |
+
lora_model_path=None, use_blip=True, extra_inpaint=True,
|
51 |
+
sam_generator=sam_generator,
|
52 |
+
mask_predictor=mask_predictor,
|
53 |
+
blip_processor=blip_processor,
|
54 |
+
blip_model=blip_model)
|
55 |
+
create_demo_handsome(model.process, model.process_image_click)
|
56 |
# with gr.TabItem('Generate Anything'):
|
57 |
# create_demo_generate_anything()
|
58 |
with gr.Tabs():
|
59 |
gr.Markdown(SHARED_UI_WARNING)
|
60 |
|
61 |
+
demo.queue(api_open=False).launch(server_name='0.0.0.0', share=False)
|
requirements.txt
CHANGED
@@ -3,7 +3,7 @@ torch==1.13.1+cu117
|
|
3 |
torchvision==0.14.1+cu117
|
4 |
torchaudio==0.13.1
|
5 |
numpy==1.23.1
|
6 |
-
gradio==3.
|
7 |
gradio_client==0.1.4
|
8 |
albumentations==1.3.0
|
9 |
opencv-contrib-python==4.3.0.36
|
|
|
3 |
torchvision==0.14.1+cu117
|
4 |
torchaudio==0.13.1
|
5 |
numpy==1.23.1
|
6 |
+
gradio==3.30.0
|
7 |
gradio_client==0.1.4
|
8 |
albumentations==1.3.0
|
9 |
opencv-contrib-python==4.3.0.36
|
sam2edit.py
CHANGED
@@ -1,82 +1,28 @@
|
|
1 |
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
|
|
2 |
import gradio as gr
|
3 |
from diffusers.utils import load_image
|
4 |
from sam2edit_lora import EditAnythingLoraModel, config_dict
|
|
|
|
|
5 |
|
6 |
|
7 |
-
def create_demo(process):
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
|
13 |
-
We are not responsible for possible risks using this model.
|
14 |
'''
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
"## EditAnything https://github.com/sail-sg/EditAnything ")
|
20 |
-
with gr.Row():
|
21 |
-
with gr.Column():
|
22 |
-
source_image = gr.Image(
|
23 |
-
source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
|
24 |
-
enable_all_generate = gr.Checkbox(
|
25 |
-
label='Auto generation on all region.', value=False)
|
26 |
-
prompt = gr.Textbox(
|
27 |
-
label="Prompt (Text in the expected things of edited region)")
|
28 |
-
enable_auto_prompt = gr.Checkbox(
|
29 |
-
label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
|
30 |
-
a_prompt = gr.Textbox(
|
31 |
-
label="Added Prompt", value='best quality, extremely detailed')
|
32 |
-
n_prompt = gr.Textbox(label="Negative Prompt",
|
33 |
-
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
34 |
-
control_scale = gr.Slider(
|
35 |
-
label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
|
36 |
-
run_button = gr.Button(label="Run")
|
37 |
-
num_samples = gr.Slider(
|
38 |
-
label="Images", minimum=1, maximum=12, value=2, step=1)
|
39 |
-
seed = gr.Slider(label="Seed", minimum=-1,
|
40 |
-
maximum=2147483647, step=1, randomize=True)
|
41 |
-
enable_tile = gr.Checkbox(
|
42 |
-
label='Tile refinement for high resolution generation.', value=True)
|
43 |
-
with gr.Accordion("Advanced options", open=False):
|
44 |
-
mask_image = gr.Image(
|
45 |
-
source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
|
46 |
-
image_resolution = gr.Slider(
|
47 |
-
label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
48 |
-
strength = gr.Slider(
|
49 |
-
label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
50 |
-
guess_mode = gr.Checkbox(
|
51 |
-
label='Guess Mode', value=False)
|
52 |
-
detect_resolution = gr.Slider(
|
53 |
-
label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
|
54 |
-
ddim_steps = gr.Slider(
|
55 |
-
label="Steps", minimum=1, maximum=100, value=30, step=1)
|
56 |
-
scale = gr.Slider(
|
57 |
-
label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
58 |
-
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
59 |
-
with gr.Column():
|
60 |
-
result_gallery = gr.Gallery(
|
61 |
-
label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
62 |
-
result_text = gr.Text(label='BLIP2+Human Prompt Text')
|
63 |
-
ips = [source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
|
64 |
-
detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile]
|
65 |
-
run_button.click(fn=process, inputs=ips, outputs=[
|
66 |
-
result_gallery, result_text])
|
67 |
-
# with gr.Row():
|
68 |
-
# ex = gr.Examples(examples=examples, fn=process,
|
69 |
-
# inputs=[a_prompt, n_prompt, scale],
|
70 |
-
# outputs=[result_gallery],
|
71 |
-
# cache_examples=False)
|
72 |
-
with gr.Row():
|
73 |
-
gr.Markdown(WARNING_INFO)
|
74 |
return demo
|
75 |
|
76 |
|
77 |
if __name__ == '__main__':
|
78 |
-
model = EditAnythingLoraModel(base_model_path="stabilityai/stable-diffusion-2
|
79 |
-
controlmodel_name='LAION Pretrained(v0-4)-SD21', extra_inpaint=
|
80 |
lora_model_path=None, use_blip=True)
|
81 |
-
demo = create_demo(model.process)
|
82 |
demo.queue().launch(server_name='0.0.0.0')
|
|
|
1 |
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
2 |
+
import os
|
3 |
import gradio as gr
|
4 |
from diffusers.utils import load_image
|
5 |
from sam2edit_lora import EditAnythingLoraModel, config_dict
|
6 |
+
from sam2edit_demo import create_demo_template
|
7 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
8 |
|
9 |
|
10 |
+
def create_demo(process, process_image_click=None):
|
11 |
|
12 |
+
examples = None
|
13 |
+
INFO = f'''
|
14 |
+
## EditAnything https://github.com/sail-sg/EditAnything
|
|
|
|
|
15 |
'''
|
16 |
+
WARNING_INFO = None
|
17 |
+
|
18 |
+
demo = create_demo_template(process, process_image_click, examples=examples,
|
19 |
+
INFO=INFO, WARNING_INFO=WARNING_INFO, enable_auto_prompt_default=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
return demo
|
21 |
|
22 |
|
23 |
if __name__ == '__main__':
|
24 |
+
model = EditAnythingLoraModel(base_model_path="stabilityai/stable-diffusion-2",
|
25 |
+
controlmodel_name='LAION Pretrained(v0-4)-SD21', extra_inpaint=True,
|
26 |
lora_model_path=None, use_blip=True)
|
27 |
+
demo = create_demo(model.process, model.process_image_click)
|
28 |
demo.queue().launch(server_name='0.0.0.0')
|
sam2edit_beauty.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1 |
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
|
|
2 |
import gradio as gr
|
3 |
from diffusers.utils import load_image
|
4 |
from sam2edit_lora import EditAnythingLoraModel, config_dict
|
|
|
|
|
5 |
|
6 |
|
7 |
-
def create_demo(process):
|
8 |
|
9 |
examples = [
|
10 |
["dudou,1girl, beautiful face, solo, candle, brown hair, long hair, <lora:flowergirl:0.9>,ulzzang-6500-v1.1,(raw photo:1.2),((photorealistic:1.4))best quality ,masterpiece, illustration, an extremely delicate and beautiful, extremely detailed ,CG ,unity ,8k wallpaper, Amazing, finely detail, masterpiece,best quality,official art,extremely detailed CG unity 8k wallpaper,absurdres, incredibly absurdres, huge filesize, ultra-detailed, highres, extremely detailed,beautiful detailed girl, extremely detailed eyes and face, beautiful detailed eyes,cinematic lighting,1girl,see-through,looking at viewer,full body,full-body shot,outdoors,arms behind back,(chinese clothes) <lora:cuteGirlMix4_v10:1>",
|
@@ -16,77 +19,26 @@ def create_demo(process):
|
|
16 |
["mix4, whole body shot, ((8k, RAW photo, highest quality, masterpiece), High detail RAW color photo professional close-up photo, shy expression, cute, beautiful detailed girl, detailed fingers, extremely detailed eyes and face, beautiful detailed nose, beautiful detailed eyes, long eyelashes, light on face, looking at viewer, (closed mouth:1.2), 1girl, cute, young, mature face, (full body:1.3), ((small breasts)), realistic face, realistic body, beautiful detailed thigh,s, same eyes color, (realistic, photo realism:1. 37), (highest quality), (best shadow), (best illustration), ultra high resolution, physics-based rendering, cinematic lighting), solo, 1girl, highly detailed, in office, detailed office, open cardigan, ponytail contorted, beautiful eyes ,sitting in office,dating, business suit, cross-laced clothes, collared shirt, beautiful breast, small breast, Chinese dress, white pantyhose, natural breasts, pink and white hair, <lora:cuteGirlMix4_v10:1>",
|
17 |
"paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), cloth, underwear, bra, low-res, normal quality, ((monochrome)), ((grayscale)), skin spots, acne, skin blemishes, age spots, glans, bad nipples, long nipples, bad vagina, extra fingers,fewer fingers,strange fingers,bad hand, ng_deepnegative_v1_75t, bad-picture-chill-75v", 7]
|
18 |
]
|
19 |
-
|
20 |
-
|
|
|
|
|
21 |
WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
|
22 |
We are not responsible for possible risks using this model.
|
23 |
-
|
24 |
Lora model from https://civitai.com/models/14171/cutegirlmix4 Thanks!
|
25 |
'''
|
26 |
-
|
27 |
-
|
28 |
-
with gr.Row():
|
29 |
-
gr.Markdown(
|
30 |
-
"## Generate Your Beauty powered by EditAnything https://github.com/sail-sg/EditAnything ")
|
31 |
-
with gr.Row():
|
32 |
-
with gr.Column():
|
33 |
-
source_image = gr.Image(
|
34 |
-
source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
|
35 |
-
enable_all_generate = gr.Checkbox(
|
36 |
-
label='Auto generation on all region.', value=False)
|
37 |
-
prompt = gr.Textbox(
|
38 |
-
label="Prompt (Text in the expected things of edited region)")
|
39 |
-
enable_auto_prompt = gr.Checkbox(
|
40 |
-
label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
|
41 |
-
a_prompt = gr.Textbox(
|
42 |
-
label="Added Prompt", value='best quality, extremely detailed')
|
43 |
-
n_prompt = gr.Textbox(label="Negative Prompt",
|
44 |
-
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
45 |
-
control_scale = gr.Slider(
|
46 |
-
label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
|
47 |
-
run_button = gr.Button(label="Run")
|
48 |
-
num_samples = gr.Slider(
|
49 |
-
label="Images", minimum=1, maximum=12, value=2, step=1)
|
50 |
-
seed = gr.Slider(label="Seed", minimum=-1,
|
51 |
-
maximum=2147483647, step=1, randomize=True)
|
52 |
-
enable_tile = gr.Checkbox(
|
53 |
-
label='Tile refinement for high resolution generation.', value=True)
|
54 |
-
with gr.Accordion("Advanced options", open=False):
|
55 |
-
mask_image = gr.Image(
|
56 |
-
source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
|
57 |
-
image_resolution = gr.Slider(
|
58 |
-
label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
59 |
-
strength = gr.Slider(
|
60 |
-
label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
61 |
-
guess_mode = gr.Checkbox(
|
62 |
-
label='Guess Mode', value=False)
|
63 |
-
detect_resolution = gr.Slider(
|
64 |
-
label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
|
65 |
-
ddim_steps = gr.Slider(
|
66 |
-
label="Steps", minimum=1, maximum=100, value=30, step=1)
|
67 |
-
scale = gr.Slider(
|
68 |
-
label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
69 |
-
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
70 |
-
with gr.Column():
|
71 |
-
result_gallery = gr.Gallery(
|
72 |
-
label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
73 |
-
result_text = gr.Text(label='BLIP2+Human Prompt Text')
|
74 |
-
ips = [source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
|
75 |
-
detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile]
|
76 |
-
run_button.click(fn=process, inputs=ips, outputs=[
|
77 |
-
result_gallery, result_text])
|
78 |
-
with gr.Row():
|
79 |
-
ex = gr.Examples(examples=examples, fn=process,
|
80 |
-
inputs=[a_prompt, n_prompt, scale],
|
81 |
-
outputs=[result_gallery],
|
82 |
-
cache_examples=False)
|
83 |
-
with gr.Row():
|
84 |
-
gr.Markdown(WARNING_INFO)
|
85 |
return demo
|
86 |
|
87 |
|
88 |
if __name__ == '__main__':
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
92 |
demo.queue().launch(server_name='0.0.0.0')
|
|
|
1 |
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
2 |
+
import os
|
3 |
import gradio as gr
|
4 |
from diffusers.utils import load_image
|
5 |
from sam2edit_lora import EditAnythingLoraModel, config_dict
|
6 |
+
from sam2edit_demo import create_demo_template
|
7 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
8 |
|
9 |
|
10 |
+
def create_demo(process, process_image_click=None):
|
11 |
|
12 |
examples = [
|
13 |
["dudou,1girl, beautiful face, solo, candle, brown hair, long hair, <lora:flowergirl:0.9>,ulzzang-6500-v1.1,(raw photo:1.2),((photorealistic:1.4))best quality ,masterpiece, illustration, an extremely delicate and beautiful, extremely detailed ,CG ,unity ,8k wallpaper, Amazing, finely detail, masterpiece,best quality,official art,extremely detailed CG unity 8k wallpaper,absurdres, incredibly absurdres, huge filesize, ultra-detailed, highres, extremely detailed,beautiful detailed girl, extremely detailed eyes and face, beautiful detailed eyes,cinematic lighting,1girl,see-through,looking at viewer,full body,full-body shot,outdoors,arms behind back,(chinese clothes) <lora:cuteGirlMix4_v10:1>",
|
|
|
19 |
["mix4, whole body shot, ((8k, RAW photo, highest quality, masterpiece), High detail RAW color photo professional close-up photo, shy expression, cute, beautiful detailed girl, detailed fingers, extremely detailed eyes and face, beautiful detailed nose, beautiful detailed eyes, long eyelashes, light on face, looking at viewer, (closed mouth:1.2), 1girl, cute, young, mature face, (full body:1.3), ((small breasts)), realistic face, realistic body, beautiful detailed thigh,s, same eyes color, (realistic, photo realism:1. 37), (highest quality), (best shadow), (best illustration), ultra high resolution, physics-based rendering, cinematic lighting), solo, 1girl, highly detailed, in office, detailed office, open cardigan, ponytail contorted, beautiful eyes ,sitting in office,dating, business suit, cross-laced clothes, collared shirt, beautiful breast, small breast, Chinese dress, white pantyhose, natural breasts, pink and white hair, <lora:cuteGirlMix4_v10:1>",
|
20 |
"paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), cloth, underwear, bra, low-res, normal quality, ((monochrome)), ((grayscale)), skin spots, acne, skin blemishes, age spots, glans, bad nipples, long nipples, bad vagina, extra fingers,fewer fingers,strange fingers,bad hand, ng_deepnegative_v1_75t, bad-picture-chill-75v", 7]
|
21 |
]
|
22 |
+
INFO = f'''
|
23 |
+
## Generate Your Beauty powered by EditAnything https://github.com/sail-sg/EditAnything
|
24 |
+
This model is good at generating beautiful female.
|
25 |
+
'''
|
26 |
WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
|
27 |
We are not responsible for possible risks using this model.
|
|
|
28 |
Lora model from https://civitai.com/models/14171/cutegirlmix4 Thanks!
|
29 |
'''
|
30 |
+
demo = create_demo_template(process, process_image_click,
|
31 |
+
examples=examples, INFO=INFO, WARNING_INFO=WARNING_INFO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
return demo
|
33 |
|
34 |
|
35 |
if __name__ == '__main__':
|
36 |
+
sd_models_path = snapshot_download("shgao/sdmodels")
|
37 |
+
lora_model_path = hf_hub_download(
|
38 |
+
"mlida/Cute_girl_mix4", "cuteGirlMix4_v10.safetensors")
|
39 |
+
model = EditAnythingLoraModel(base_model_path=os.path.join(sd_models_path, "chilloutmix_NiPrunedFp32Fix"),
|
40 |
+
lora_model_path=lora_model_path, use_blip=True, extra_inpaint=True,
|
41 |
+
lora_weight=0.5,
|
42 |
+
)
|
43 |
+
demo = create_demo(model.process, model.process_image_click)
|
44 |
demo.queue().launch(server_name='0.0.0.0')
|
sam2edit_demo.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
def create_demo_template(process, process_image_click=None, examples=None,
|
5 |
+
INFO='EditAnything https://github.com/sail-sg/EditAnything', WARNING_INFO=None,
|
6 |
+
enable_auto_prompt_default=False,
|
7 |
+
):
|
8 |
+
|
9 |
+
print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
|
10 |
+
block = gr.Blocks()
|
11 |
+
with block as demo:
|
12 |
+
clicked_points = gr.State([])
|
13 |
+
origin_image = gr.State(None)
|
14 |
+
click_mask = gr.State(None)
|
15 |
+
with gr.Row():
|
16 |
+
gr.Markdown(INFO)
|
17 |
+
with gr.Row().style(equal_height=False):
|
18 |
+
with gr.Column():
|
19 |
+
with gr.Tab("Clickπ±"):
|
20 |
+
source_image_click = gr.Image(
|
21 |
+
type="pil", interactive=True,
|
22 |
+
label="Image: Upload an image and click the region you want to edit.",
|
23 |
+
)
|
24 |
+
with gr.Column():
|
25 |
+
with gr.Row():
|
26 |
+
point_prompt = gr.Radio(
|
27 |
+
choices=["Foreground Point", "Background Point"],
|
28 |
+
value="Foreground Point",
|
29 |
+
label="Point Label",
|
30 |
+
interactive=True, show_label=False)
|
31 |
+
clear_button_click = gr.Button(
|
32 |
+
value="Clear Click Points", interactive=True)
|
33 |
+
clear_button_image = gr.Button(
|
34 |
+
value="Clear Image", interactive=True)
|
35 |
+
with gr.Row():
|
36 |
+
run_button_click = gr.Button(
|
37 |
+
label="Run EditAnying", interactive=True)
|
38 |
+
with gr.Tab("BrushποΈ"):
|
39 |
+
source_image_brush = gr.Image(
|
40 |
+
source='upload',
|
41 |
+
label="Image: Upload an image and cover the region you want to edit with sketch",
|
42 |
+
type="numpy", tool="sketch"
|
43 |
+
)
|
44 |
+
run_button = gr.Button(label="Run EditAnying", interactive=True)
|
45 |
+
with gr.Column():
|
46 |
+
enable_all_generate = gr.Checkbox(
|
47 |
+
label='Auto generation on all region.', value=False)
|
48 |
+
control_scale = gr.Slider(
|
49 |
+
label="Mask Align strength", info="Large value -> strict alignment with SAM mask", minimum=0, maximum=1, value=1, step=0.1)
|
50 |
+
with gr.Column():
|
51 |
+
enable_auto_prompt = gr.Checkbox(
|
52 |
+
label='Auto generate text prompt from input image with BLIP2', info='Warning: Enable this may makes your prompt not working.', value=enable_auto_prompt_default)
|
53 |
+
a_prompt = gr.Textbox(
|
54 |
+
label="Positive Prompt", info='Text in the expected things of edited region', value='best quality, extremely detailed')
|
55 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
56 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, NSFW')
|
57 |
+
with gr.Row():
|
58 |
+
num_samples = gr.Slider(
|
59 |
+
label="Images", minimum=1, maximum=12, value=2, step=1)
|
60 |
+
seed = gr.Slider(label="Seed", minimum=-1,
|
61 |
+
maximum=2147483647, step=1, randomize=True)
|
62 |
+
with gr.Row():
|
63 |
+
enable_tile = gr.Checkbox(
|
64 |
+
label='Tile refinement for high resolution generation', info='Slow inference', value=True)
|
65 |
+
refine_alignment_ratio = gr.Slider(
|
66 |
+
label="Alignment Strength", info='Large value -> strict alignment with input image. Small value -> strong global consistency', minimum=0.0, maximum=1.0, value=0.95, step=0.05)
|
67 |
+
|
68 |
+
with gr.Accordion("Advanced options", open=False):
|
69 |
+
mask_image = gr.Image(
|
70 |
+
source='upload', label="Upload a predefined mask of edit region if you do not want to write your prompt.", info="(Optional:Switch to Brush mode when using this!) ", type="numpy", value=None)
|
71 |
+
image_resolution = gr.Slider(
|
72 |
+
label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
73 |
+
refine_image_resolution = gr.Slider(
|
74 |
+
label="Image Resolution", minimum=256, maximum=8192, value=1024, step=64)
|
75 |
+
strength = gr.Slider(
|
76 |
+
label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
77 |
+
guess_mode = gr.Checkbox(
|
78 |
+
label='Guess Mode', value=False)
|
79 |
+
detect_resolution = gr.Slider(
|
80 |
+
label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
|
81 |
+
ddim_steps = gr.Slider(
|
82 |
+
label="Steps", minimum=1, maximum=100, value=30, step=1)
|
83 |
+
scale = gr.Slider(
|
84 |
+
label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
85 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
86 |
+
with gr.Column():
|
87 |
+
result_gallery_refine = gr.Gallery(
|
88 |
+
label='Output High quality', show_label=True, elem_id="gallery").style(grid=2, preview=False)
|
89 |
+
result_gallery_init = gr.Gallery(
|
90 |
+
label='Output Low quality', show_label=True, elem_id="gallery").style(grid=2, height='auto')
|
91 |
+
result_gallery_ref = gr.Gallery(
|
92 |
+
label='Output Ref', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
93 |
+
result_text = gr.Text(label='BLIP2+Human Prompt Text')
|
94 |
+
|
95 |
+
ips = [source_image_brush, enable_all_generate, mask_image, control_scale, enable_auto_prompt, a_prompt, n_prompt, num_samples, image_resolution,
|
96 |
+
detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile, refine_alignment_ratio, refine_image_resolution]
|
97 |
+
run_button.click(fn=process, inputs=ips, outputs=[
|
98 |
+
result_gallery_refine, result_gallery_init, result_gallery_ref, result_text])
|
99 |
+
|
100 |
+
ip_click = [origin_image, enable_all_generate, click_mask, control_scale, enable_auto_prompt, a_prompt, n_prompt, num_samples, image_resolution,
|
101 |
+
detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile, refine_alignment_ratio, refine_image_resolution]
|
102 |
+
|
103 |
+
run_button_click.click(fn=process,
|
104 |
+
inputs=ip_click,
|
105 |
+
outputs=[result_gallery_refine, result_gallery_init, result_gallery_ref, result_text])
|
106 |
+
|
107 |
+
source_image_click.upload(
|
108 |
+
lambda image: image.copy() if image is not None else None,
|
109 |
+
inputs=[source_image_click],
|
110 |
+
outputs=[origin_image]
|
111 |
+
)
|
112 |
+
source_image_click.select(
|
113 |
+
process_image_click,
|
114 |
+
inputs=[origin_image, point_prompt,
|
115 |
+
clicked_points, image_resolution],
|
116 |
+
outputs=[source_image_click, clicked_points, click_mask],
|
117 |
+
show_progress=True, queue=True
|
118 |
+
)
|
119 |
+
clear_button_click.click(
|
120 |
+
fn=lambda original_image: (original_image.copy(), [], None)
|
121 |
+
if original_image is not None else (None, [], None),
|
122 |
+
inputs=[origin_image],
|
123 |
+
outputs=[source_image_click, clicked_points, click_mask]
|
124 |
+
)
|
125 |
+
clear_button_image.click(
|
126 |
+
fn=lambda: (None, [], None, None, None),
|
127 |
+
inputs=[],
|
128 |
+
outputs=[source_image_click, clicked_points,
|
129 |
+
click_mask, result_gallery_init, result_text]
|
130 |
+
)
|
131 |
+
if examples is not None:
|
132 |
+
with gr.Row():
|
133 |
+
ex = gr.Examples(examples=examples, fn=process,
|
134 |
+
inputs=[a_prompt, n_prompt, scale],
|
135 |
+
outputs=[result_gallery_init],
|
136 |
+
cache_examples=False)
|
137 |
+
if WARNING_INFO is not None:
|
138 |
+
with gr.Row():
|
139 |
+
gr.Markdown(WARNING_INFO)
|
140 |
+
return demo
|
sam2edit_handsome.py
CHANGED
@@ -1,87 +1,37 @@
|
|
1 |
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
|
|
2 |
import gradio as gr
|
3 |
from diffusers.utils import load_image
|
4 |
from sam2edit_lora import EditAnythingLoraModel, config_dict
|
|
|
|
|
5 |
|
6 |
|
7 |
-
|
8 |
-
def create_demo(process):
|
9 |
|
10 |
examples = [
|
11 |
-
["1man, muscle,full body, vest, short straight hair, glasses, Gym, barbells, dumbbells, treadmills, boxing rings, squat racks, plates, dumbbell racks soft lighting, masterpiece, best quality, 8k uhd, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6>",
|
12 |
-
|
|
|
|
|
13 |
]
|
14 |
|
15 |
print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
|
|
|
|
|
|
|
|
|
|
|
16 |
WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
|
17 |
We are not responsible for possible risks using this model.
|
18 |
Base model from https://huggingface.co/SG161222/Realistic_Vision_V2.0 Thanks!
|
19 |
'''
|
20 |
-
|
21 |
-
with block as demo:
|
22 |
-
with gr.Row():
|
23 |
-
gr.Markdown(
|
24 |
-
"## Generate Your Handsome powered by EditAnything https://github.com/sail-sg/EditAnything ")
|
25 |
-
with gr.Row():
|
26 |
-
with gr.Column():
|
27 |
-
source_image = gr.Image(
|
28 |
-
source='upload', label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
|
29 |
-
enable_all_generate = gr.Checkbox(
|
30 |
-
label='Auto generation on all region.', value=False)
|
31 |
-
prompt = gr.Textbox(
|
32 |
-
label="Prompt (Text in the expected things of edited region)")
|
33 |
-
enable_auto_prompt = gr.Checkbox(
|
34 |
-
label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=False)
|
35 |
-
a_prompt = gr.Textbox(
|
36 |
-
label="Added Prompt", value='best quality, extremely detailed')
|
37 |
-
n_prompt = gr.Textbox(label="Negative Prompt",
|
38 |
-
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
39 |
-
control_scale = gr.Slider(
|
40 |
-
label="Mask Align strength (Large value means more strict alignment with SAM mask)", minimum=0, maximum=1, value=1, step=0.1)
|
41 |
-
run_button = gr.Button(label="Run")
|
42 |
-
num_samples = gr.Slider(
|
43 |
-
label="Images", minimum=1, maximum=12, value=2, step=1)
|
44 |
-
seed = gr.Slider(label="Seed", minimum=-1,
|
45 |
-
maximum=2147483647, step=1, randomize=True)
|
46 |
-
enable_tile = gr.Checkbox(
|
47 |
-
label='Tile refinement for high resolution generation.', value=True)
|
48 |
-
with gr.Accordion("Advanced options", open=False):
|
49 |
-
mask_image = gr.Image(
|
50 |
-
source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
|
51 |
-
image_resolution = gr.Slider(
|
52 |
-
label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
53 |
-
strength = gr.Slider(
|
54 |
-
label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
55 |
-
guess_mode = gr.Checkbox(
|
56 |
-
label='Guess Mode', value=False)
|
57 |
-
detect_resolution = gr.Slider(
|
58 |
-
label="SAM Resolution", minimum=128, maximum=2048, value=1024, step=1)
|
59 |
-
ddim_steps = gr.Slider(
|
60 |
-
label="Steps", minimum=1, maximum=100, value=30, step=1)
|
61 |
-
scale = gr.Slider(
|
62 |
-
label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
63 |
-
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
64 |
-
with gr.Column():
|
65 |
-
result_gallery = gr.Gallery(
|
66 |
-
label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
67 |
-
result_text = gr.Text(label='BLIP2+Human Prompt Text')
|
68 |
-
ips = [source_image, enable_all_generate, mask_image, control_scale, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
|
69 |
-
detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, enable_tile]
|
70 |
-
run_button.click(fn=process, inputs=ips, outputs=[
|
71 |
-
result_gallery, result_text])
|
72 |
-
with gr.Row():
|
73 |
-
ex = gr.Examples(examples=examples, fn=process,
|
74 |
-
inputs=[a_prompt, n_prompt, scale],
|
75 |
-
outputs=[result_gallery],
|
76 |
-
cache_examples=False)
|
77 |
-
with gr.Row():
|
78 |
-
gr.Markdown(WARNING_INFO)
|
79 |
return demo
|
80 |
|
81 |
|
82 |
-
|
83 |
if __name__ == '__main__':
|
84 |
-
model = EditAnythingLoraModel(base_model_path=
|
85 |
-
|
86 |
-
demo = create_demo(model.process)
|
87 |
demo.queue().launch(server_name='0.0.0.0')
|
|
|
1 |
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
|
2 |
+
import os
|
3 |
import gradio as gr
|
4 |
from diffusers.utils import load_image
|
5 |
from sam2edit_lora import EditAnythingLoraModel, config_dict
|
6 |
+
from sam2edit_demo import create_demo_template
|
7 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
8 |
|
9 |
|
10 |
+
def create_demo(process, process_image_click=None):
|
|
|
11 |
|
12 |
examples = [
|
13 |
+
["1man, muscle,full body, vest, short straight hair, glasses, Gym, barbells, dumbbells, treadmills, boxing rings, squat racks, plates, dumbbell racks soft lighting, masterpiece, best quality, 8k uhd, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6>",
|
14 |
+
"(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 6],
|
15 |
+
["1man, 25 years- old, full body, wearing long-sleeve white shirt and tie, muscular rand black suit, soft lighting, masterpiece, best quality, 8k uhd, dslr, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6> <lora:uncutPenisLora_v10:0.6>",
|
16 |
+
"(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 6],
|
17 |
]
|
18 |
|
19 |
print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
|
20 |
+
|
21 |
+
INFO = f'''
|
22 |
+
## Generate Your Handsome powered by EditAnything https://github.com/sail-sg/EditAnything
|
23 |
+
This model is good at generating handsome male.
|
24 |
+
'''
|
25 |
WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
|
26 |
We are not responsible for possible risks using this model.
|
27 |
Base model from https://huggingface.co/SG161222/Realistic_Vision_V2.0 Thanks!
|
28 |
'''
|
29 |
+
demo = create_demo_template(process, process_image_click, examples=examples, INFO=INFO, WARNING_INFO=WARNING_INFO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
return demo
|
31 |
|
32 |
|
|
|
33 |
if __name__ == '__main__':
|
34 |
+
model = EditAnythingLoraModel(base_model_path='Realistic_Vision_V2.0',
|
35 |
+
lora_model_path=None, use_blip=True)
|
36 |
+
demo = create_demo(model.process, model.process_image_click)
|
37 |
demo.queue().launch(server_name='0.0.0.0')
|
sam2edit_lora.py
CHANGED
@@ -14,7 +14,7 @@ import random
|
|
14 |
import os
|
15 |
import requests
|
16 |
from io import BytesIO
|
17 |
-
from annotator.util import resize_image, HWC3
|
18 |
|
19 |
import torch
|
20 |
from safetensors.torch import load_file
|
@@ -22,7 +22,6 @@ from collections import defaultdict
|
|
22 |
from diffusers import StableDiffusionControlNetPipeline
|
23 |
from diffusers import ControlNetModel, UniPCMultistepScheduler
|
24 |
from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
25 |
-
# from utils.tmp import StableDiffusionControlNetInpaintPipeline
|
26 |
# need the latest transformers
|
27 |
# pip install git+https://github.com/huggingface/transformers.git
|
28 |
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
@@ -32,13 +31,13 @@ import PIL.Image
|
|
32 |
# Segment-Anything init.
|
33 |
# pip install git+https://github.com/facebookresearch/segment-anything.git
|
34 |
try:
|
35 |
-
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
36 |
except ImportError:
|
37 |
print('segment_anything not installed')
|
38 |
result = subprocess.run(
|
39 |
['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
|
40 |
print(f'Install segment_anything {result}')
|
41 |
-
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
42 |
if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
|
43 |
result = subprocess.run(
|
44 |
['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
|
@@ -52,13 +51,18 @@ config_dict = OrderedDict([
|
|
52 |
])
|
53 |
|
54 |
|
55 |
-
def init_sam_model():
|
|
|
|
|
56 |
sam_checkpoint = "models/sam_vit_h_4b8939.pth"
|
57 |
model_type = "default"
|
58 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
59 |
sam.to(device=device)
|
60 |
-
sam_generator = SamAutomaticMaskGenerator(
|
61 |
-
|
|
|
|
|
|
|
62 |
|
63 |
|
64 |
def init_blip_processor():
|
@@ -112,7 +116,6 @@ def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
|
|
112 |
return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
|
113 |
|
114 |
|
115 |
-
|
116 |
def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
|
117 |
LORA_PREFIX_UNET = "lora_unet"
|
118 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
@@ -241,10 +244,12 @@ def make_inpaint_condition(image, image_mask):
|
|
241 |
image = torch.from_numpy(image)
|
242 |
return image
|
243 |
|
|
|
244 |
def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, generation_only=False, extra_inpaint=True, lora_weight=1.0):
|
245 |
controlnet = []
|
246 |
-
controlnet.append(ControlNetModel.from_pretrained(
|
247 |
-
|
|
|
248 |
print("Warning: ControlNet based inpainting model only support SD1.5 for now.")
|
249 |
controlnet.append(
|
250 |
ControlNetModel.from_pretrained(
|
@@ -271,17 +276,18 @@ def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, g
|
|
271 |
pipe.enable_model_cpu_offload()
|
272 |
return pipe
|
273 |
|
|
|
274 |
def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0):
|
275 |
controlnet = ControlNetModel.from_pretrained(
|
276 |
-
|
277 |
-
if base_model_path=='runwayml/stable-diffusion-v1-5' or base_model_path=='stabilityai/stable-diffusion-2-inpainting':
|
278 |
print("base_model_path", base_model_path)
|
279 |
-
pipe =
|
280 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
|
281 |
)
|
282 |
else:
|
283 |
-
pipe =
|
284 |
-
|
285 |
)
|
286 |
if lora_model_path is not None:
|
287 |
pipe = load_lora_weights(
|
@@ -296,7 +302,6 @@ def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0):
|
|
296 |
return pipe
|
297 |
|
298 |
|
299 |
-
|
300 |
def show_anns(anns):
|
301 |
if len(anns) == 0:
|
302 |
return
|
@@ -331,9 +336,11 @@ class EditAnythingLoraModel:
|
|
331 |
blip_model=None,
|
332 |
sam_generator=None,
|
333 |
controlmodel_name='LAION Pretrained(v0-4)-SD15',
|
334 |
-
|
|
|
335 |
tile_model=None,
|
336 |
lora_weight=1.0,
|
|
|
337 |
):
|
338 |
self.device = device
|
339 |
self.use_blip = use_blip
|
@@ -348,11 +355,8 @@ class EditAnythingLoraModel:
|
|
348 |
base_model_path, lora_model_path, self.default_controlnet_path, generation_only=False, extra_inpaint=extra_inpaint, lora_weight=lora_weight)
|
349 |
|
350 |
# Segment-Anything init.
|
351 |
-
|
352 |
-
|
353 |
-
else:
|
354 |
-
self.sam_generator = init_sam_model()
|
355 |
-
|
356 |
# BLIP2 init.
|
357 |
if use_blip:
|
358 |
if blip_processor is not None:
|
@@ -369,7 +373,8 @@ class EditAnythingLoraModel:
|
|
369 |
if tile_model is not None:
|
370 |
self.tile_pipe = tile_model
|
371 |
else:
|
372 |
-
self.tile_pipe = obtain_tile_model(
|
|
|
373 |
|
374 |
def get_blip2_text(self, image):
|
375 |
inputs = self.blip_processor(image, return_tensors="pt").to(
|
@@ -384,19 +389,92 @@ class EditAnythingLoraModel:
|
|
384 |
full_img, res = show_anns(masks)
|
385 |
return full_img, res
|
386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
@torch.inference_mode()
|
388 |
-
def
|
389 |
-
|
390 |
-
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
ddim_steps, guess_mode, strength, scale, seed, eta,
|
393 |
-
enable_tile=True, condition_model=None):
|
394 |
|
395 |
if condition_model is None:
|
396 |
this_controlnet_path = self.default_controlnet_path
|
397 |
else:
|
398 |
this_controlnet_path = config_dict[condition_model]
|
399 |
-
input_image = source_image["image"]
|
|
|
400 |
if mask_image is None:
|
401 |
if enable_all_generate != self.defalut_enable_all_generate:
|
402 |
self.pipe = obtain_generation_model(
|
@@ -410,6 +488,8 @@ class EditAnythingLoraModel:
|
|
410 |
(input_image.shape[0], input_image.shape[1], 3))*255
|
411 |
else:
|
412 |
mask_image = source_image["mask"]
|
|
|
|
|
413 |
if self.default_controlnet_path != this_controlnet_path:
|
414 |
print("To Use:", this_controlnet_path,
|
415 |
"Current:", self.default_controlnet_path)
|
@@ -424,10 +504,10 @@ class EditAnythingLoraModel:
|
|
424 |
print("Generating text:")
|
425 |
blip2_prompt = self.get_blip2_text(input_image)
|
426 |
print("Generated text:", blip2_prompt)
|
427 |
-
if len(
|
428 |
-
|
429 |
else:
|
430 |
-
|
431 |
|
432 |
input_image = HWC3(input_image)
|
433 |
|
@@ -448,23 +528,23 @@ class EditAnythingLoraModel:
|
|
448 |
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
449 |
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
450 |
|
451 |
-
|
452 |
mask_image_tmp = cv2.resize(
|
453 |
-
|
454 |
mask_image = Image.fromarray(mask_image_tmp)
|
455 |
|
456 |
if seed == -1:
|
457 |
seed = random.randint(0, 65535)
|
458 |
seed_everything(seed)
|
459 |
generator = torch.manual_seed(seed)
|
460 |
-
postive_prompt =
|
461 |
negative_prompt = n_prompt
|
462 |
prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
|
463 |
self.pipe, postive_prompt, negative_prompt, "cuda")
|
464 |
prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
|
465 |
negative_prompt_embeds = torch.cat(
|
466 |
[negative_prompt_embeds] * num_samples, dim=0)
|
467 |
-
if enable_all_generate and self.extra_inpaint:
|
468 |
self.pipe.safety_checker = lambda images, clip_input: (
|
469 |
images, False)
|
470 |
x_samples = self.pipe(
|
@@ -485,7 +565,8 @@ class EditAnythingLoraModel:
|
|
485 |
if self.extra_inpaint:
|
486 |
inpaint_image = make_inpaint_condition(img, mask_image_tmp)
|
487 |
print(inpaint_image.shape)
|
488 |
-
multi_condition_image.append(
|
|
|
489 |
multi_condition_scale.append(1.0)
|
490 |
x_samples = self.pipe(
|
491 |
image=img,
|
@@ -501,33 +582,33 @@ class EditAnythingLoraModel:
|
|
501 |
).images
|
502 |
results = [x_samples[i] for i in range(num_samples)]
|
503 |
|
504 |
-
|
505 |
-
|
506 |
-
# for each in img_tile:
|
507 |
-
# print("tile",each.size)
|
508 |
prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
|
509 |
self.tile_pipe, postive_prompt, negative_prompt, "cuda")
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
|
|
|
|
531 |
|
532 |
def download_image(url):
|
533 |
response = requests.get(url)
|
|
|
14 |
import os
|
15 |
import requests
|
16 |
from io import BytesIO
|
17 |
+
from annotator.util import resize_image, HWC3, resize_points
|
18 |
|
19 |
import torch
|
20 |
from safetensors.torch import load_file
|
|
|
22 |
from diffusers import StableDiffusionControlNetPipeline
|
23 |
from diffusers import ControlNetModel, UniPCMultistepScheduler
|
24 |
from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
|
|
25 |
# need the latest transformers
|
26 |
# pip install git+https://github.com/huggingface/transformers.git
|
27 |
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
|
|
31 |
# Segment-Anything init.
|
32 |
# pip install git+https://github.com/facebookresearch/segment-anything.git
|
33 |
try:
|
34 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
35 |
except ImportError:
|
36 |
print('segment_anything not installed')
|
37 |
result = subprocess.run(
|
38 |
['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
|
39 |
print(f'Install segment_anything {result}')
|
40 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
41 |
if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
|
42 |
result = subprocess.run(
|
43 |
['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
|
|
|
51 |
])
|
52 |
|
53 |
|
54 |
+
def init_sam_model(sam_generator=None, mask_predictor=None):
|
55 |
+
if sam_generator is not None and mask_predictor is not None:
|
56 |
+
return sam_generator, mask_predictor
|
57 |
sam_checkpoint = "models/sam_vit_h_4b8939.pth"
|
58 |
model_type = "default"
|
59 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
60 |
sam.to(device=device)
|
61 |
+
sam_generator = SamAutomaticMaskGenerator(
|
62 |
+
sam) if sam_generator is None else sam_generator
|
63 |
+
mask_predictor = SamPredictor(
|
64 |
+
sam) if mask_predictor is None else mask_predictor
|
65 |
+
return sam_generator, mask_predictor
|
66 |
|
67 |
|
68 |
def init_blip_processor():
|
|
|
116 |
return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
|
117 |
|
118 |
|
|
|
119 |
def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
|
120 |
LORA_PREFIX_UNET = "lora_unet"
|
121 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
|
|
244 |
image = torch.from_numpy(image)
|
245 |
return image
|
246 |
|
247 |
+
|
248 |
def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, generation_only=False, extra_inpaint=True, lora_weight=1.0):
|
249 |
controlnet = []
|
250 |
+
controlnet.append(ControlNetModel.from_pretrained(
|
251 |
+
controlnet_path, torch_dtype=torch.float16)) # sam control
|
252 |
+
if (not generation_only) and extra_inpaint: # inpainting control
|
253 |
print("Warning: ControlNet based inpainting model only support SD1.5 for now.")
|
254 |
controlnet.append(
|
255 |
ControlNetModel.from_pretrained(
|
|
|
276 |
pipe.enable_model_cpu_offload()
|
277 |
return pipe
|
278 |
|
279 |
+
|
280 |
def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0):
|
281 |
controlnet = ControlNetModel.from_pretrained(
|
282 |
+
'lllyasviel/control_v11f1e_sd15_tile', torch_dtype=torch.float16) # tile controlnet
|
283 |
+
if base_model_path == 'runwayml/stable-diffusion-v1-5' or base_model_path == 'stabilityai/stable-diffusion-2-inpainting':
|
284 |
print("base_model_path", base_model_path)
|
285 |
+
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
286 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
|
287 |
)
|
288 |
else:
|
289 |
+
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
290 |
+
base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
|
291 |
)
|
292 |
if lora_model_path is not None:
|
293 |
pipe = load_lora_weights(
|
|
|
302 |
return pipe
|
303 |
|
304 |
|
|
|
305 |
def show_anns(anns):
|
306 |
if len(anns) == 0:
|
307 |
return
|
|
|
336 |
blip_model=None,
|
337 |
sam_generator=None,
|
338 |
controlmodel_name='LAION Pretrained(v0-4)-SD15',
|
339 |
+
# used when the base model is not an inpainting model.
|
340 |
+
extra_inpaint=True,
|
341 |
tile_model=None,
|
342 |
lora_weight=1.0,
|
343 |
+
mask_predictor=None
|
344 |
):
|
345 |
self.device = device
|
346 |
self.use_blip = use_blip
|
|
|
355 |
base_model_path, lora_model_path, self.default_controlnet_path, generation_only=False, extra_inpaint=extra_inpaint, lora_weight=lora_weight)
|
356 |
|
357 |
# Segment-Anything init.
|
358 |
+
self.sam_generator, self.mask_predictor = init_sam_model(
|
359 |
+
sam_generator, mask_predictor)
|
|
|
|
|
|
|
360 |
# BLIP2 init.
|
361 |
if use_blip:
|
362 |
if blip_processor is not None:
|
|
|
373 |
if tile_model is not None:
|
374 |
self.tile_pipe = tile_model
|
375 |
else:
|
376 |
+
self.tile_pipe = obtain_tile_model(
|
377 |
+
base_model_path, lora_model_path, lora_weight=lora_weight)
|
378 |
|
379 |
def get_blip2_text(self, image):
|
380 |
inputs = self.blip_processor(image, return_tensors="pt").to(
|
|
|
389 |
full_img, res = show_anns(masks)
|
390 |
return full_img, res
|
391 |
|
392 |
+
def get_click_mask(self, image, clicked_points):
|
393 |
+
self.mask_predictor.set_image(image)
|
394 |
+
# Separate the points and labels
|
395 |
+
points, labels = zip(*[(point[:2], point[2])
|
396 |
+
for point in clicked_points])
|
397 |
+
|
398 |
+
# Convert the points and labels to numpy arrays
|
399 |
+
input_point = np.array(points)
|
400 |
+
input_label = np.array(labels)
|
401 |
+
|
402 |
+
masks, _, _ = self.mask_predictor.predict(
|
403 |
+
point_coords=input_point,
|
404 |
+
point_labels=input_label,
|
405 |
+
multimask_output=False,
|
406 |
+
)
|
407 |
+
|
408 |
+
return masks
|
409 |
+
|
410 |
@torch.inference_mode()
|
411 |
+
def process_image_click(self, original_image: gr.Image,
|
412 |
+
point_prompt: gr.Radio,
|
413 |
+
clicked_points: gr.State,
|
414 |
+
image_resolution,
|
415 |
+
evt: gr.SelectData):
|
416 |
+
# Get the clicked coordinates
|
417 |
+
clicked_coords = evt.index
|
418 |
+
x, y = clicked_coords
|
419 |
+
label = point_prompt
|
420 |
+
lab = 1 if label == "Foreground Point" else 0
|
421 |
+
clicked_points.append((x, y, lab))
|
422 |
+
|
423 |
+
input_image = np.array(original_image, dtype=np.uint8)
|
424 |
+
H, W, C = input_image.shape
|
425 |
+
input_image = HWC3(input_image)
|
426 |
+
img = resize_image(input_image, image_resolution)
|
427 |
+
|
428 |
+
# Update the clicked_points
|
429 |
+
resized_points = resize_points(clicked_points,
|
430 |
+
input_image.shape,
|
431 |
+
image_resolution)
|
432 |
+
mask_click_np = self.get_click_mask(img, resized_points)
|
433 |
+
|
434 |
+
# Convert mask_click_np to HWC format
|
435 |
+
mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0
|
436 |
+
|
437 |
+
mask_image = HWC3(mask_click_np.astype(np.uint8))
|
438 |
+
mask_image = cv2.resize(
|
439 |
+
mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
|
440 |
+
# mask_image = Image.fromarray(mask_image_tmp)
|
441 |
+
|
442 |
+
# Draw circles for all clicked points
|
443 |
+
edited_image = input_image
|
444 |
+
for x, y, lab in clicked_points:
|
445 |
+
# Set the circle color based on the label
|
446 |
+
color = (255, 0, 0) if lab == 1 else (0, 0, 255)
|
447 |
+
|
448 |
+
# Draw the circle
|
449 |
+
edited_image = cv2.circle(edited_image, (x, y), 20, color, -1)
|
450 |
+
|
451 |
+
# Set the opacity for the mask_image and edited_image
|
452 |
+
opacity_mask = 0.75
|
453 |
+
opacity_edited = 1.0
|
454 |
+
|
455 |
+
# Combine the edited_image and the mask_image using cv2.addWeighted()
|
456 |
+
overlay_image = cv2.addWeighted(
|
457 |
+
edited_image, opacity_edited,
|
458 |
+
(mask_image * np.array([0/255, 255/255, 0/255])).astype(np.uint8),
|
459 |
+
opacity_mask, 0
|
460 |
+
)
|
461 |
+
|
462 |
+
return Image.fromarray(overlay_image), clicked_points, Image.fromarray(mask_image)
|
463 |
+
|
464 |
+
@torch.inference_mode()
|
465 |
+
def process(self, source_image, enable_all_generate, mask_image,
|
466 |
+
control_scale,
|
467 |
+
enable_auto_prompt, a_prompt, n_prompt,
|
468 |
+
num_samples, image_resolution, detect_resolution,
|
469 |
ddim_steps, guess_mode, strength, scale, seed, eta,
|
470 |
+
enable_tile=True, refine_alignment_ratio=None, refine_image_resolution=None, condition_model=None):
|
471 |
|
472 |
if condition_model is None:
|
473 |
this_controlnet_path = self.default_controlnet_path
|
474 |
else:
|
475 |
this_controlnet_path = config_dict[condition_model]
|
476 |
+
input_image = source_image["image"] if isinstance(
|
477 |
+
source_image, dict) else np.array(source_image, dtype=np.uint8)
|
478 |
if mask_image is None:
|
479 |
if enable_all_generate != self.defalut_enable_all_generate:
|
480 |
self.pipe = obtain_generation_model(
|
|
|
488 |
(input_image.shape[0], input_image.shape[1], 3))*255
|
489 |
else:
|
490 |
mask_image = source_image["mask"]
|
491 |
+
else:
|
492 |
+
mask_image = np.array(mask_image, dtype=np.uint8)
|
493 |
if self.default_controlnet_path != this_controlnet_path:
|
494 |
print("To Use:", this_controlnet_path,
|
495 |
"Current:", self.default_controlnet_path)
|
|
|
504 |
print("Generating text:")
|
505 |
blip2_prompt = self.get_blip2_text(input_image)
|
506 |
print("Generated text:", blip2_prompt)
|
507 |
+
if len(a_prompt) > 0:
|
508 |
+
a_prompt = blip2_prompt + ',' + a_prompt
|
509 |
else:
|
510 |
+
a_prompt = blip2_prompt
|
511 |
|
512 |
input_image = HWC3(input_image)
|
513 |
|
|
|
528 |
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
529 |
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
530 |
|
531 |
+
mask_imag_ori = HWC3(mask_image.astype(np.uint8))
|
532 |
mask_image_tmp = cv2.resize(
|
533 |
+
mask_imag_ori, (W, H), interpolation=cv2.INTER_LINEAR)
|
534 |
mask_image = Image.fromarray(mask_image_tmp)
|
535 |
|
536 |
if seed == -1:
|
537 |
seed = random.randint(0, 65535)
|
538 |
seed_everything(seed)
|
539 |
generator = torch.manual_seed(seed)
|
540 |
+
postive_prompt = a_prompt
|
541 |
negative_prompt = n_prompt
|
542 |
prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
|
543 |
self.pipe, postive_prompt, negative_prompt, "cuda")
|
544 |
prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
|
545 |
negative_prompt_embeds = torch.cat(
|
546 |
[negative_prompt_embeds] * num_samples, dim=0)
|
547 |
+
if enable_all_generate and not self.extra_inpaint:
|
548 |
self.pipe.safety_checker = lambda images, clip_input: (
|
549 |
images, False)
|
550 |
x_samples = self.pipe(
|
|
|
565 |
if self.extra_inpaint:
|
566 |
inpaint_image = make_inpaint_condition(img, mask_image_tmp)
|
567 |
print(inpaint_image.shape)
|
568 |
+
multi_condition_image.append(
|
569 |
+
inpaint_image.type(torch.float16))
|
570 |
multi_condition_scale.append(1.0)
|
571 |
x_samples = self.pipe(
|
572 |
image=img,
|
|
|
582 |
).images
|
583 |
results = [x_samples[i] for i in range(num_samples)]
|
584 |
|
585 |
+
results_tile = []
|
586 |
+
if enable_tile:
|
|
|
|
|
587 |
prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
|
588 |
self.tile_pipe, postive_prompt, negative_prompt, "cuda")
|
589 |
+
for i in range(num_samples):
|
590 |
+
img_tile = PIL.Image.fromarray(resize_image(
|
591 |
+
np.array(x_samples[i]), refine_image_resolution))
|
592 |
+
if i == 0:
|
593 |
+
mask_image_tile = cv2.resize(
|
594 |
+
mask_imag_ori, (img_tile.size[0], img_tile.size[1]), interpolation=cv2.INTER_LINEAR)
|
595 |
+
mask_image_tile = Image.fromarray(mask_image_tile)
|
596 |
+
x_samples_tile = self.tile_pipe(
|
597 |
+
image=img_tile,
|
598 |
+
mask_image=mask_image_tile,
|
599 |
+
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
|
600 |
+
num_images_per_prompt=1,
|
601 |
+
num_inference_steps=ddim_steps,
|
602 |
+
generator=generator,
|
603 |
+
controlnet_conditioning_image=img_tile,
|
604 |
+
height=img_tile.size[1],
|
605 |
+
width=img_tile.size[0],
|
606 |
+
controlnet_conditioning_scale=1.0,
|
607 |
+
alignment_ratio=refine_alignment_ratio,
|
608 |
+
).images
|
609 |
+
results_tile += x_samples_tile
|
610 |
+
|
611 |
+
return results_tile, results, [full_segmask, mask_image], postive_prompt
|
612 |
|
613 |
def download_image(url):
|
614 |
response = requests.get(url)
|
utils/stable_diffusion_controlnet_inpaint.py
CHANGED
@@ -835,6 +835,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixi
|
|
835 |
callback_steps: int = 1,
|
836 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
837 |
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
|
|
838 |
):
|
839 |
r"""
|
840 |
Function invoked when calling the pipeline for generation.
|
@@ -1115,12 +1116,15 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixi
|
|
1115 |
progress_bar.update()
|
1116 |
if callback is not None and i % callback_steps == 0:
|
1117 |
callback(i, t, latents)
|
1118 |
-
# if self.unet.config.in_channels==4:
|
1119 |
-
# # masking for non-inpainting models
|
1120 |
-
# init_latents_proper = self.scheduler.add_noise(init_masked_image_latents, noise, t)
|
1121 |
-
# latents = (init_latents_proper * mask_image) + (latents * (1 - mask_image))
|
1122 |
|
1123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1124 |
# fill the unmasked part with original image
|
1125 |
latents = (init_masked_image_latents * mask_image) + (latents * (1 - mask_image))
|
1126 |
|
|
|
835 |
callback_steps: int = 1,
|
836 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
837 |
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
838 |
+
alignment_ratio = None,
|
839 |
):
|
840 |
r"""
|
841 |
Function invoked when calling the pipeline for generation.
|
|
|
1116 |
progress_bar.update()
|
1117 |
if callback is not None and i % callback_steps == 0:
|
1118 |
callback(i, t, latents)
|
|
|
|
|
|
|
|
|
1119 |
|
1120 |
+
if self.unet.config.in_channels==4 and alignment_ratio is not None:
|
1121 |
+
if i < len(timesteps) * alignment_ratio:
|
1122 |
+
# print(i, len(timesteps))
|
1123 |
+
# masking for non-inpainting models
|
1124 |
+
init_latents_proper = self.scheduler.add_noise(init_masked_image_latents, noise, t)
|
1125 |
+
latents = (init_latents_proper * mask_image) + (latents * (1 - mask_image))
|
1126 |
+
|
1127 |
+
if self.unet.config.in_channels==4 and (alignment_ratio==1.0 or alignment_ratio is None):
|
1128 |
# fill the unmasked part with original image
|
1129 |
latents = (init_masked_image_latents * mask_image) + (latents * (1 - mask_image))
|
1130 |
|