rem / app.py
Vijish's picture
Create app.py
9d33283
%cd rem
#best object removal model
import gradio as gr
import numpy as np
import torch
from src.pipeline_stable_diffusion_controlnet_inpaint import *
from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, DEISMultistepScheduler
from diffusers.utils import load_image
from PIL import Image
import cv2
from src.core import process_inpaint
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
import time # Import the time module
from scipy.ndimage import label, find_objects
from PIL import Image, ImageDraw
import numpy as np
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-sd21-depth-diffusers", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",controlnet=controlnet, torch_dtype=torch.float16)
pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to('cuda')
def resize_image(image, target_size):
width, height = image.size
aspect_ratio = float(width) / float(height)
if width > height:
new_width = target_size
new_height = int(target_size / aspect_ratio)
else:
new_width = int(target_size * aspect_ratio)
new_height = target_size
return image.resize((new_width, new_height), Image.BICUBIC)
def get_depth_map(image,target_size):
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
with torch.no_grad(), torch.autocast("cuda"):
depth_map = depth_estimator(image).predicted_depth
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=target_size, # Replace with the size of your blended_image
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
image = torch.cat([depth_map] * 3, dim=1)
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
return image
def add_split_line(mask_image, line_thickness):
# Ensure the mask is in the correct mode
if mask_image.mode != 'L':
mask_image = mask_image.convert('L')
# Convert mask to a numpy array
mask_array = np.array(mask_image)
# Label different regions in the mask
labeled_array, num_features = label(mask_array == 255)
# Create a draw object
draw = ImageDraw.Draw(mask_image)
# Iterate over each white area
for i in range(1, num_features + 1):
# Find the bounding box of the white area
slice_x, slice_y = find_objects(labeled_array == i)[0]
top, bottom = slice_x.start, slice_x.stop
left, right = slice_y.start, slice_y.stop
# Draw a line that splits the white area
if (right - left) > (bottom - top):
# If the area is wider than it is tall, draw a vertical line
center_x = (left + right) // 2
draw.line([(center_x, top), (center_x, bottom)], fill=0, width=line_thickness)
else:
# If the area is taller than it is wide, draw a horizontal line
center_y = (top + bottom) // 2
draw.line([(left, center_y), (right, center_y)], fill=0, width=line_thickness)
return mask_image
def predict(input_dict):
start_time = time.time() # Start time
# Get the drawn input image and mask
image = input_dict["image"].convert("RGB")
input_image = input_dict["mask"].convert("RGBA")
image = resize_image(image, 768)
input_image = resize_image(input_image, 768)
mask_holes = add_split_line(input_image, 10) # 10% of white area size
# Convert to numpy array
image_npp = np.array(image)
drawing_np = np.array(input_image)
if image_npp.shape[2] == 4:
image_npp = cv2.cvtColor(image_npp, cv2.COLOR_RGBA2RGB)
# Process the mask similar to Streamlit code
background = np.where(
(drawing_np[:, :, 0] == 0) &
(drawing_np[:, :, 1] == 0) &
(drawing_np[:, :, 2] == 0)
)
drawing = np.where(
(drawing_np[:, :, 0] == 255) &
(drawing_np[:, :, 1] == 0) &
(drawing_np[:, :, 2] == 255)
)
mask_npp = np.zeros_like(drawing_np)
mask_npp[background] = [0, 0, 0, 255] # Opaque where not drawing
mask_npp[drawing] = [0, 0, 0, 0] # Transparent where drawing
# Process inpainting
inpainted_image_np = process_inpaint(image_npp, mask_npp)
inpainted_image = Image.fromarray(inpainted_image_np)
unmasked_region = np.where(mask_npp[:, :, 3] != 0, True, False) # Non-zero in alpha channel indicates unmasked area
# Process the blended image
blended_image_np = np.array(inpainted_image_np)
blended_image_size = inpainted_image.size # This gives you (width, height)
# Flip the dimensions to get (768, 512)
flipped_size = (blended_image_size[1], blended_image_size[0])
depth_image = get_depth_map(inpainted_image, flipped_size)
generator = torch.manual_seed(0)
output = pipe(
prompt="",
num_inference_steps=8,
generator=generator,
image=blended_image_np,
control_image=depth_image,
controlnet_conditioning_scale=0.9,
mask_image=mask_holes
).images[0]
# Convert the final output to a NumPy array
output_np = np.array(output)
# Ensuring dimensions match before applying unmasked_region
if output_np.shape[:2] == inpainted_image_np.shape[:2]:
# Paste the unmasked region from inpainted_image_np onto the final output
output_np[unmasked_region] = inpainted_image_np[unmasked_region]
else:
print("Dimension mismatch: cannot apply unmasked_region")
# Convert back to PIL Image
final_output = Image.fromarray(output_np)
end_time = time.time()
inference_time = end_time - start_time
inference_time_str = f"Inference Time: {inference_time:.2f} seconds"
# Return both image and inference time
return final_output, inference_time_str
image_blocks = gr.Blocks()
with image_blocks as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', tool='sketch', elem_id="input_image_upload", type="pil", label="Upload & Draw on Image")
btn = gr.Button("Remove Object")
with gr.Column():
result = gr.Image(label="Result")
inference_time_label = gr.Label() # Add a label to display the inference time
btn.click(fn=predict, inputs=[input_image], outputs=[result, inference_time_label]) # Update outputs
demo.launch(debug=True)