Oilkkkkbb / run_magicfu.py
Shatei's picture
Update space
5e373a9
# Copyright 2024 Adobe. All rights reserved.
#%%
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from itertools import islice
from torch import autocast
import torchvision
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from torchvision.transforms import Resize
import argparse
import os
import pathlib
import glob
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def fix_img(test_img):
width, height = test_img.size
if width != height:
left = 0
right = height
bottom = height
top = 0
return test_img.crop((left, top, right, bottom))
else:
return test_img
# util funcs
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def get_tensor_clip(normalize=True, toTensor=True):
transform_list = []
if toTensor:
transform_list += [torchvision.transforms.ToTensor()]
if normalize:
transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))]
return torchvision.transforms.Compose(transform_list)
def get_tensor_dino(normalize=True, toTensor=True):
transform_list = [torchvision.transforms.Resize((224,224))]
if toTensor:
transform_list += [torchvision.transforms.ToTensor()]
if normalize:
transform_list += [lambda x: 255.0 * x[:3],
torchvision.transforms.Normalize(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
)]
return torchvision.transforms.Compose(transform_list)
def get_tensor(normalize=True, toTensor=True):
transform_list = []
if toTensor:
transform_list += [torchvision.transforms.ToTensor()]
if normalize:
transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
transform_list += [
torchvision.transforms.Resize(512),
torchvision.transforms.CenterCrop(512)
]
return torchvision.transforms.Compose(transform_list)
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def load_model_from_config(config, ckpt, verbose=False):
model = instantiate_from_config(config.model)
# print('NOTE: NO CHECKPOINT IS LOADED')
if ckpt is not None:
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
# sd = pl_sd["state_dict"]
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
def get_model(config_path, ckpt_path):
config = OmegaConf.load(f"{config_path}")
model = load_model_from_config(config, None)
pl_sd = torch.load(ckpt_path, map_location="cpu")
m, u = model.load_state_dict(pl_sd, strict=True)
if len(m) > 0:
print("WARNING: missing keys:")
print(m)
if len(u) > 0:
print("unexpected keys:")
print(u)
model = model.to(device)
return model
def get_grid(size):
y = np.repeat(np.arange(size)[None, ...], size)
y = y.reshape(size, size)
x = y.transpose()
out = np.stack([y,x], -1)
return out
def un_norm(x):
return (x+1.0)/2.0
class MagicFixup:
def __init__(self, model_path='/sensei-fs/users/halzayer/collage2photo/Paint-by-Example/official_checkpoint_image_attn_200k.pt'):
self.model = get_model('configs/collage_mix_train.yaml',model_path)
def edit_image(self, ref_image, coarse_edit, mask_tensor, start_step, steps):
# essentially sample
sampler = DDIMSampler(self.model)
start_code = None
transformed_grid = torch.zeros((2, 64, 64))
self.model.model.og_grid = None
self.model.model.transformed_grid = transformed_grid.unsqueeze(0).to(self.model.device)
scale = 1.0
C, f, H, W= 4, 8, 512, 512
n_samples = 1
ddim_steps = steps
ddim_eta = 1.0
step = start_step
with torch.no_grad():
with autocast("cuda"):
with self.model.ema_scope():
image_tensor = get_tensor(toTensor=False)(coarse_edit)
clean_ref_tensor = get_tensor(toTensor=False)(ref_image)
clean_ref_tensor = clean_ref_tensor.unsqueeze(0)
ref_tensor=get_tensor_dino(toTensor=False)(ref_image).unsqueeze(0)
b_mask = mask_tensor.cpu() < 0.5
# inpainting
reference = un_norm(image_tensor)
reference = reference.squeeze()
ref_cv = torch.moveaxis(reference, 0, -1).cpu().numpy()
ref_cv = (ref_cv * 255).astype(np.uint8)
cv_mask = b_mask.int().squeeze().cpu().numpy().astype(np.uint8)
kernel = np.ones((7,7))
dilated_mask = cv2.dilate(cv_mask, kernel)
dst = cv2.inpaint(ref_cv,dilated_mask,3,cv2.INPAINT_NS)
# dst = inpaint.inpaint_biharmonic(ref_cv, dilated_mask, channel_axis=-1)
dst_tensor = torch.tensor(dst).moveaxis(-1, 0) / 255.0
image_tensor = (dst_tensor * 2.0) - 1.0
image_tensor = image_tensor.unsqueeze(0)
ref_tensor = ref_tensor
inpaint_image = image_tensor#*mask_tensor
test_model_kwargs={}
test_model_kwargs['inpaint_mask']=mask_tensor.to(device)
test_model_kwargs['inpaint_image']=inpaint_image.to(device)
clean_ref_tensor = clean_ref_tensor.to(device)
ref_tensor=ref_tensor.to(device)
uc = None
if scale != 1.0:
uc = self.model.learnable_vector
c = self.model.get_learned_conditioning(ref_tensor.to(torch.float16))
c = self.model.proj_out(c)
z_inpaint = self.model.encode_first_stage(test_model_kwargs['inpaint_image'])
z_inpaint = self.model.get_first_stage_encoding(z_inpaint).detach()
z_ref = self.model.encode_first_stage(clean_ref_tensor)
z_ref = self.model.get_first_stage_encoding(z_ref).detach()
test_model_kwargs['inpaint_image']=z_inpaint
test_model_kwargs['inpaint_mask']=Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(test_model_kwargs['inpaint_mask'])
shape = [C, H // f, W // f]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
z_ref=z_ref,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code,
test_model_kwargs=test_model_kwargs,
x0=z_inpaint,
x0_step=step,
ddim_discretize='uniform',
drop_latent_guidance=1.0
)
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
x_checked_image=x_samples_ddim
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
return x_checked_image_torch
#%%
#%%
import time
# %%
def file_exists(path):
""" Check if a file exists and is not a directory. """
if not os.path.isfile(path):
raise argparse.ArgumentTypeError(f"{path} is not a valid file.")
return path
def parse_arguments():
""" Parses command-line arguments. """
parser = argparse.ArgumentParser(description="Process images based on provided paths.")
parser.add_argument("--checkpoint", type=file_exists, required=True, help="Path to the MagicFixup checkpoint file.")
parser.add_argument("--reference", type=file_exists, default='examples/fox_drinking_og.png', help="Path to the reference original image.")
parser.add_argument("--edit", type=file_exists, default='examples/fox_drinking__edit__01.png', help="Path to the image edit. Make sure the alpha channel is set properly")
parser.add_argument("--output-dir", type=str, default='./outputs', help="Path to the folder where to save the outputs")
parser.add_argument("--samples", type=int, default=5, help="number of samples to output")
return parser.parse_args()
def main():
# Parse arguments
args = parse_arguments()
# create magic fixup model
magic_fixup = MagicFixup(model_path=args.checkpoint)
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
# run it here
to_tensor = torchvision.transforms.ToTensor()
ref_path = args.reference
coarse_edit_path = args.edit
mask_edit_path = coarse_edit_path
edit_file_name = pathlib.Path(coarse_edit_path).stem
save_pattern = f'{output_dir}/{edit_file_name}__sample__*.png'
save_counter = len(glob.glob(save_pattern))
all_rgbs = []
for i in range(args.samples):
with autocast("cuda"):
ref_image_t = to_tensor(Image.open(ref_path).convert('RGB').resize((512,512))).half().cuda()
coarse_edit_t = to_tensor(Image.open(coarse_edit_path).resize((512,512))).half().cuda()
# get mask from coarse
# mask_t = torch.ones_like(coarse_edit_t[-1][None, None,...])
coarse_edit_mask_t = to_tensor(Image.open(mask_edit_path).resize((512,512))).half().cuda()
# get mask from coarse
mask_t = (coarse_edit_mask_t[-1][None, None,...]).half() # do center crop
coarse_edit_t_rgb = coarse_edit_t[:-1]
out_rgb = magic_fixup.edit_image(ref_image_t, coarse_edit_t_rgb, mask_t, start_step=1.0, steps=50)
all_rgbs.append(out_rgb.squeeze().cpu().detach().float())
save_path = f'{output_dir}/{edit_file_name}__sample__{save_counter:03d}.png'
torchvision.utils.save_image(all_rgbs[i], save_path)
save_counter += 1
if __name__ == "__main__":
main()