Spaces:
Paused
Paused
ekhatskevich
commited on
Commit
·
9235b7f
1
Parent(s):
08b0954
initial commit
Browse files- .gitignore +2 -0
- app.py +76 -0
- inference/__init__.py +2 -0
- inference/ace_plus_diffusers.py +121 -0
- inference/ace_plus_inference.py +83 -0
- inference/registry.py +228 -0
- inference/utils.py +132 -0
- requirements.txt +5 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
.idea
|
app.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
# Set necessary environment variables for ACE++
|
5 |
+
os.environ["FLUX_FILL_PATH"] = "hf://black-forest-labs/FLUX.1-Fill-dev"
|
6 |
+
os.environ["PORTRAIT_MODEL_PATH"] = "ms://iic/ACE_Plus@portrait/comfyui_portrait_lora64.safetensors"
|
7 |
+
os.environ["SUBJECT_MODEL_PATH"] = "ms://iic/ACE_Plus@subject/comfyui_subject_lora16.safetensors"
|
8 |
+
os.environ["LOCAL_MODEL_PATH"] = "ms://iic/ACE_Plus@local_editing/comfyui_local_lora16.safetensors"
|
9 |
+
|
10 |
+
# Import ACEInference and Config from the ACE_plus repo
|
11 |
+
from inference.ace_plus_inference import ACEInference
|
12 |
+
from scepter.modules.utils.config import Config
|
13 |
+
|
14 |
+
# Define a minimal configuration dictionary.
|
15 |
+
# Adjust the "MODEL" field as required by your ACE++ setup.
|
16 |
+
config_dict = {
|
17 |
+
"MODEL": {
|
18 |
+
"type": "YourACEModelType", # Replace with the actual model type string used in ACE_plus.
|
19 |
+
"pretrained_path": os.getenv("PORTRAIT_MODEL_PATH")
|
20 |
+
},
|
21 |
+
"MAX_SEQ_LEN": 77,
|
22 |
+
"SAMPLE_ARGS": {
|
23 |
+
"prompt": "Face swap"
|
24 |
+
},
|
25 |
+
"DTYPE": "bfloat16"
|
26 |
+
}
|
27 |
+
cfg = Config(config_dict)
|
28 |
+
|
29 |
+
# Instantiate the ACEInference object.
|
30 |
+
ace_infer = ACEInference(cfg)
|
31 |
+
|
32 |
+
def face_swap_app(target_img, face_img):
|
33 |
+
"""
|
34 |
+
Swaps the face in the target image using the provided face image via ACE++.
|
35 |
+
|
36 |
+
Parameters:
|
37 |
+
target_img: The image in which you want to swap a face.
|
38 |
+
face_img: The reference face image to insert.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
The output image after applying ACE++ face swapping.
|
42 |
+
"""
|
43 |
+
# For ACEInference, we pass:
|
44 |
+
# - reference_image: the target image,
|
45 |
+
# - edit_image: the new face image,
|
46 |
+
# - edit_mask: set to None so the image processor will create it,
|
47 |
+
# - prompt: "Face swap" instructs the model to perform face swapping.
|
48 |
+
# Other parameters (output dimensions, sampler, etc.) are set here as desired.
|
49 |
+
output_img, edit_image, change_image, mask, seed = ace_infer(
|
50 |
+
reference_image=target_img,
|
51 |
+
edit_image=face_img,
|
52 |
+
edit_mask=None, # No manual mask provided; let ACE++ handle it
|
53 |
+
prompt="Face swap",
|
54 |
+
output_height=1024,
|
55 |
+
output_width=1024,
|
56 |
+
sampler='flow_euler',
|
57 |
+
sample_steps=28,
|
58 |
+
guide_scale=50,
|
59 |
+
seed=-1 # Use a random seed if not specified
|
60 |
+
)
|
61 |
+
return output_img
|
62 |
+
|
63 |
+
# Create the Gradio interface.
|
64 |
+
iface = gr.Interface(
|
65 |
+
fn=face_swap_app,
|
66 |
+
inputs=[
|
67 |
+
gr.Image(type="pil", label="Target Image"),
|
68 |
+
gr.Image(type="pil", label="Face Image")
|
69 |
+
],
|
70 |
+
outputs=gr.Image(type="pil", label="Swapped Face Output"),
|
71 |
+
title="ACE++ Face Swap Demo",
|
72 |
+
description="Upload a target image and a face image to swap the face using the ACE++ model."
|
73 |
+
)
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
iface.launch()
|
inference/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .ace_plus_diffusers import ACEPlusDiffuserInference
|
2 |
+
from .ace_plus_inference import ACEInference
|
inference/ace_plus_diffusers.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import random
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
import torch, os
|
7 |
+
from diffusers import FluxFillPipeline
|
8 |
+
from scepter.modules.utils.config import Config
|
9 |
+
from scepter.modules.utils.distribute import we
|
10 |
+
from scepter.modules.utils.file_system import FS
|
11 |
+
from scepter.modules.utils.logger import get_logger
|
12 |
+
from transformers import T5TokenizerFast
|
13 |
+
from .utils import ACEPlusImageProcessor
|
14 |
+
|
15 |
+
class ACEPlusDiffuserInference():
|
16 |
+
def __init__(self, logger=None):
|
17 |
+
if logger is None:
|
18 |
+
logger = get_logger(name='ace_plus')
|
19 |
+
self.logger = logger
|
20 |
+
self.input = {}
|
21 |
+
|
22 |
+
def load_default(self, cfg):
|
23 |
+
if cfg is not None:
|
24 |
+
self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()}
|
25 |
+
self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()}
|
26 |
+
self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()}
|
27 |
+
|
28 |
+
def init_from_cfg(self, cfg):
|
29 |
+
self.max_seq_len = cfg.get("MAX_SEQ_LEN", 4096)
|
30 |
+
self.image_processor = ACEPlusImageProcessor(max_seq_len=self.max_seq_len)
|
31 |
+
|
32 |
+
local_folder = FS.get_dir_to_local_dir(cfg.MODEL.PRETRAINED_MODEL)
|
33 |
+
|
34 |
+
self.pipe = FluxFillPipeline.from_pretrained(local_folder, torch_dtype=torch.bfloat16).to(we.device_id)
|
35 |
+
|
36 |
+
tokenizer_2 = T5TokenizerFast.from_pretrained(os.path.join(local_folder, "tokenizer_2"),
|
37 |
+
additional_special_tokens=["{image}"])
|
38 |
+
self.pipe.tokenizer_2 = tokenizer_2
|
39 |
+
self.load_default(cfg.DEFAULT_PARAS)
|
40 |
+
|
41 |
+
def prepare_input(self,
|
42 |
+
image,
|
43 |
+
mask,
|
44 |
+
batch_size=1,
|
45 |
+
dtype = torch.bfloat16,
|
46 |
+
num_images_per_prompt=1,
|
47 |
+
height=512,
|
48 |
+
width=512,
|
49 |
+
generator=None):
|
50 |
+
num_channels_latents = self.pipe.vae.config.latent_channels
|
51 |
+
# import pdb;pdb.set_trace()
|
52 |
+
mask, masked_image_latents = self.pipe.prepare_mask_latents(
|
53 |
+
mask.unsqueeze(0),
|
54 |
+
image.unsqueeze(0).to(we.device_id, dtype = dtype),
|
55 |
+
batch_size,
|
56 |
+
num_channels_latents,
|
57 |
+
num_images_per_prompt,
|
58 |
+
height,
|
59 |
+
width,
|
60 |
+
dtype,
|
61 |
+
we.device_id,
|
62 |
+
generator,
|
63 |
+
)
|
64 |
+
# import pdb;pdb.set_trace()
|
65 |
+
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
|
66 |
+
return masked_image_latents
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def __call__(self,
|
70 |
+
reference_image=None,
|
71 |
+
edit_image=None,
|
72 |
+
edit_mask=None,
|
73 |
+
prompt='',
|
74 |
+
task=None,
|
75 |
+
output_height=1024,
|
76 |
+
output_width=1024,
|
77 |
+
sampler='flow_euler',
|
78 |
+
sample_steps=28,
|
79 |
+
guide_scale=50,
|
80 |
+
lora_path=None,
|
81 |
+
seed=-1,
|
82 |
+
tar_index=0,
|
83 |
+
align=0,
|
84 |
+
repainting_scale=0,
|
85 |
+
**kwargs):
|
86 |
+
if isinstance(prompt, str):
|
87 |
+
prompt = [prompt]
|
88 |
+
seed = seed if seed >= 0 else random.randint(0, 2 ** 32 - 1)
|
89 |
+
# edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
|
90 |
+
image, mask, _, _, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
|
91 |
+
width = output_width,
|
92 |
+
height = output_height,
|
93 |
+
repainting_scale = repainting_scale)
|
94 |
+
h, w = image.shape[1:]
|
95 |
+
generator = torch.Generator("cpu").manual_seed(seed)
|
96 |
+
masked_image_latents = self.prepare_input(image, mask,
|
97 |
+
batch_size=len(prompt) , height=h, width=w, generator = generator)
|
98 |
+
|
99 |
+
if lora_path is not None:
|
100 |
+
with FS.get_from(lora_path) as local_path:
|
101 |
+
self.pipe.load_lora_weights(local_path)
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
image = self.pipe(
|
106 |
+
prompt=prompt,
|
107 |
+
masked_image_latents=masked_image_latents,
|
108 |
+
height=h,
|
109 |
+
width=w,
|
110 |
+
guidance_scale=guide_scale,
|
111 |
+
num_inference_steps=sample_steps,
|
112 |
+
max_sequence_length=512,
|
113 |
+
generator=generator
|
114 |
+
).images[0]
|
115 |
+
if lora_path is not None:
|
116 |
+
self.pipe.unload_lora_weights()
|
117 |
+
return self.image_processor.postprocess(image, slice_w, out_w, out_h), seed
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
pass
|
inference/ace_plus_inference.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import random
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
import torch, numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from scepter.modules.model.registry import MODELS
|
9 |
+
from scepter.modules.utils.config import Config
|
10 |
+
from scepter.modules.utils.distribute import we
|
11 |
+
from .registry import BaseInference, INFERENCES
|
12 |
+
from .utils import ACEPlusImageProcessor
|
13 |
+
|
14 |
+
@INFERENCES.register_class()
|
15 |
+
class ACEInference(BaseInference):
|
16 |
+
'''
|
17 |
+
reuse the ldm code
|
18 |
+
'''
|
19 |
+
def __init__(self, cfg, logger=None):
|
20 |
+
super().__init__(cfg, logger)
|
21 |
+
self.pipe = MODELS.build(cfg.MODEL, logger=self.logger).eval().to(we.device_id)
|
22 |
+
self.image_processor = ACEPlusImageProcessor(max_seq_len=cfg.MAX_SEQ_LEN)
|
23 |
+
self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for
|
24 |
+
k, v in cfg.SAMPLE_ARGS.items()}
|
25 |
+
self.dtype = getattr(torch, cfg.get("DTYPE", "bfloat16"))
|
26 |
+
@torch.no_grad()
|
27 |
+
def __call__(self,
|
28 |
+
reference_image=None,
|
29 |
+
edit_image=None,
|
30 |
+
edit_mask=None,
|
31 |
+
prompt='',
|
32 |
+
edit_type=None,
|
33 |
+
output_height=1024,
|
34 |
+
output_width=1024,
|
35 |
+
sampler='flow_euler',
|
36 |
+
sample_steps=28,
|
37 |
+
guide_scale=50,
|
38 |
+
lora_path=None,
|
39 |
+
seed=-1,
|
40 |
+
repainting_scale=0,
|
41 |
+
use_change=False,
|
42 |
+
keep_pixels=False,
|
43 |
+
keep_pixels_rate=0.8,
|
44 |
+
**kwargs):
|
45 |
+
# convert the input info to the input of ldm.
|
46 |
+
if isinstance(prompt, str):
|
47 |
+
prompt = [prompt]
|
48 |
+
seed = seed if seed >= 0 else random.randint(0, 2 ** 24 - 1)
|
49 |
+
image, mask, change_image, content_image, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
|
50 |
+
height=output_height, width=output_width,
|
51 |
+
repainting_scale=repainting_scale,
|
52 |
+
keep_pixels=keep_pixels,
|
53 |
+
keep_pixels_rate=keep_pixels_rate,
|
54 |
+
use_change = use_change)
|
55 |
+
change_image = [None] if change_image is None else [change_image.to(we.device_id)]
|
56 |
+
image, mask = [image.to(we.device_id)], [mask.to(we.device_id)]
|
57 |
+
|
58 |
+
(src_image_list, src_mask_list, modify_image_list,
|
59 |
+
edit_id, prompt) = [image], [mask], [change_image], [[0]], [prompt]
|
60 |
+
|
61 |
+
with torch.amp.autocast(enabled=True, dtype=self.dtype, device_type='cuda'):
|
62 |
+
out_image = self.pipe(
|
63 |
+
src_image_list=src_image_list,
|
64 |
+
modify_image_list= modify_image_list,
|
65 |
+
src_mask_list=src_mask_list,
|
66 |
+
edit_id=edit_id,
|
67 |
+
image=image,
|
68 |
+
image_mask=mask,
|
69 |
+
prompt=prompt,
|
70 |
+
sampler='flow_euler',
|
71 |
+
sample_steps=sample_steps,
|
72 |
+
seed=seed,
|
73 |
+
guide_scale=guide_scale,
|
74 |
+
show_process=True,
|
75 |
+
)
|
76 |
+
imgs = [x_i['reconstruct_image'].float().permute(1, 2, 0).cpu().numpy()
|
77 |
+
for x_i in out_image
|
78 |
+
]
|
79 |
+
imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
|
80 |
+
edit_image = Image.fromarray((torch.clamp(image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
|
81 |
+
change_image = Image.fromarray((torch.clamp(change_image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
|
82 |
+
mask = Image.fromarray((mask[0] * 255).squeeze(0).cpu().numpy().astype(np.uint8))
|
83 |
+
return self.image_processor.postprocess(imgs[0], slice_w, out_w, out_h), edit_image, change_image, mask, seed
|
inference/registry.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL.Image import Image
|
6 |
+
from collections import OrderedDict
|
7 |
+
from scepter.modules.utils.distribute import we
|
8 |
+
from scepter.modules.utils.config import Config
|
9 |
+
from scepter.modules.utils.logger import get_logger
|
10 |
+
from scepter.studio.utils.env import get_available_memory
|
11 |
+
from scepter.modules.model.registry import MODELS, BACKBONES, EMBEDDERS
|
12 |
+
from scepter.modules.utils.registry import Registry, build_from_config
|
13 |
+
def get_model(model_tuple):
|
14 |
+
assert 'model' in model_tuple
|
15 |
+
return model_tuple['model']
|
16 |
+
|
17 |
+
class BaseInference():
|
18 |
+
'''
|
19 |
+
support to load the components dynamicly.
|
20 |
+
create and load model when run this model at the first time.
|
21 |
+
'''
|
22 |
+
def __init__(self, cfg, logger=None):
|
23 |
+
if logger is None:
|
24 |
+
logger = get_logger(name='scepter')
|
25 |
+
self.logger = logger
|
26 |
+
self.name = cfg.NAME
|
27 |
+
|
28 |
+
def init_from_modules(self, modules):
|
29 |
+
for k, v in modules.items():
|
30 |
+
self.__setattr__(k, v)
|
31 |
+
|
32 |
+
def infer_model(self, cfg, module_paras=None):
|
33 |
+
module = {
|
34 |
+
'model': None,
|
35 |
+
'cfg': cfg,
|
36 |
+
'device': 'offline',
|
37 |
+
'name': cfg.NAME,
|
38 |
+
'function_info': {},
|
39 |
+
'paras': {}
|
40 |
+
}
|
41 |
+
if module_paras is None:
|
42 |
+
return module
|
43 |
+
function_info = {}
|
44 |
+
paras = {
|
45 |
+
k.lower(): v
|
46 |
+
for k, v in module_paras.get('PARAS', {}).items()
|
47 |
+
}
|
48 |
+
for function in module_paras.get('FUNCTION', []):
|
49 |
+
input_dict = {}
|
50 |
+
for inp in function.get('INPUT', []):
|
51 |
+
if inp.lower() in self.input:
|
52 |
+
input_dict[inp.lower()] = self.input[inp.lower()]
|
53 |
+
function_info[function.NAME] = {
|
54 |
+
'dtype': function.get('DTYPE', 'float32'),
|
55 |
+
'input': input_dict
|
56 |
+
}
|
57 |
+
module['paras'] = paras
|
58 |
+
module['function_info'] = function_info
|
59 |
+
return module
|
60 |
+
|
61 |
+
def init_from_ckpt(self, path, model, ignore_keys=list()):
|
62 |
+
if path.endswith('safetensors'):
|
63 |
+
from safetensors.torch import load_file as load_safetensors
|
64 |
+
sd = load_safetensors(path)
|
65 |
+
else:
|
66 |
+
sd = torch.load(path, map_location='cpu', weights_only=True)
|
67 |
+
|
68 |
+
new_sd = OrderedDict()
|
69 |
+
for k, v in sd.items():
|
70 |
+
ignored = False
|
71 |
+
for ik in ignore_keys:
|
72 |
+
if ik in k:
|
73 |
+
if we.rank == 0:
|
74 |
+
self.logger.info(
|
75 |
+
'Ignore key {} from state_dict.'.format(k))
|
76 |
+
ignored = True
|
77 |
+
break
|
78 |
+
if not ignored:
|
79 |
+
new_sd[k] = v
|
80 |
+
|
81 |
+
missing, unexpected = model.load_state_dict(new_sd, strict=False)
|
82 |
+
if we.rank == 0:
|
83 |
+
self.logger.info(
|
84 |
+
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
85 |
+
)
|
86 |
+
if len(missing) > 0:
|
87 |
+
self.logger.info(f'Missing Keys:\n {missing}')
|
88 |
+
if len(unexpected) > 0:
|
89 |
+
self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
|
90 |
+
|
91 |
+
def load(self, module):
|
92 |
+
if module['device'] == 'offline':
|
93 |
+
from scepter.modules.utils.import_utils import LazyImportModule
|
94 |
+
if (LazyImportModule.get_module_type(('MODELS', module['cfg'].NAME)) or
|
95 |
+
module['cfg'].NAME in MODELS.class_map):
|
96 |
+
model = MODELS.build(module['cfg'], logger=self.logger).eval()
|
97 |
+
elif (LazyImportModule.get_module_type(('BACKBONES', module['cfg'].NAME)) or
|
98 |
+
module['cfg'].NAME in BACKBONES.class_map):
|
99 |
+
model = BACKBONES.build(module['cfg'],
|
100 |
+
logger=self.logger).eval()
|
101 |
+
elif (LazyImportModule.get_module_type(('EMBEDDERS', module['cfg'].NAME)) or
|
102 |
+
module['cfg'].NAME in EMBEDDERS.class_map):
|
103 |
+
model = EMBEDDERS.build(module['cfg'],
|
104 |
+
logger=self.logger).eval()
|
105 |
+
else:
|
106 |
+
raise NotImplementedError
|
107 |
+
if 'DTYPE' in module['cfg'] and module['cfg']['DTYPE'] is not None:
|
108 |
+
model = model.to(getattr(torch, module['cfg'].DTYPE))
|
109 |
+
if module['cfg'].get('RELOAD_MODEL', None):
|
110 |
+
self.init_from_ckpt(module['cfg'].RELOAD_MODEL, model)
|
111 |
+
module['model'] = model
|
112 |
+
module['device'] = 'cpu'
|
113 |
+
if module['device'] == 'cpu':
|
114 |
+
module['device'] = we.device_id
|
115 |
+
module['model'] = module['model'].to(we.device_id)
|
116 |
+
return module
|
117 |
+
|
118 |
+
def unload(self, module):
|
119 |
+
if module is None:
|
120 |
+
return module
|
121 |
+
mem = get_available_memory()
|
122 |
+
free_mem = int(mem['available'] / (1024**2))
|
123 |
+
total_mem = int(mem['total'] / (1024**2))
|
124 |
+
if free_mem < 0.5 * total_mem:
|
125 |
+
if module['model'] is not None:
|
126 |
+
module['model'] = module['model'].to('cpu')
|
127 |
+
del module['model']
|
128 |
+
module['model'] = None
|
129 |
+
module['device'] = 'offline'
|
130 |
+
print('delete module')
|
131 |
+
else:
|
132 |
+
if module['model'] is not None:
|
133 |
+
module['model'] = module['model'].to('cpu')
|
134 |
+
module['device'] = 'cpu'
|
135 |
+
else:
|
136 |
+
module['device'] = 'offline'
|
137 |
+
if torch.cuda.is_available():
|
138 |
+
torch.cuda.empty_cache()
|
139 |
+
torch.cuda.ipc_collect()
|
140 |
+
return module
|
141 |
+
|
142 |
+
def dynamic_load(self, module=None, name=''):
|
143 |
+
self.logger.info('Loading {} model'.format(name))
|
144 |
+
if name == 'all':
|
145 |
+
for subname in self.loaded_model_name:
|
146 |
+
self.loaded_model[subname] = self.dynamic_load(
|
147 |
+
getattr(self, subname), subname)
|
148 |
+
elif name in self.loaded_model_name:
|
149 |
+
if name in self.loaded_model:
|
150 |
+
if module['cfg'] != self.loaded_model[name]['cfg']:
|
151 |
+
self.unload(self.loaded_model[name])
|
152 |
+
module = self.load(module)
|
153 |
+
self.loaded_model[name] = module
|
154 |
+
return module
|
155 |
+
elif module['device'] == 'cpu' or module['device'] == 'offline':
|
156 |
+
module = self.load(module)
|
157 |
+
return module
|
158 |
+
else:
|
159 |
+
return module
|
160 |
+
else:
|
161 |
+
module = self.load(module)
|
162 |
+
self.loaded_model[name] = module
|
163 |
+
return module
|
164 |
+
else:
|
165 |
+
return self.load(module)
|
166 |
+
|
167 |
+
def dynamic_unload(self, module=None, name='', skip_loaded=False):
|
168 |
+
self.logger.info('Unloading {} model'.format(name))
|
169 |
+
if name == 'all':
|
170 |
+
for name, module in self.loaded_model.items():
|
171 |
+
module = self.unload(self.loaded_model[name])
|
172 |
+
self.loaded_model[name] = module
|
173 |
+
elif name in self.loaded_model_name:
|
174 |
+
if name in self.loaded_model:
|
175 |
+
if not skip_loaded:
|
176 |
+
module = self.unload(self.loaded_model[name])
|
177 |
+
self.loaded_model[name] = module
|
178 |
+
else:
|
179 |
+
self.unload(module)
|
180 |
+
else:
|
181 |
+
self.unload(module)
|
182 |
+
|
183 |
+
def load_default(self, cfg):
|
184 |
+
module_paras = {}
|
185 |
+
if cfg is not None:
|
186 |
+
self.paras = cfg.PARAS
|
187 |
+
self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()}
|
188 |
+
self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()}
|
189 |
+
self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()}
|
190 |
+
module_paras = cfg.MODULES_PARAS
|
191 |
+
return module_paras
|
192 |
+
|
193 |
+
def load_image(self, image, num_samples=1):
|
194 |
+
if isinstance(image, torch.Tensor):
|
195 |
+
pass
|
196 |
+
elif isinstance(image, Image):
|
197 |
+
pass
|
198 |
+
elif isinstance(image, Image):
|
199 |
+
pass
|
200 |
+
|
201 |
+
def get_function_info(self, module, function_name=None):
|
202 |
+
all_function = module['function_info']
|
203 |
+
if function_name in all_function:
|
204 |
+
return function_name, all_function[function_name]['dtype']
|
205 |
+
if function_name is None and len(all_function) == 1:
|
206 |
+
for k, v in all_function.items():
|
207 |
+
return k, v['dtype']
|
208 |
+
|
209 |
+
@torch.no_grad()
|
210 |
+
def __call__(self,
|
211 |
+
input,
|
212 |
+
**kwargs):
|
213 |
+
return
|
214 |
+
|
215 |
+
def build_inference(cfg, registry, logger=None, *args, **kwargs):
|
216 |
+
""" After build model, load pretrained model if exists key `pretrain`.
|
217 |
+
|
218 |
+
pretrain (str, dict): Describes how to load pretrained model.
|
219 |
+
str, treat pretrain as model path;
|
220 |
+
dict: should contains key `path`, and other parameters token by function load_pretrained();
|
221 |
+
"""
|
222 |
+
if not isinstance(cfg, Config):
|
223 |
+
raise TypeError(f'Config must be type dict, got {type(cfg)}')
|
224 |
+
model = build_from_config(cfg, registry, logger=logger, *args, **kwargs)
|
225 |
+
return model
|
226 |
+
|
227 |
+
# reigister cls for diffusion.
|
228 |
+
INFERENCES = Registry('INFERENCE', build_func=build_inference)
|
inference/utils.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as T
|
7 |
+
import numpy as np
|
8 |
+
from scepter.modules.annotator.registry import ANNOTATORS
|
9 |
+
from scepter.modules.utils.config import Config
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
|
13 |
+
def edit_preprocess(processor, device, edit_image, edit_mask):
|
14 |
+
if edit_image is None or processor is None:
|
15 |
+
return edit_image
|
16 |
+
processor = Config(cfg_dict=processor, load=False)
|
17 |
+
processor = ANNOTATORS.build(processor).to(device)
|
18 |
+
new_edit_image = processor(np.asarray(edit_image))
|
19 |
+
processor = processor.to("cpu")
|
20 |
+
del processor
|
21 |
+
new_edit_image = Image.fromarray(new_edit_image)
|
22 |
+
return Image.composite(new_edit_image, edit_image, edit_mask)
|
23 |
+
|
24 |
+
class ACEPlusImageProcessor():
|
25 |
+
def __init__(self, max_aspect_ratio=4, d=16, max_seq_len=1024):
|
26 |
+
self.max_aspect_ratio = max_aspect_ratio
|
27 |
+
self.d = d
|
28 |
+
self.max_seq_len = max_seq_len
|
29 |
+
self.transforms = T.Compose([
|
30 |
+
T.ToTensor(),
|
31 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
32 |
+
])
|
33 |
+
|
34 |
+
def image_check(self, image):
|
35 |
+
if image is None:
|
36 |
+
return image
|
37 |
+
# preprocess
|
38 |
+
W, H = image.size
|
39 |
+
if H / W > self.max_aspect_ratio:
|
40 |
+
image = T.CenterCrop([int(self.max_aspect_ratio * W), W])(image)
|
41 |
+
elif W / H > self.max_aspect_ratio:
|
42 |
+
image = T.CenterCrop([H, int(self.max_aspect_ratio * H)])(image)
|
43 |
+
return self.transforms(image)
|
44 |
+
|
45 |
+
|
46 |
+
def preprocess(self,
|
47 |
+
reference_image=None,
|
48 |
+
edit_image=None,
|
49 |
+
edit_mask=None,
|
50 |
+
height=1024,
|
51 |
+
width=1024,
|
52 |
+
repainting_scale = 1.0,
|
53 |
+
keep_pixels = False,
|
54 |
+
keep_pixels_rate = 0.8,
|
55 |
+
use_change = False):
|
56 |
+
reference_image = self.image_check(reference_image)
|
57 |
+
edit_image = self.image_check(edit_image)
|
58 |
+
# for reference generation
|
59 |
+
if edit_image is None:
|
60 |
+
edit_image = torch.zeros([3, height, width])
|
61 |
+
edit_mask = torch.ones([1, height, width])
|
62 |
+
else:
|
63 |
+
if edit_mask is None:
|
64 |
+
_, eH, eW = edit_image.shape
|
65 |
+
edit_mask = np.ones((eH, eW))
|
66 |
+
else:
|
67 |
+
edit_mask = np.asarray(edit_mask)
|
68 |
+
edit_mask = np.where(edit_mask > 128, 1, 0)
|
69 |
+
edit_mask = edit_mask.astype(
|
70 |
+
np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
|
71 |
+
np.float32)
|
72 |
+
edit_mask = torch.tensor(edit_mask).unsqueeze(0)
|
73 |
+
|
74 |
+
edit_image = edit_image * (1 - edit_mask * repainting_scale)
|
75 |
+
|
76 |
+
|
77 |
+
out_h, out_w = edit_image.shape[-2:]
|
78 |
+
|
79 |
+
assert edit_mask is not None
|
80 |
+
if reference_image is not None:
|
81 |
+
_, H, W = reference_image.shape
|
82 |
+
_, eH, eW = edit_image.shape
|
83 |
+
if not keep_pixels:
|
84 |
+
# align height with edit_image
|
85 |
+
scale = eH / H
|
86 |
+
tH, tW = eH, int(W * scale)
|
87 |
+
reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
|
88 |
+
reference_image)
|
89 |
+
else:
|
90 |
+
# padding
|
91 |
+
if H >= keep_pixels_rate * eH:
|
92 |
+
tH = int(eH * keep_pixels_rate)
|
93 |
+
scale = tH/H
|
94 |
+
tW = int(W * scale)
|
95 |
+
reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
|
96 |
+
reference_image)
|
97 |
+
rH, rW = reference_image.shape[-2:]
|
98 |
+
delta_w = 0
|
99 |
+
delta_h = eH - rH
|
100 |
+
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
|
101 |
+
reference_image = T.Pad(padding, fill=0, padding_mode="constant")(reference_image)
|
102 |
+
edit_image = torch.cat([reference_image, edit_image], dim=-1)
|
103 |
+
edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
|
104 |
+
slice_w = reference_image.shape[-1]
|
105 |
+
else:
|
106 |
+
slice_w = 0
|
107 |
+
|
108 |
+
H, W = edit_image.shape[-2:]
|
109 |
+
scale = min(1.0, math.sqrt(self.max_seq_len * 2 / ((H / self.d) * (W / self.d))))
|
110 |
+
rH = int(H * scale) // self.d * self.d # ensure divisible by self.d
|
111 |
+
rW = int(W * scale) // self.d * self.d
|
112 |
+
slice_w = int(slice_w * scale) // self.d * self.d
|
113 |
+
|
114 |
+
edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_image)
|
115 |
+
edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)
|
116 |
+
content_image = edit_image
|
117 |
+
if use_change:
|
118 |
+
change_image = edit_image * edit_mask
|
119 |
+
edit_image = edit_image * (1 - edit_mask)
|
120 |
+
else:
|
121 |
+
change_image = None
|
122 |
+
return edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
|
123 |
+
|
124 |
+
|
125 |
+
def postprocess(self, image, slice_w, out_w, out_h):
|
126 |
+
w, h = image.size
|
127 |
+
if slice_w > 0:
|
128 |
+
output_image = image.crop((slice_w + 30, 0, w, h))
|
129 |
+
output_image = output_image.resize((out_w, out_h))
|
130 |
+
else:
|
131 |
+
output_image = image
|
132 |
+
return output_image
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
scepter
|
3 |
+
torch
|
4 |
+
torchvision
|
5 |
+
transformers
|