fc-simple / inference /ace_plus_inference.py
ekhatskevich
initial commit
9235b7f
raw
history blame
4.12 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
from collections import OrderedDict
import torch, numpy as np
from PIL import Image
from scepter.modules.model.registry import MODELS
from scepter.modules.utils.config import Config
from scepter.modules.utils.distribute import we
from .registry import BaseInference, INFERENCES
from .utils import ACEPlusImageProcessor
@INFERENCES.register_class()
class ACEInference(BaseInference):
'''
reuse the ldm code
'''
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger)
self.pipe = MODELS.build(cfg.MODEL, logger=self.logger).eval().to(we.device_id)
self.image_processor = ACEPlusImageProcessor(max_seq_len=cfg.MAX_SEQ_LEN)
self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for
k, v in cfg.SAMPLE_ARGS.items()}
self.dtype = getattr(torch, cfg.get("DTYPE", "bfloat16"))
@torch.no_grad()
def __call__(self,
reference_image=None,
edit_image=None,
edit_mask=None,
prompt='',
edit_type=None,
output_height=1024,
output_width=1024,
sampler='flow_euler',
sample_steps=28,
guide_scale=50,
lora_path=None,
seed=-1,
repainting_scale=0,
use_change=False,
keep_pixels=False,
keep_pixels_rate=0.8,
**kwargs):
# convert the input info to the input of ldm.
if isinstance(prompt, str):
prompt = [prompt]
seed = seed if seed >= 0 else random.randint(0, 2 ** 24 - 1)
image, mask, change_image, content_image, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
height=output_height, width=output_width,
repainting_scale=repainting_scale,
keep_pixels=keep_pixels,
keep_pixels_rate=keep_pixels_rate,
use_change = use_change)
change_image = [None] if change_image is None else [change_image.to(we.device_id)]
image, mask = [image.to(we.device_id)], [mask.to(we.device_id)]
(src_image_list, src_mask_list, modify_image_list,
edit_id, prompt) = [image], [mask], [change_image], [[0]], [prompt]
with torch.amp.autocast(enabled=True, dtype=self.dtype, device_type='cuda'):
out_image = self.pipe(
src_image_list=src_image_list,
modify_image_list= modify_image_list,
src_mask_list=src_mask_list,
edit_id=edit_id,
image=image,
image_mask=mask,
prompt=prompt,
sampler='flow_euler',
sample_steps=sample_steps,
seed=seed,
guide_scale=guide_scale,
show_process=True,
)
imgs = [x_i['reconstruct_image'].float().permute(1, 2, 0).cpu().numpy()
for x_i in out_image
]
imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
edit_image = Image.fromarray((torch.clamp(image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
change_image = Image.fromarray((torch.clamp(change_image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
mask = Image.fromarray((mask[0] * 255).squeeze(0).cpu().numpy().astype(np.uint8))
return self.image_processor.postprocess(imgs[0], slice_w, out_w, out_h), edit_image, change_image, mask, seed