Vijish commited on
Commit
9d33283
1 Parent(s): 236fbce

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -0
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %cd rem
2
+ #best object removal model
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ from src.pipeline_stable_diffusion_controlnet_inpaint import *
8
+
9
+ from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, DEISMultistepScheduler
10
+ from diffusers.utils import load_image
11
+ from PIL import Image
12
+ import cv2
13
+ from src.core import process_inpaint
14
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
15
+ import time # Import the time module
16
+
17
+ from scipy.ndimage import label, find_objects
18
+ from PIL import Image, ImageDraw
19
+ import numpy as np
20
+
21
+ depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
22
+ feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
23
+ controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-sd21-depth-diffusers", torch_dtype=torch.float16)
24
+
25
+
26
+
27
+ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
28
+ "stabilityai/stable-diffusion-2-inpainting",controlnet=controlnet, torch_dtype=torch.float16)
29
+
30
+ pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
31
+
32
+ pipe.to('cuda')
33
+
34
+ def resize_image(image, target_size):
35
+ width, height = image.size
36
+ aspect_ratio = float(width) / float(height)
37
+ if width > height:
38
+ new_width = target_size
39
+ new_height = int(target_size / aspect_ratio)
40
+ else:
41
+ new_width = int(target_size * aspect_ratio)
42
+ new_height = target_size
43
+ return image.resize((new_width, new_height), Image.BICUBIC)
44
+
45
+ def get_depth_map(image,target_size):
46
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
47
+ with torch.no_grad(), torch.autocast("cuda"):
48
+ depth_map = depth_estimator(image).predicted_depth
49
+
50
+ depth_map = torch.nn.functional.interpolate(
51
+ depth_map.unsqueeze(1),
52
+ size=target_size, # Replace with the size of your blended_image
53
+ mode="bicubic",
54
+ align_corners=False,
55
+ )
56
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
57
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
58
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
59
+ image = torch.cat([depth_map] * 3, dim=1)
60
+
61
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
62
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
63
+ return image
64
+
65
+ def add_split_line(mask_image, line_thickness):
66
+ # Ensure the mask is in the correct mode
67
+ if mask_image.mode != 'L':
68
+ mask_image = mask_image.convert('L')
69
+
70
+ # Convert mask to a numpy array
71
+ mask_array = np.array(mask_image)
72
+
73
+ # Label different regions in the mask
74
+ labeled_array, num_features = label(mask_array == 255)
75
+
76
+ # Create a draw object
77
+ draw = ImageDraw.Draw(mask_image)
78
+
79
+ # Iterate over each white area
80
+ for i in range(1, num_features + 1):
81
+ # Find the bounding box of the white area
82
+ slice_x, slice_y = find_objects(labeled_array == i)[0]
83
+ top, bottom = slice_x.start, slice_x.stop
84
+ left, right = slice_y.start, slice_y.stop
85
+
86
+ # Draw a line that splits the white area
87
+ if (right - left) > (bottom - top):
88
+ # If the area is wider than it is tall, draw a vertical line
89
+ center_x = (left + right) // 2
90
+ draw.line([(center_x, top), (center_x, bottom)], fill=0, width=line_thickness)
91
+ else:
92
+ # If the area is taller than it is wide, draw a horizontal line
93
+ center_y = (top + bottom) // 2
94
+ draw.line([(left, center_y), (right, center_y)], fill=0, width=line_thickness)
95
+
96
+ return mask_image
97
+
98
+ def predict(input_dict):
99
+ start_time = time.time() # Start time
100
+
101
+ # Get the drawn input image and mask
102
+ image = input_dict["image"].convert("RGB")
103
+ input_image = input_dict["mask"].convert("RGBA")
104
+ image = resize_image(image, 768)
105
+ input_image = resize_image(input_image, 768)
106
+ mask_holes = add_split_line(input_image, 10) # 10% of white area size
107
+
108
+ # Convert to numpy array
109
+ image_npp = np.array(image)
110
+ drawing_np = np.array(input_image)
111
+
112
+ if image_npp.shape[2] == 4:
113
+ image_npp = cv2.cvtColor(image_npp, cv2.COLOR_RGBA2RGB)
114
+
115
+ # Process the mask similar to Streamlit code
116
+ background = np.where(
117
+ (drawing_np[:, :, 0] == 0) &
118
+ (drawing_np[:, :, 1] == 0) &
119
+ (drawing_np[:, :, 2] == 0)
120
+ )
121
+ drawing = np.where(
122
+ (drawing_np[:, :, 0] == 255) &
123
+ (drawing_np[:, :, 1] == 0) &
124
+ (drawing_np[:, :, 2] == 255)
125
+ )
126
+ mask_npp = np.zeros_like(drawing_np)
127
+ mask_npp[background] = [0, 0, 0, 255] # Opaque where not drawing
128
+ mask_npp[drawing] = [0, 0, 0, 0] # Transparent where drawing
129
+
130
+ # Process inpainting
131
+ inpainted_image_np = process_inpaint(image_npp, mask_npp)
132
+ inpainted_image = Image.fromarray(inpainted_image_np)
133
+
134
+ unmasked_region = np.where(mask_npp[:, :, 3] != 0, True, False) # Non-zero in alpha channel indicates unmasked area
135
+
136
+ # Process the blended image
137
+ blended_image_np = np.array(inpainted_image_np)
138
+
139
+ blended_image_size = inpainted_image.size # This gives you (width, height)
140
+
141
+ # Flip the dimensions to get (768, 512)
142
+ flipped_size = (blended_image_size[1], blended_image_size[0])
143
+ depth_image = get_depth_map(inpainted_image, flipped_size)
144
+
145
+
146
+ generator = torch.manual_seed(0)
147
+ output = pipe(
148
+ prompt="",
149
+ num_inference_steps=8,
150
+ generator=generator,
151
+ image=blended_image_np,
152
+ control_image=depth_image,
153
+ controlnet_conditioning_scale=0.9,
154
+ mask_image=mask_holes
155
+ ).images[0]
156
+
157
+ # Convert the final output to a NumPy array
158
+ output_np = np.array(output)
159
+
160
+ # Ensuring dimensions match before applying unmasked_region
161
+ if output_np.shape[:2] == inpainted_image_np.shape[:2]:
162
+ # Paste the unmasked region from inpainted_image_np onto the final output
163
+ output_np[unmasked_region] = inpainted_image_np[unmasked_region]
164
+ else:
165
+ print("Dimension mismatch: cannot apply unmasked_region")
166
+
167
+ # Convert back to PIL Image
168
+ final_output = Image.fromarray(output_np)
169
+
170
+ end_time = time.time()
171
+ inference_time = end_time - start_time
172
+ inference_time_str = f"Inference Time: {inference_time:.2f} seconds"
173
+
174
+ # Return both image and inference time
175
+ return final_output, inference_time_str
176
+
177
+ image_blocks = gr.Blocks()
178
+
179
+ with image_blocks as demo:
180
+ with gr.Row():
181
+ with gr.Column():
182
+ input_image = gr.Image(source='upload', tool='sketch', elem_id="input_image_upload", type="pil", label="Upload & Draw on Image")
183
+ btn = gr.Button("Remove Object")
184
+ with gr.Column():
185
+ result = gr.Image(label="Result")
186
+ inference_time_label = gr.Label() # Add a label to display the inference time
187
+ btn.click(fn=predict, inputs=[input_image], outputs=[result, inference_time_label]) # Update outputs
188
+
189
+ demo.launch(debug=True)