File size: 6,907 Bytes
9d33283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
%cd rem
#best object removal model

import gradio as gr
import numpy as np
import torch
from src.pipeline_stable_diffusion_controlnet_inpaint import *

from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, DEISMultistepScheduler
from diffusers.utils import load_image
from PIL import Image
import cv2
from src.core import process_inpaint
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
import time  # Import the time module

from scipy.ndimage import label, find_objects
from PIL import Image, ImageDraw
import numpy as np

depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-sd21-depth-diffusers", torch_dtype=torch.float16)



pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",controlnet=controlnet, torch_dtype=torch.float16)

pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)

pipe.to('cuda')

def resize_image(image, target_size):
    width, height = image.size
    aspect_ratio = float(width) / float(height)
    if width > height:
        new_width = target_size
        new_height = int(target_size / aspect_ratio)
    else:
        new_width = int(target_size * aspect_ratio)
        new_height = target_size
    return image.resize((new_width, new_height), Image.BICUBIC)

def get_depth_map(image,target_size):
    image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
    with torch.no_grad(), torch.autocast("cuda"):
        depth_map = depth_estimator(image).predicted_depth

    depth_map = torch.nn.functional.interpolate(
        depth_map.unsqueeze(1),
        size=target_size,  # Replace with the size of your blended_image
        mode="bicubic",
        align_corners=False,
    )
    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_map = (depth_map - depth_min) / (depth_max - depth_min)
    image = torch.cat([depth_map] * 3, dim=1)

    image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
    image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
    return image

def add_split_line(mask_image, line_thickness):
    # Ensure the mask is in the correct mode
    if mask_image.mode != 'L':
        mask_image = mask_image.convert('L')

    # Convert mask to a numpy array
    mask_array = np.array(mask_image)

    # Label different regions in the mask
    labeled_array, num_features = label(mask_array == 255)

    # Create a draw object
    draw = ImageDraw.Draw(mask_image)

    # Iterate over each white area
    for i in range(1, num_features + 1):
        # Find the bounding box of the white area
        slice_x, slice_y = find_objects(labeled_array == i)[0]
        top, bottom = slice_x.start, slice_x.stop
        left, right = slice_y.start, slice_y.stop

        # Draw a line that splits the white area
        if (right - left) > (bottom - top):
            # If the area is wider than it is tall, draw a vertical line
            center_x = (left + right) // 2
            draw.line([(center_x, top), (center_x, bottom)], fill=0, width=line_thickness)
        else:
            # If the area is taller than it is wide, draw a horizontal line
            center_y = (top + bottom) // 2
            draw.line([(left, center_y), (right, center_y)], fill=0, width=line_thickness)

    return mask_image

def predict(input_dict):
    start_time = time.time()  # Start time

    # Get the drawn input image and mask
    image = input_dict["image"].convert("RGB")
    input_image = input_dict["mask"].convert("RGBA")
    image = resize_image(image, 768)
    input_image = resize_image(input_image, 768)
    mask_holes = add_split_line(input_image, 10)  # 10% of white area size

    # Convert to numpy array
    image_npp = np.array(image)
    drawing_np = np.array(input_image)

    if image_npp.shape[2] == 4:
        image_npp = cv2.cvtColor(image_npp, cv2.COLOR_RGBA2RGB)

    # Process the mask similar to Streamlit code
    background = np.where(
        (drawing_np[:, :, 0] == 0) &
        (drawing_np[:, :, 1] == 0) &
        (drawing_np[:, :, 2] == 0)
    )
    drawing = np.where(
        (drawing_np[:, :, 0] == 255) &
        (drawing_np[:, :, 1] == 0) &
        (drawing_np[:, :, 2] == 255)
    )
    mask_npp = np.zeros_like(drawing_np)
    mask_npp[background] = [0, 0, 0, 255]  # Opaque where not drawing
    mask_npp[drawing] = [0, 0, 0, 0]  # Transparent where drawing

    # Process inpainting
    inpainted_image_np = process_inpaint(image_npp, mask_npp)
    inpainted_image = Image.fromarray(inpainted_image_np)

    unmasked_region = np.where(mask_npp[:, :, 3] != 0, True, False)  # Non-zero in alpha channel indicates unmasked area

    # Process the blended image
    blended_image_np = np.array(inpainted_image_np)

    blended_image_size = inpainted_image.size  # This gives you (width, height)

    # Flip the dimensions to get (768, 512)
    flipped_size = (blended_image_size[1], blended_image_size[0])
    depth_image = get_depth_map(inpainted_image, flipped_size)


    generator = torch.manual_seed(0)
    output = pipe(
        prompt="",
        num_inference_steps=8,
        generator=generator,
        image=blended_image_np,
        control_image=depth_image,
        controlnet_conditioning_scale=0.9,
        mask_image=mask_holes
    ).images[0]

    # Convert the final output to a NumPy array
    output_np = np.array(output)

    # Ensuring dimensions match before applying unmasked_region
    if output_np.shape[:2] == inpainted_image_np.shape[:2]:
        # Paste the unmasked region from inpainted_image_np onto the final output
        output_np[unmasked_region] = inpainted_image_np[unmasked_region]
    else:
        print("Dimension mismatch: cannot apply unmasked_region")

    # Convert back to PIL Image
    final_output = Image.fromarray(output_np)

    end_time = time.time()
    inference_time = end_time - start_time
    inference_time_str = f"Inference Time: {inference_time:.2f} seconds"

    # Return both image and inference time
    return final_output, inference_time_str

image_blocks = gr.Blocks()

with image_blocks as demo:
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(source='upload', tool='sketch', elem_id="input_image_upload", type="pil", label="Upload & Draw on Image")
            btn = gr.Button("Remove Object")
        with gr.Column():
            result = gr.Image(label="Result")
            inference_time_label = gr.Label()  # Add a label to display the inference time
        btn.click(fn=predict, inputs=[input_image], outputs=[result, inference_time_label])  # Update outputs

demo.launch(debug=True)