ekhatskevich commited on
Commit
9235b7f
·
1 Parent(s): 08b0954

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ .idea
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+
4
+ # Set necessary environment variables for ACE++
5
+ os.environ["FLUX_FILL_PATH"] = "hf://black-forest-labs/FLUX.1-Fill-dev"
6
+ os.environ["PORTRAIT_MODEL_PATH"] = "ms://iic/ACE_Plus@portrait/comfyui_portrait_lora64.safetensors"
7
+ os.environ["SUBJECT_MODEL_PATH"] = "ms://iic/ACE_Plus@subject/comfyui_subject_lora16.safetensors"
8
+ os.environ["LOCAL_MODEL_PATH"] = "ms://iic/ACE_Plus@local_editing/comfyui_local_lora16.safetensors"
9
+
10
+ # Import ACEInference and Config from the ACE_plus repo
11
+ from inference.ace_plus_inference import ACEInference
12
+ from scepter.modules.utils.config import Config
13
+
14
+ # Define a minimal configuration dictionary.
15
+ # Adjust the "MODEL" field as required by your ACE++ setup.
16
+ config_dict = {
17
+ "MODEL": {
18
+ "type": "YourACEModelType", # Replace with the actual model type string used in ACE_plus.
19
+ "pretrained_path": os.getenv("PORTRAIT_MODEL_PATH")
20
+ },
21
+ "MAX_SEQ_LEN": 77,
22
+ "SAMPLE_ARGS": {
23
+ "prompt": "Face swap"
24
+ },
25
+ "DTYPE": "bfloat16"
26
+ }
27
+ cfg = Config(config_dict)
28
+
29
+ # Instantiate the ACEInference object.
30
+ ace_infer = ACEInference(cfg)
31
+
32
+ def face_swap_app(target_img, face_img):
33
+ """
34
+ Swaps the face in the target image using the provided face image via ACE++.
35
+
36
+ Parameters:
37
+ target_img: The image in which you want to swap a face.
38
+ face_img: The reference face image to insert.
39
+
40
+ Returns:
41
+ The output image after applying ACE++ face swapping.
42
+ """
43
+ # For ACEInference, we pass:
44
+ # - reference_image: the target image,
45
+ # - edit_image: the new face image,
46
+ # - edit_mask: set to None so the image processor will create it,
47
+ # - prompt: "Face swap" instructs the model to perform face swapping.
48
+ # Other parameters (output dimensions, sampler, etc.) are set here as desired.
49
+ output_img, edit_image, change_image, mask, seed = ace_infer(
50
+ reference_image=target_img,
51
+ edit_image=face_img,
52
+ edit_mask=None, # No manual mask provided; let ACE++ handle it
53
+ prompt="Face swap",
54
+ output_height=1024,
55
+ output_width=1024,
56
+ sampler='flow_euler',
57
+ sample_steps=28,
58
+ guide_scale=50,
59
+ seed=-1 # Use a random seed if not specified
60
+ )
61
+ return output_img
62
+
63
+ # Create the Gradio interface.
64
+ iface = gr.Interface(
65
+ fn=face_swap_app,
66
+ inputs=[
67
+ gr.Image(type="pil", label="Target Image"),
68
+ gr.Image(type="pil", label="Face Image")
69
+ ],
70
+ outputs=gr.Image(type="pil", label="Swapped Face Output"),
71
+ title="ACE++ Face Swap Demo",
72
+ description="Upload a target image and a face image to swap the face using the ACE++ model."
73
+ )
74
+
75
+ if __name__ == "__main__":
76
+ iface.launch()
inference/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .ace_plus_diffusers import ACEPlusDiffuserInference
2
+ from .ace_plus_inference import ACEInference
inference/ace_plus_diffusers.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import random
4
+ from collections import OrderedDict
5
+
6
+ import torch, os
7
+ from diffusers import FluxFillPipeline
8
+ from scepter.modules.utils.config import Config
9
+ from scepter.modules.utils.distribute import we
10
+ from scepter.modules.utils.file_system import FS
11
+ from scepter.modules.utils.logger import get_logger
12
+ from transformers import T5TokenizerFast
13
+ from .utils import ACEPlusImageProcessor
14
+
15
+ class ACEPlusDiffuserInference():
16
+ def __init__(self, logger=None):
17
+ if logger is None:
18
+ logger = get_logger(name='ace_plus')
19
+ self.logger = logger
20
+ self.input = {}
21
+
22
+ def load_default(self, cfg):
23
+ if cfg is not None:
24
+ self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()}
25
+ self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()}
26
+ self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()}
27
+
28
+ def init_from_cfg(self, cfg):
29
+ self.max_seq_len = cfg.get("MAX_SEQ_LEN", 4096)
30
+ self.image_processor = ACEPlusImageProcessor(max_seq_len=self.max_seq_len)
31
+
32
+ local_folder = FS.get_dir_to_local_dir(cfg.MODEL.PRETRAINED_MODEL)
33
+
34
+ self.pipe = FluxFillPipeline.from_pretrained(local_folder, torch_dtype=torch.bfloat16).to(we.device_id)
35
+
36
+ tokenizer_2 = T5TokenizerFast.from_pretrained(os.path.join(local_folder, "tokenizer_2"),
37
+ additional_special_tokens=["{image}"])
38
+ self.pipe.tokenizer_2 = tokenizer_2
39
+ self.load_default(cfg.DEFAULT_PARAS)
40
+
41
+ def prepare_input(self,
42
+ image,
43
+ mask,
44
+ batch_size=1,
45
+ dtype = torch.bfloat16,
46
+ num_images_per_prompt=1,
47
+ height=512,
48
+ width=512,
49
+ generator=None):
50
+ num_channels_latents = self.pipe.vae.config.latent_channels
51
+ # import pdb;pdb.set_trace()
52
+ mask, masked_image_latents = self.pipe.prepare_mask_latents(
53
+ mask.unsqueeze(0),
54
+ image.unsqueeze(0).to(we.device_id, dtype = dtype),
55
+ batch_size,
56
+ num_channels_latents,
57
+ num_images_per_prompt,
58
+ height,
59
+ width,
60
+ dtype,
61
+ we.device_id,
62
+ generator,
63
+ )
64
+ # import pdb;pdb.set_trace()
65
+ masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
66
+ return masked_image_latents
67
+
68
+ @torch.no_grad()
69
+ def __call__(self,
70
+ reference_image=None,
71
+ edit_image=None,
72
+ edit_mask=None,
73
+ prompt='',
74
+ task=None,
75
+ output_height=1024,
76
+ output_width=1024,
77
+ sampler='flow_euler',
78
+ sample_steps=28,
79
+ guide_scale=50,
80
+ lora_path=None,
81
+ seed=-1,
82
+ tar_index=0,
83
+ align=0,
84
+ repainting_scale=0,
85
+ **kwargs):
86
+ if isinstance(prompt, str):
87
+ prompt = [prompt]
88
+ seed = seed if seed >= 0 else random.randint(0, 2 ** 32 - 1)
89
+ # edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
90
+ image, mask, _, _, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
91
+ width = output_width,
92
+ height = output_height,
93
+ repainting_scale = repainting_scale)
94
+ h, w = image.shape[1:]
95
+ generator = torch.Generator("cpu").manual_seed(seed)
96
+ masked_image_latents = self.prepare_input(image, mask,
97
+ batch_size=len(prompt) , height=h, width=w, generator = generator)
98
+
99
+ if lora_path is not None:
100
+ with FS.get_from(lora_path) as local_path:
101
+ self.pipe.load_lora_weights(local_path)
102
+
103
+
104
+
105
+ image = self.pipe(
106
+ prompt=prompt,
107
+ masked_image_latents=masked_image_latents,
108
+ height=h,
109
+ width=w,
110
+ guidance_scale=guide_scale,
111
+ num_inference_steps=sample_steps,
112
+ max_sequence_length=512,
113
+ generator=generator
114
+ ).images[0]
115
+ if lora_path is not None:
116
+ self.pipe.unload_lora_weights()
117
+ return self.image_processor.postprocess(image, slice_w, out_w, out_h), seed
118
+
119
+
120
+ if __name__ == '__main__':
121
+ pass
inference/ace_plus_inference.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import random
4
+ from collections import OrderedDict
5
+
6
+ import torch, numpy as np
7
+ from PIL import Image
8
+ from scepter.modules.model.registry import MODELS
9
+ from scepter.modules.utils.config import Config
10
+ from scepter.modules.utils.distribute import we
11
+ from .registry import BaseInference, INFERENCES
12
+ from .utils import ACEPlusImageProcessor
13
+
14
+ @INFERENCES.register_class()
15
+ class ACEInference(BaseInference):
16
+ '''
17
+ reuse the ldm code
18
+ '''
19
+ def __init__(self, cfg, logger=None):
20
+ super().__init__(cfg, logger)
21
+ self.pipe = MODELS.build(cfg.MODEL, logger=self.logger).eval().to(we.device_id)
22
+ self.image_processor = ACEPlusImageProcessor(max_seq_len=cfg.MAX_SEQ_LEN)
23
+ self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for
24
+ k, v in cfg.SAMPLE_ARGS.items()}
25
+ self.dtype = getattr(torch, cfg.get("DTYPE", "bfloat16"))
26
+ @torch.no_grad()
27
+ def __call__(self,
28
+ reference_image=None,
29
+ edit_image=None,
30
+ edit_mask=None,
31
+ prompt='',
32
+ edit_type=None,
33
+ output_height=1024,
34
+ output_width=1024,
35
+ sampler='flow_euler',
36
+ sample_steps=28,
37
+ guide_scale=50,
38
+ lora_path=None,
39
+ seed=-1,
40
+ repainting_scale=0,
41
+ use_change=False,
42
+ keep_pixels=False,
43
+ keep_pixels_rate=0.8,
44
+ **kwargs):
45
+ # convert the input info to the input of ldm.
46
+ if isinstance(prompt, str):
47
+ prompt = [prompt]
48
+ seed = seed if seed >= 0 else random.randint(0, 2 ** 24 - 1)
49
+ image, mask, change_image, content_image, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
50
+ height=output_height, width=output_width,
51
+ repainting_scale=repainting_scale,
52
+ keep_pixels=keep_pixels,
53
+ keep_pixels_rate=keep_pixels_rate,
54
+ use_change = use_change)
55
+ change_image = [None] if change_image is None else [change_image.to(we.device_id)]
56
+ image, mask = [image.to(we.device_id)], [mask.to(we.device_id)]
57
+
58
+ (src_image_list, src_mask_list, modify_image_list,
59
+ edit_id, prompt) = [image], [mask], [change_image], [[0]], [prompt]
60
+
61
+ with torch.amp.autocast(enabled=True, dtype=self.dtype, device_type='cuda'):
62
+ out_image = self.pipe(
63
+ src_image_list=src_image_list,
64
+ modify_image_list= modify_image_list,
65
+ src_mask_list=src_mask_list,
66
+ edit_id=edit_id,
67
+ image=image,
68
+ image_mask=mask,
69
+ prompt=prompt,
70
+ sampler='flow_euler',
71
+ sample_steps=sample_steps,
72
+ seed=seed,
73
+ guide_scale=guide_scale,
74
+ show_process=True,
75
+ )
76
+ imgs = [x_i['reconstruct_image'].float().permute(1, 2, 0).cpu().numpy()
77
+ for x_i in out_image
78
+ ]
79
+ imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
80
+ edit_image = Image.fromarray((torch.clamp(image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
81
+ change_image = Image.fromarray((torch.clamp(change_image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
82
+ mask = Image.fromarray((mask[0] * 255).squeeze(0).cpu().numpy().astype(np.uint8))
83
+ return self.image_processor.postprocess(imgs[0], slice_w, out_w, out_h), edit_image, change_image, mask, seed
inference/registry.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import torch
5
+ from PIL.Image import Image
6
+ from collections import OrderedDict
7
+ from scepter.modules.utils.distribute import we
8
+ from scepter.modules.utils.config import Config
9
+ from scepter.modules.utils.logger import get_logger
10
+ from scepter.studio.utils.env import get_available_memory
11
+ from scepter.modules.model.registry import MODELS, BACKBONES, EMBEDDERS
12
+ from scepter.modules.utils.registry import Registry, build_from_config
13
+ def get_model(model_tuple):
14
+ assert 'model' in model_tuple
15
+ return model_tuple['model']
16
+
17
+ class BaseInference():
18
+ '''
19
+ support to load the components dynamicly.
20
+ create and load model when run this model at the first time.
21
+ '''
22
+ def __init__(self, cfg, logger=None):
23
+ if logger is None:
24
+ logger = get_logger(name='scepter')
25
+ self.logger = logger
26
+ self.name = cfg.NAME
27
+
28
+ def init_from_modules(self, modules):
29
+ for k, v in modules.items():
30
+ self.__setattr__(k, v)
31
+
32
+ def infer_model(self, cfg, module_paras=None):
33
+ module = {
34
+ 'model': None,
35
+ 'cfg': cfg,
36
+ 'device': 'offline',
37
+ 'name': cfg.NAME,
38
+ 'function_info': {},
39
+ 'paras': {}
40
+ }
41
+ if module_paras is None:
42
+ return module
43
+ function_info = {}
44
+ paras = {
45
+ k.lower(): v
46
+ for k, v in module_paras.get('PARAS', {}).items()
47
+ }
48
+ for function in module_paras.get('FUNCTION', []):
49
+ input_dict = {}
50
+ for inp in function.get('INPUT', []):
51
+ if inp.lower() in self.input:
52
+ input_dict[inp.lower()] = self.input[inp.lower()]
53
+ function_info[function.NAME] = {
54
+ 'dtype': function.get('DTYPE', 'float32'),
55
+ 'input': input_dict
56
+ }
57
+ module['paras'] = paras
58
+ module['function_info'] = function_info
59
+ return module
60
+
61
+ def init_from_ckpt(self, path, model, ignore_keys=list()):
62
+ if path.endswith('safetensors'):
63
+ from safetensors.torch import load_file as load_safetensors
64
+ sd = load_safetensors(path)
65
+ else:
66
+ sd = torch.load(path, map_location='cpu', weights_only=True)
67
+
68
+ new_sd = OrderedDict()
69
+ for k, v in sd.items():
70
+ ignored = False
71
+ for ik in ignore_keys:
72
+ if ik in k:
73
+ if we.rank == 0:
74
+ self.logger.info(
75
+ 'Ignore key {} from state_dict.'.format(k))
76
+ ignored = True
77
+ break
78
+ if not ignored:
79
+ new_sd[k] = v
80
+
81
+ missing, unexpected = model.load_state_dict(new_sd, strict=False)
82
+ if we.rank == 0:
83
+ self.logger.info(
84
+ f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
85
+ )
86
+ if len(missing) > 0:
87
+ self.logger.info(f'Missing Keys:\n {missing}')
88
+ if len(unexpected) > 0:
89
+ self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
90
+
91
+ def load(self, module):
92
+ if module['device'] == 'offline':
93
+ from scepter.modules.utils.import_utils import LazyImportModule
94
+ if (LazyImportModule.get_module_type(('MODELS', module['cfg'].NAME)) or
95
+ module['cfg'].NAME in MODELS.class_map):
96
+ model = MODELS.build(module['cfg'], logger=self.logger).eval()
97
+ elif (LazyImportModule.get_module_type(('BACKBONES', module['cfg'].NAME)) or
98
+ module['cfg'].NAME in BACKBONES.class_map):
99
+ model = BACKBONES.build(module['cfg'],
100
+ logger=self.logger).eval()
101
+ elif (LazyImportModule.get_module_type(('EMBEDDERS', module['cfg'].NAME)) or
102
+ module['cfg'].NAME in EMBEDDERS.class_map):
103
+ model = EMBEDDERS.build(module['cfg'],
104
+ logger=self.logger).eval()
105
+ else:
106
+ raise NotImplementedError
107
+ if 'DTYPE' in module['cfg'] and module['cfg']['DTYPE'] is not None:
108
+ model = model.to(getattr(torch, module['cfg'].DTYPE))
109
+ if module['cfg'].get('RELOAD_MODEL', None):
110
+ self.init_from_ckpt(module['cfg'].RELOAD_MODEL, model)
111
+ module['model'] = model
112
+ module['device'] = 'cpu'
113
+ if module['device'] == 'cpu':
114
+ module['device'] = we.device_id
115
+ module['model'] = module['model'].to(we.device_id)
116
+ return module
117
+
118
+ def unload(self, module):
119
+ if module is None:
120
+ return module
121
+ mem = get_available_memory()
122
+ free_mem = int(mem['available'] / (1024**2))
123
+ total_mem = int(mem['total'] / (1024**2))
124
+ if free_mem < 0.5 * total_mem:
125
+ if module['model'] is not None:
126
+ module['model'] = module['model'].to('cpu')
127
+ del module['model']
128
+ module['model'] = None
129
+ module['device'] = 'offline'
130
+ print('delete module')
131
+ else:
132
+ if module['model'] is not None:
133
+ module['model'] = module['model'].to('cpu')
134
+ module['device'] = 'cpu'
135
+ else:
136
+ module['device'] = 'offline'
137
+ if torch.cuda.is_available():
138
+ torch.cuda.empty_cache()
139
+ torch.cuda.ipc_collect()
140
+ return module
141
+
142
+ def dynamic_load(self, module=None, name=''):
143
+ self.logger.info('Loading {} model'.format(name))
144
+ if name == 'all':
145
+ for subname in self.loaded_model_name:
146
+ self.loaded_model[subname] = self.dynamic_load(
147
+ getattr(self, subname), subname)
148
+ elif name in self.loaded_model_name:
149
+ if name in self.loaded_model:
150
+ if module['cfg'] != self.loaded_model[name]['cfg']:
151
+ self.unload(self.loaded_model[name])
152
+ module = self.load(module)
153
+ self.loaded_model[name] = module
154
+ return module
155
+ elif module['device'] == 'cpu' or module['device'] == 'offline':
156
+ module = self.load(module)
157
+ return module
158
+ else:
159
+ return module
160
+ else:
161
+ module = self.load(module)
162
+ self.loaded_model[name] = module
163
+ return module
164
+ else:
165
+ return self.load(module)
166
+
167
+ def dynamic_unload(self, module=None, name='', skip_loaded=False):
168
+ self.logger.info('Unloading {} model'.format(name))
169
+ if name == 'all':
170
+ for name, module in self.loaded_model.items():
171
+ module = self.unload(self.loaded_model[name])
172
+ self.loaded_model[name] = module
173
+ elif name in self.loaded_model_name:
174
+ if name in self.loaded_model:
175
+ if not skip_loaded:
176
+ module = self.unload(self.loaded_model[name])
177
+ self.loaded_model[name] = module
178
+ else:
179
+ self.unload(module)
180
+ else:
181
+ self.unload(module)
182
+
183
+ def load_default(self, cfg):
184
+ module_paras = {}
185
+ if cfg is not None:
186
+ self.paras = cfg.PARAS
187
+ self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()}
188
+ self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()}
189
+ self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()}
190
+ module_paras = cfg.MODULES_PARAS
191
+ return module_paras
192
+
193
+ def load_image(self, image, num_samples=1):
194
+ if isinstance(image, torch.Tensor):
195
+ pass
196
+ elif isinstance(image, Image):
197
+ pass
198
+ elif isinstance(image, Image):
199
+ pass
200
+
201
+ def get_function_info(self, module, function_name=None):
202
+ all_function = module['function_info']
203
+ if function_name in all_function:
204
+ return function_name, all_function[function_name]['dtype']
205
+ if function_name is None and len(all_function) == 1:
206
+ for k, v in all_function.items():
207
+ return k, v['dtype']
208
+
209
+ @torch.no_grad()
210
+ def __call__(self,
211
+ input,
212
+ **kwargs):
213
+ return
214
+
215
+ def build_inference(cfg, registry, logger=None, *args, **kwargs):
216
+ """ After build model, load pretrained model if exists key `pretrain`.
217
+
218
+ pretrain (str, dict): Describes how to load pretrained model.
219
+ str, treat pretrain as model path;
220
+ dict: should contains key `path`, and other parameters token by function load_pretrained();
221
+ """
222
+ if not isinstance(cfg, Config):
223
+ raise TypeError(f'Config must be type dict, got {type(cfg)}')
224
+ model = build_from_config(cfg, registry, logger=logger, *args, **kwargs)
225
+ return model
226
+
227
+ # reigister cls for diffusion.
228
+ INFERENCES = Registry('INFERENCE', build_func=build_inference)
inference/utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+
5
+ import torch
6
+ import torchvision.transforms as T
7
+ import numpy as np
8
+ from scepter.modules.annotator.registry import ANNOTATORS
9
+ from scepter.modules.utils.config import Config
10
+ from PIL import Image
11
+
12
+
13
+ def edit_preprocess(processor, device, edit_image, edit_mask):
14
+ if edit_image is None or processor is None:
15
+ return edit_image
16
+ processor = Config(cfg_dict=processor, load=False)
17
+ processor = ANNOTATORS.build(processor).to(device)
18
+ new_edit_image = processor(np.asarray(edit_image))
19
+ processor = processor.to("cpu")
20
+ del processor
21
+ new_edit_image = Image.fromarray(new_edit_image)
22
+ return Image.composite(new_edit_image, edit_image, edit_mask)
23
+
24
+ class ACEPlusImageProcessor():
25
+ def __init__(self, max_aspect_ratio=4, d=16, max_seq_len=1024):
26
+ self.max_aspect_ratio = max_aspect_ratio
27
+ self.d = d
28
+ self.max_seq_len = max_seq_len
29
+ self.transforms = T.Compose([
30
+ T.ToTensor(),
31
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
32
+ ])
33
+
34
+ def image_check(self, image):
35
+ if image is None:
36
+ return image
37
+ # preprocess
38
+ W, H = image.size
39
+ if H / W > self.max_aspect_ratio:
40
+ image = T.CenterCrop([int(self.max_aspect_ratio * W), W])(image)
41
+ elif W / H > self.max_aspect_ratio:
42
+ image = T.CenterCrop([H, int(self.max_aspect_ratio * H)])(image)
43
+ return self.transforms(image)
44
+
45
+
46
+ def preprocess(self,
47
+ reference_image=None,
48
+ edit_image=None,
49
+ edit_mask=None,
50
+ height=1024,
51
+ width=1024,
52
+ repainting_scale = 1.0,
53
+ keep_pixels = False,
54
+ keep_pixels_rate = 0.8,
55
+ use_change = False):
56
+ reference_image = self.image_check(reference_image)
57
+ edit_image = self.image_check(edit_image)
58
+ # for reference generation
59
+ if edit_image is None:
60
+ edit_image = torch.zeros([3, height, width])
61
+ edit_mask = torch.ones([1, height, width])
62
+ else:
63
+ if edit_mask is None:
64
+ _, eH, eW = edit_image.shape
65
+ edit_mask = np.ones((eH, eW))
66
+ else:
67
+ edit_mask = np.asarray(edit_mask)
68
+ edit_mask = np.where(edit_mask > 128, 1, 0)
69
+ edit_mask = edit_mask.astype(
70
+ np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
71
+ np.float32)
72
+ edit_mask = torch.tensor(edit_mask).unsqueeze(0)
73
+
74
+ edit_image = edit_image * (1 - edit_mask * repainting_scale)
75
+
76
+
77
+ out_h, out_w = edit_image.shape[-2:]
78
+
79
+ assert edit_mask is not None
80
+ if reference_image is not None:
81
+ _, H, W = reference_image.shape
82
+ _, eH, eW = edit_image.shape
83
+ if not keep_pixels:
84
+ # align height with edit_image
85
+ scale = eH / H
86
+ tH, tW = eH, int(W * scale)
87
+ reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
88
+ reference_image)
89
+ else:
90
+ # padding
91
+ if H >= keep_pixels_rate * eH:
92
+ tH = int(eH * keep_pixels_rate)
93
+ scale = tH/H
94
+ tW = int(W * scale)
95
+ reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
96
+ reference_image)
97
+ rH, rW = reference_image.shape[-2:]
98
+ delta_w = 0
99
+ delta_h = eH - rH
100
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
101
+ reference_image = T.Pad(padding, fill=0, padding_mode="constant")(reference_image)
102
+ edit_image = torch.cat([reference_image, edit_image], dim=-1)
103
+ edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
104
+ slice_w = reference_image.shape[-1]
105
+ else:
106
+ slice_w = 0
107
+
108
+ H, W = edit_image.shape[-2:]
109
+ scale = min(1.0, math.sqrt(self.max_seq_len * 2 / ((H / self.d) * (W / self.d))))
110
+ rH = int(H * scale) // self.d * self.d # ensure divisible by self.d
111
+ rW = int(W * scale) // self.d * self.d
112
+ slice_w = int(slice_w * scale) // self.d * self.d
113
+
114
+ edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_image)
115
+ edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)
116
+ content_image = edit_image
117
+ if use_change:
118
+ change_image = edit_image * edit_mask
119
+ edit_image = edit_image * (1 - edit_mask)
120
+ else:
121
+ change_image = None
122
+ return edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
123
+
124
+
125
+ def postprocess(self, image, slice_w, out_w, out_h):
126
+ w, h = image.size
127
+ if slice_w > 0:
128
+ output_image = image.crop((slice_w + 30, 0, w, h))
129
+ output_image = output_image.resize((out_w, out_h))
130
+ else:
131
+ output_image = image
132
+ return output_image
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ scepter
3
+ torch
4
+ torchvision
5
+ transformers