|
import cv2 |
|
import einops |
|
import numpy as np |
|
import torch |
|
import random |
|
from pytorch_lightning import seed_everything |
|
from cldm.model import create_model, load_state_dict |
|
from cldm.ddim_hacked import DDIMSampler |
|
from cldm.hack import disable_verbosity, enable_sliced_attention |
|
from datasets.data_utils import * |
|
cv2.setNumThreads(0) |
|
cv2.ocl.setUseOpenCL(False) |
|
import albumentations as A |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
|
|
|
|
save_memory = True |
|
disable_verbosity() |
|
if save_memory: |
|
enable_sliced_attention() |
|
|
|
|
|
config = OmegaConf.load('./configs/inference.yaml') |
|
model_ckpt = config.pretrained_model |
|
model_config = config.config_file |
|
|
|
model = create_model(model_config ).cpu() |
|
model.load_state_dict(load_state_dict(model_ckpt, location='cuda')) |
|
model = model.cuda() |
|
ddim_sampler = DDIMSampler(model) |
|
|
|
|
|
|
|
def aug_data_mask(image, mask): |
|
transform = A.Compose([ |
|
A.HorizontalFlip(p=0.5), |
|
A.RandomBrightnessContrast(p=0.5), |
|
]) |
|
transformed = transform(image=image.astype(np.uint8), mask = mask) |
|
transformed_image = transformed["image"] |
|
transformed_mask = transformed["mask"] |
|
return transformed_image, transformed_mask |
|
|
|
|
|
def process_pairs(ref_image, ref_mask, tar_image, tar_mask): |
|
|
|
|
|
ref_box_yyxx = get_bbox_from_mask(ref_mask) |
|
|
|
|
|
ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1) |
|
masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3) |
|
|
|
y1,y2,x1,x2 = ref_box_yyxx |
|
masked_ref_image = masked_ref_image[y1:y2,x1:x2,:] |
|
ref_mask = ref_mask[y1:y2,x1:x2] |
|
|
|
|
|
ratio = np.random.randint(12, 13) / 10 |
|
masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio) |
|
ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1) |
|
|
|
|
|
masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False) |
|
masked_ref_image = cv2.resize(masked_ref_image, (224,224) ).astype(np.uint8) |
|
|
|
ref_mask_3 = pad_to_square(ref_mask_3 * 255, pad_value = 0, random = False) |
|
ref_mask_3 = cv2.resize(ref_mask_3, (224,224) ).astype(np.uint8) |
|
ref_mask = ref_mask_3[:,:,0] |
|
|
|
|
|
masked_ref_image_aug = masked_ref_image |
|
|
|
|
|
masked_ref_image_compose, ref_mask_compose = masked_ref_image, ref_mask |
|
masked_ref_image_aug = masked_ref_image_compose.copy() |
|
ref_mask_3 = np.stack([ref_mask_compose,ref_mask_compose,ref_mask_compose],-1) |
|
ref_image_collage = sobel(masked_ref_image_compose, ref_mask_compose/255) |
|
|
|
|
|
tar_box_yyxx = get_bbox_from_mask(tar_mask) |
|
tar_box_yyxx = expand_bbox(tar_mask, tar_box_yyxx, ratio=[1.1,1.2]) |
|
|
|
|
|
tar_box_yyxx_crop = expand_bbox(tar_image, tar_box_yyxx, ratio=[1.5, 3]) |
|
tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop) |
|
y1,y2,x1,x2 = tar_box_yyxx_crop |
|
|
|
cropped_target_image = tar_image[y1:y2,x1:x2,:] |
|
tar_box_yyxx = box_in_box(tar_box_yyxx, tar_box_yyxx_crop) |
|
y1,y2,x1,x2 = tar_box_yyxx |
|
|
|
|
|
ref_image_collage = cv2.resize(ref_image_collage, (x2-x1, y2-y1)) |
|
ref_mask_compose = cv2.resize(ref_mask_compose.astype(np.uint8), (x2-x1, y2-y1)) |
|
ref_mask_compose = (ref_mask_compose > 128).astype(np.uint8) |
|
|
|
collage = cropped_target_image.copy() |
|
collage[y1:y2,x1:x2,:] = ref_image_collage |
|
|
|
collage_mask = cropped_target_image.copy() * 0.0 |
|
collage_mask[y1:y2,x1:x2,:] = 1.0 |
|
|
|
|
|
H1, W1 = collage.shape[0], collage.shape[1] |
|
cropped_target_image = pad_to_square(cropped_target_image, pad_value = 0, random = False).astype(np.uint8) |
|
collage = pad_to_square(collage, pad_value = 0, random = False).astype(np.uint8) |
|
collage_mask = pad_to_square(collage_mask, pad_value = -1, random = False).astype(np.uint8) |
|
|
|
|
|
H2, W2 = collage.shape[0], collage.shape[1] |
|
cropped_target_image = cv2.resize(cropped_target_image, (512,512)).astype(np.float32) |
|
collage = cv2.resize(collage, (512,512)).astype(np.float32) |
|
collage_mask = (cv2.resize(collage_mask, (512,512)).astype(np.float32) > 0.5).astype(np.float32) |
|
|
|
masked_ref_image_aug = masked_ref_image_aug / 255 |
|
cropped_target_image = cropped_target_image / 127.5 - 1.0 |
|
collage = collage / 127.5 - 1.0 |
|
collage = np.concatenate([collage, collage_mask[:,:,:1] ] , -1) |
|
|
|
item = dict(ref=masked_ref_image_aug.copy(), jpg=cropped_target_image.copy(), hint=collage.copy(), extra_sizes=np.array([H1, W1, H2, W2]), tar_box_yyxx_crop=np.array( tar_box_yyxx_crop ) ) |
|
return item |
|
|
|
|
|
def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop): |
|
H1, W1, H2, W2 = extra_sizes |
|
y1,y2,x1,x2 = tar_box_yyxx_crop |
|
pred = cv2.resize(pred, (W2, H2)) |
|
m = 5 |
|
|
|
if W1 == H1: |
|
tar_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m] |
|
return tar_image |
|
|
|
if W1 < W2: |
|
pad1 = int((W2 - W1) / 2) |
|
pad2 = W2 - W1 - pad1 |
|
pred = pred[:,pad1: -pad2, :] |
|
else: |
|
pad1 = int((H2 - H1) / 2) |
|
pad2 = H2 - H1 - pad1 |
|
pred = pred[pad1: -pad2, :, :] |
|
|
|
gen_image = tar_image.copy() |
|
gen_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m] |
|
return gen_image |
|
|
|
|
|
def inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_scale, seed, steps): |
|
item = process_pairs(ref_image, ref_mask, tar_image, tar_mask) |
|
ref = item['ref'] * 255 |
|
tar = item['jpg'] * 127.5 + 127.5 |
|
hint = item['hint'] * 127.5 + 127.5 |
|
|
|
hint_image = hint[:,:,:-1] |
|
hint_mask = item['hint'][:,:,-1] * 255 |
|
hint_mask = np.stack([hint_mask,hint_mask,hint_mask],-1) |
|
ref = cv2.resize(ref.astype(np.uint8), (512,512)) |
|
|
|
seed = random.randint(0, 65535) |
|
if save_memory: |
|
model.low_vram_shift(is_diffusing=False) |
|
|
|
ref = item['ref'] |
|
tar = item['jpg'] |
|
hint = item['hint'] |
|
num_samples = 1 |
|
|
|
control = torch.from_numpy(hint.copy()).float().cuda() |
|
control = torch.stack([control for _ in range(num_samples)], dim=0) |
|
control = einops.rearrange(control, 'b h w c -> b c h w').clone() |
|
|
|
|
|
clip_input = torch.from_numpy(ref.copy()).float().cuda() |
|
clip_input = torch.stack([clip_input for _ in range(num_samples)], dim=0) |
|
clip_input = einops.rearrange(clip_input, 'b h w c -> b c h w').clone() |
|
|
|
guess_mode = False |
|
H,W = 512,512 |
|
|
|
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning( clip_input )]} |
|
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([torch.zeros((1,3,224,224))] * num_samples)]} |
|
shape = (4, H // 8, W // 8) |
|
|
|
if save_memory: |
|
model.low_vram_shift(is_diffusing=True) |
|
|
|
|
|
num_samples = 1 |
|
image_resolution = 512 |
|
strength = 1 |
|
guess_mode = False |
|
|
|
ddim_steps = steps |
|
scale = guidance_scale |
|
seed = seed |
|
eta = 0.0 |
|
|
|
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) |
|
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, |
|
shape, cond, verbose=False, eta=eta, |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=un_cond) |
|
if save_memory: |
|
model.low_vram_shift(is_diffusing=False) |
|
|
|
x_samples = model.decode_first_stage(samples) |
|
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy() |
|
|
|
result = x_samples[0][:,:,::-1] |
|
result = np.clip(result,0,255) |
|
|
|
pred = x_samples[0] |
|
pred = np.clip(pred,0,255)[1:,:,:] |
|
sizes = item['extra_sizes'] |
|
tar_box_yyxx_crop = item['tar_box_yyxx_crop'] |
|
gen_image = crop_back(pred, tar_image, sizes, tar_box_yyxx_crop) |
|
return gen_image |
|
|
|
|
|
import cv2 |
|
import numpy as np |
|
import base64 |
|
import os |
|
from http.server import BaseHTTPRequestHandler, HTTPServer |
|
import json |
|
from io import BytesIO |
|
from PIL import Image |
|
|
|
def base64_to_cv2_image(base64_str): |
|
img_str = base64.b64decode(base64_str) |
|
np_img = np.frombuffer(img_str, dtype=np.uint8) |
|
img = cv2.imdecode(np_img, cv2.IMREAD_COLOR) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
return img |
|
|
|
def base64_to_pil_image(base64_str): |
|
img_data = base64.b64decode(base64_str) |
|
img = Image.open(BytesIO(img_data)) |
|
return img |
|
|
|
def pil_image_to_np_array(pil_img, target_index): |
|
np_array = np.array(pil_img) |
|
return (np_array == target_index).astype(np.uint8) |
|
|
|
def image_to_base64(img): |
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
_, buffer = cv2.imencode('.jpg', img) |
|
base64_str = base64.b64encode(buffer).decode("utf-8") |
|
return base64_str |
|
|
|
|
|
class RequestHandler(BaseHTTPRequestHandler): |
|
API_KEY = "xiCQTaoQKXUNATzuFLWRgtoJKiFXiDGvnk" |
|
|
|
def _set_response(self, status_code=200, content_type='application/json'): |
|
self.send_response(status_code) |
|
self.send_header('Content-type', content_type) |
|
self.send_header('Access-Control-Allow-Origin', '*') |
|
self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS') |
|
self.send_header('Access-Control-Allow-Headers', 'X-API-Key, Content-Type') |
|
self.end_headers() |
|
|
|
def do_OPTIONS(self): |
|
self._set_response(204) |
|
|
|
def do_GET(self): |
|
|
|
self._set_response(405) |
|
self.wfile.write(b'{"error": "GET method not allowed."}') |
|
|
|
def handle_not_supported_method(self): |
|
self._set_response(405) |
|
self.wfile.write(b'{"error": "Method not supported."}') |
|
|
|
def do_PUT(self): |
|
self.handle_not_supported_method() |
|
|
|
def do_DELETE(self): |
|
self.handle_not_supported_method() |
|
|
|
def do_PATCH(self): |
|
self.handle_not_supported_method() |
|
|
|
def do_POST(self): |
|
print("Received POST request...") |
|
received_api_key = self.headers.get('X-API-Key') |
|
|
|
if received_api_key != self.API_KEY: |
|
|
|
self._set_response(401) |
|
self.wfile.write(b'{"error": "Invalid API key"}') |
|
print("Invalid API key") |
|
return |
|
|
|
content_length = int(self.headers['Content-Length']) |
|
print(f"Content Length: {content_length}") |
|
|
|
if content_length: |
|
post_data = self.rfile.read(content_length) |
|
print("Data received") |
|
try: |
|
data = json.loads(post_data.decode('utf-8')) |
|
print("Processing data") |
|
|
|
|
|
seed = int(data.get('seed')) |
|
steps = int(data.get('steps')) |
|
guidance_scale = float(data.get('guidance_scale')) |
|
|
|
ref_image = base64_to_cv2_image(data['ref_image']) |
|
tar_image = base64_to_cv2_image(data['tar_image']) |
|
|
|
|
|
|
|
|
|
ref_mask_img = base64_to_cv2_image(data['ref_mask']) |
|
ref_mask = cv2.cvtColor(ref_mask_img, cv2.COLOR_RGB2GRAY) |
|
ref_mask = (ref_mask > 128).astype(np.uint8) |
|
|
|
|
|
tar_mask_img = base64_to_cv2_image(data['tar_mask']) |
|
tar_mask = cv2.cvtColor(tar_mask_img, cv2.COLOR_RGB2GRAY) |
|
tar_mask = (tar_mask > 128).astype(np.uint8) |
|
|
|
output_dir = '/work/ADOOR_ACE/test_out' |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
cv2.imwrite(os.path.join(output_dir, 'out_ref_image.jpg'), cv2.cvtColor(ref_image, cv2.COLOR_RGB2BGR)) |
|
cv2.imwrite(os.path.join(output_dir, 'out_tar_image.jpg'), cv2.cvtColor(tar_image, cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
ref_mask_img_to_save = (ref_mask * 255).astype(np.uint8) |
|
cv2.imwrite(os.path.join(output_dir, 'out_ref_mask.jpg'), ref_mask_img_to_save) |
|
|
|
|
|
tar_mask_img_to_save = (tar_mask * 255).astype(np.uint8) |
|
cv2.imwrite(os.path.join(output_dir,'out_tar_mask.jpg'), tar_mask_img_to_save) |
|
|
|
gen_image = inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_scale, seed, steps) |
|
gen_image_base64 = image_to_base64(gen_image) |
|
|
|
self.send_response(200) |
|
self.send_header('Content-Type', 'image/jpeg') |
|
self.end_headers() |
|
self.wfile.write(base64.b64decode(gen_image_base64)) |
|
|
|
print("Sent image response") |
|
|
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
self._set_response(500) |
|
error_data = json.dumps({'error': str(e)}).encode('utf-8') |
|
self.wfile.write(error_data) |
|
print("Sent error response") |
|
|
|
else: |
|
print("No data received in POST request.") |
|
self._set_response(400) |
|
error_data = json.dumps({'error': 'No data received'}).encode('utf-8') |
|
self.wfile.write(error_data) |
|
print("Sent error response") |
|
|
|
|
|
|
|
def run(server_class=HTTPServer, handler_class=RequestHandler, port=8084): |
|
server_address = ('', port) |
|
httpd = server_class(server_address, handler_class) |
|
print(f"Starting HTTP server on port {port}") |
|
httpd.serve_forever() |
|
|
|
if __name__ == "__main__": |
|
run() |
|
|
|
|
|
|