Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
%cd rem
|
2 |
+
#best object removal model
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from src.pipeline_stable_diffusion_controlnet_inpaint import *
|
8 |
+
|
9 |
+
from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, DEISMultistepScheduler
|
10 |
+
from diffusers.utils import load_image
|
11 |
+
from PIL import Image
|
12 |
+
import cv2
|
13 |
+
from src.core import process_inpaint
|
14 |
+
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
15 |
+
import time # Import the time module
|
16 |
+
|
17 |
+
from scipy.ndimage import label, find_objects
|
18 |
+
from PIL import Image, ImageDraw
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
|
22 |
+
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
23 |
+
controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-sd21-depth-diffusers", torch_dtype=torch.float16)
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
28 |
+
"stabilityai/stable-diffusion-2-inpainting",controlnet=controlnet, torch_dtype=torch.float16)
|
29 |
+
|
30 |
+
pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
|
31 |
+
|
32 |
+
pipe.to('cuda')
|
33 |
+
|
34 |
+
def resize_image(image, target_size):
|
35 |
+
width, height = image.size
|
36 |
+
aspect_ratio = float(width) / float(height)
|
37 |
+
if width > height:
|
38 |
+
new_width = target_size
|
39 |
+
new_height = int(target_size / aspect_ratio)
|
40 |
+
else:
|
41 |
+
new_width = int(target_size * aspect_ratio)
|
42 |
+
new_height = target_size
|
43 |
+
return image.resize((new_width, new_height), Image.BICUBIC)
|
44 |
+
|
45 |
+
def get_depth_map(image,target_size):
|
46 |
+
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
47 |
+
with torch.no_grad(), torch.autocast("cuda"):
|
48 |
+
depth_map = depth_estimator(image).predicted_depth
|
49 |
+
|
50 |
+
depth_map = torch.nn.functional.interpolate(
|
51 |
+
depth_map.unsqueeze(1),
|
52 |
+
size=target_size, # Replace with the size of your blended_image
|
53 |
+
mode="bicubic",
|
54 |
+
align_corners=False,
|
55 |
+
)
|
56 |
+
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
57 |
+
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
58 |
+
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
59 |
+
image = torch.cat([depth_map] * 3, dim=1)
|
60 |
+
|
61 |
+
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
62 |
+
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
63 |
+
return image
|
64 |
+
|
65 |
+
def add_split_line(mask_image, line_thickness):
|
66 |
+
# Ensure the mask is in the correct mode
|
67 |
+
if mask_image.mode != 'L':
|
68 |
+
mask_image = mask_image.convert('L')
|
69 |
+
|
70 |
+
# Convert mask to a numpy array
|
71 |
+
mask_array = np.array(mask_image)
|
72 |
+
|
73 |
+
# Label different regions in the mask
|
74 |
+
labeled_array, num_features = label(mask_array == 255)
|
75 |
+
|
76 |
+
# Create a draw object
|
77 |
+
draw = ImageDraw.Draw(mask_image)
|
78 |
+
|
79 |
+
# Iterate over each white area
|
80 |
+
for i in range(1, num_features + 1):
|
81 |
+
# Find the bounding box of the white area
|
82 |
+
slice_x, slice_y = find_objects(labeled_array == i)[0]
|
83 |
+
top, bottom = slice_x.start, slice_x.stop
|
84 |
+
left, right = slice_y.start, slice_y.stop
|
85 |
+
|
86 |
+
# Draw a line that splits the white area
|
87 |
+
if (right - left) > (bottom - top):
|
88 |
+
# If the area is wider than it is tall, draw a vertical line
|
89 |
+
center_x = (left + right) // 2
|
90 |
+
draw.line([(center_x, top), (center_x, bottom)], fill=0, width=line_thickness)
|
91 |
+
else:
|
92 |
+
# If the area is taller than it is wide, draw a horizontal line
|
93 |
+
center_y = (top + bottom) // 2
|
94 |
+
draw.line([(left, center_y), (right, center_y)], fill=0, width=line_thickness)
|
95 |
+
|
96 |
+
return mask_image
|
97 |
+
|
98 |
+
def predict(input_dict):
|
99 |
+
start_time = time.time() # Start time
|
100 |
+
|
101 |
+
# Get the drawn input image and mask
|
102 |
+
image = input_dict["image"].convert("RGB")
|
103 |
+
input_image = input_dict["mask"].convert("RGBA")
|
104 |
+
image = resize_image(image, 768)
|
105 |
+
input_image = resize_image(input_image, 768)
|
106 |
+
mask_holes = add_split_line(input_image, 10) # 10% of white area size
|
107 |
+
|
108 |
+
# Convert to numpy array
|
109 |
+
image_npp = np.array(image)
|
110 |
+
drawing_np = np.array(input_image)
|
111 |
+
|
112 |
+
if image_npp.shape[2] == 4:
|
113 |
+
image_npp = cv2.cvtColor(image_npp, cv2.COLOR_RGBA2RGB)
|
114 |
+
|
115 |
+
# Process the mask similar to Streamlit code
|
116 |
+
background = np.where(
|
117 |
+
(drawing_np[:, :, 0] == 0) &
|
118 |
+
(drawing_np[:, :, 1] == 0) &
|
119 |
+
(drawing_np[:, :, 2] == 0)
|
120 |
+
)
|
121 |
+
drawing = np.where(
|
122 |
+
(drawing_np[:, :, 0] == 255) &
|
123 |
+
(drawing_np[:, :, 1] == 0) &
|
124 |
+
(drawing_np[:, :, 2] == 255)
|
125 |
+
)
|
126 |
+
mask_npp = np.zeros_like(drawing_np)
|
127 |
+
mask_npp[background] = [0, 0, 0, 255] # Opaque where not drawing
|
128 |
+
mask_npp[drawing] = [0, 0, 0, 0] # Transparent where drawing
|
129 |
+
|
130 |
+
# Process inpainting
|
131 |
+
inpainted_image_np = process_inpaint(image_npp, mask_npp)
|
132 |
+
inpainted_image = Image.fromarray(inpainted_image_np)
|
133 |
+
|
134 |
+
unmasked_region = np.where(mask_npp[:, :, 3] != 0, True, False) # Non-zero in alpha channel indicates unmasked area
|
135 |
+
|
136 |
+
# Process the blended image
|
137 |
+
blended_image_np = np.array(inpainted_image_np)
|
138 |
+
|
139 |
+
blended_image_size = inpainted_image.size # This gives you (width, height)
|
140 |
+
|
141 |
+
# Flip the dimensions to get (768, 512)
|
142 |
+
flipped_size = (blended_image_size[1], blended_image_size[0])
|
143 |
+
depth_image = get_depth_map(inpainted_image, flipped_size)
|
144 |
+
|
145 |
+
|
146 |
+
generator = torch.manual_seed(0)
|
147 |
+
output = pipe(
|
148 |
+
prompt="",
|
149 |
+
num_inference_steps=8,
|
150 |
+
generator=generator,
|
151 |
+
image=blended_image_np,
|
152 |
+
control_image=depth_image,
|
153 |
+
controlnet_conditioning_scale=0.9,
|
154 |
+
mask_image=mask_holes
|
155 |
+
).images[0]
|
156 |
+
|
157 |
+
# Convert the final output to a NumPy array
|
158 |
+
output_np = np.array(output)
|
159 |
+
|
160 |
+
# Ensuring dimensions match before applying unmasked_region
|
161 |
+
if output_np.shape[:2] == inpainted_image_np.shape[:2]:
|
162 |
+
# Paste the unmasked region from inpainted_image_np onto the final output
|
163 |
+
output_np[unmasked_region] = inpainted_image_np[unmasked_region]
|
164 |
+
else:
|
165 |
+
print("Dimension mismatch: cannot apply unmasked_region")
|
166 |
+
|
167 |
+
# Convert back to PIL Image
|
168 |
+
final_output = Image.fromarray(output_np)
|
169 |
+
|
170 |
+
end_time = time.time()
|
171 |
+
inference_time = end_time - start_time
|
172 |
+
inference_time_str = f"Inference Time: {inference_time:.2f} seconds"
|
173 |
+
|
174 |
+
# Return both image and inference time
|
175 |
+
return final_output, inference_time_str
|
176 |
+
|
177 |
+
image_blocks = gr.Blocks()
|
178 |
+
|
179 |
+
with image_blocks as demo:
|
180 |
+
with gr.Row():
|
181 |
+
with gr.Column():
|
182 |
+
input_image = gr.Image(source='upload', tool='sketch', elem_id="input_image_upload", type="pil", label="Upload & Draw on Image")
|
183 |
+
btn = gr.Button("Remove Object")
|
184 |
+
with gr.Column():
|
185 |
+
result = gr.Image(label="Result")
|
186 |
+
inference_time_label = gr.Label() # Add a label to display the inference time
|
187 |
+
btn.click(fn=predict, inputs=[input_image], outputs=[result, inference_time_label]) # Update outputs
|
188 |
+
|
189 |
+
demo.launch(debug=True)
|