|
%cd rem |
|
|
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from src.pipeline_stable_diffusion_controlnet_inpaint import * |
|
|
|
from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, DEISMultistepScheduler |
|
from diffusers.utils import load_image |
|
from PIL import Image |
|
import cv2 |
|
from src.core import process_inpaint |
|
from transformers import DPTFeatureExtractor, DPTForDepthEstimation |
|
import time |
|
|
|
from scipy.ndimage import label, find_objects |
|
from PIL import Image, ImageDraw |
|
import numpy as np |
|
|
|
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") |
|
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") |
|
controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-sd21-depth-diffusers", torch_dtype=torch.float16) |
|
|
|
|
|
|
|
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting",controlnet=controlnet, torch_dtype=torch.float16) |
|
|
|
pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
pipe.to('cuda') |
|
|
|
def resize_image(image, target_size): |
|
width, height = image.size |
|
aspect_ratio = float(width) / float(height) |
|
if width > height: |
|
new_width = target_size |
|
new_height = int(target_size / aspect_ratio) |
|
else: |
|
new_width = int(target_size * aspect_ratio) |
|
new_height = target_size |
|
return image.resize((new_width, new_height), Image.BICUBIC) |
|
|
|
def get_depth_map(image,target_size): |
|
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") |
|
with torch.no_grad(), torch.autocast("cuda"): |
|
depth_map = depth_estimator(image).predicted_depth |
|
|
|
depth_map = torch.nn.functional.interpolate( |
|
depth_map.unsqueeze(1), |
|
size=target_size, |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) |
|
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) |
|
depth_map = (depth_map - depth_min) / (depth_max - depth_min) |
|
image = torch.cat([depth_map] * 3, dim=1) |
|
|
|
image = image.permute(0, 2, 3, 1).cpu().numpy()[0] |
|
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) |
|
return image |
|
|
|
def add_split_line(mask_image, line_thickness): |
|
|
|
if mask_image.mode != 'L': |
|
mask_image = mask_image.convert('L') |
|
|
|
|
|
mask_array = np.array(mask_image) |
|
|
|
|
|
labeled_array, num_features = label(mask_array == 255) |
|
|
|
|
|
draw = ImageDraw.Draw(mask_image) |
|
|
|
|
|
for i in range(1, num_features + 1): |
|
|
|
slice_x, slice_y = find_objects(labeled_array == i)[0] |
|
top, bottom = slice_x.start, slice_x.stop |
|
left, right = slice_y.start, slice_y.stop |
|
|
|
|
|
if (right - left) > (bottom - top): |
|
|
|
center_x = (left + right) // 2 |
|
draw.line([(center_x, top), (center_x, bottom)], fill=0, width=line_thickness) |
|
else: |
|
|
|
center_y = (top + bottom) // 2 |
|
draw.line([(left, center_y), (right, center_y)], fill=0, width=line_thickness) |
|
|
|
return mask_image |
|
|
|
def predict(input_dict): |
|
start_time = time.time() |
|
|
|
|
|
image = input_dict["image"].convert("RGB") |
|
input_image = input_dict["mask"].convert("RGBA") |
|
image = resize_image(image, 768) |
|
input_image = resize_image(input_image, 768) |
|
mask_holes = add_split_line(input_image, 10) |
|
|
|
|
|
image_npp = np.array(image) |
|
drawing_np = np.array(input_image) |
|
|
|
if image_npp.shape[2] == 4: |
|
image_npp = cv2.cvtColor(image_npp, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
background = np.where( |
|
(drawing_np[:, :, 0] == 0) & |
|
(drawing_np[:, :, 1] == 0) & |
|
(drawing_np[:, :, 2] == 0) |
|
) |
|
drawing = np.where( |
|
(drawing_np[:, :, 0] == 255) & |
|
(drawing_np[:, :, 1] == 0) & |
|
(drawing_np[:, :, 2] == 255) |
|
) |
|
mask_npp = np.zeros_like(drawing_np) |
|
mask_npp[background] = [0, 0, 0, 255] |
|
mask_npp[drawing] = [0, 0, 0, 0] |
|
|
|
|
|
inpainted_image_np = process_inpaint(image_npp, mask_npp) |
|
inpainted_image = Image.fromarray(inpainted_image_np) |
|
|
|
unmasked_region = np.where(mask_npp[:, :, 3] != 0, True, False) |
|
|
|
|
|
blended_image_np = np.array(inpainted_image_np) |
|
|
|
blended_image_size = inpainted_image.size |
|
|
|
|
|
flipped_size = (blended_image_size[1], blended_image_size[0]) |
|
depth_image = get_depth_map(inpainted_image, flipped_size) |
|
|
|
|
|
generator = torch.manual_seed(0) |
|
output = pipe( |
|
prompt="", |
|
num_inference_steps=8, |
|
generator=generator, |
|
image=blended_image_np, |
|
control_image=depth_image, |
|
controlnet_conditioning_scale=0.9, |
|
mask_image=mask_holes |
|
).images[0] |
|
|
|
|
|
output_np = np.array(output) |
|
|
|
|
|
if output_np.shape[:2] == inpainted_image_np.shape[:2]: |
|
|
|
output_np[unmasked_region] = inpainted_image_np[unmasked_region] |
|
else: |
|
print("Dimension mismatch: cannot apply unmasked_region") |
|
|
|
|
|
final_output = Image.fromarray(output_np) |
|
|
|
end_time = time.time() |
|
inference_time = end_time - start_time |
|
inference_time_str = f"Inference Time: {inference_time:.2f} seconds" |
|
|
|
|
|
return final_output, inference_time_str |
|
|
|
image_blocks = gr.Blocks() |
|
|
|
with image_blocks as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(source='upload', tool='sketch', elem_id="input_image_upload", type="pil", label="Upload & Draw on Image") |
|
btn = gr.Button("Remove Object") |
|
with gr.Column(): |
|
result = gr.Image(label="Result") |
|
inference_time_label = gr.Label() |
|
btn.click(fn=predict, inputs=[input_image], outputs=[result, inference_time_label]) |
|
|
|
demo.launch(debug=True) |