%cd rem #best object removal model import gradio as gr import numpy as np import torch from src.pipeline_stable_diffusion_controlnet_inpaint import * from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, DEISMultistepScheduler from diffusers.utils import load_image from PIL import Image import cv2 from src.core import process_inpaint from transformers import DPTFeatureExtractor, DPTForDepthEstimation import time # Import the time module from scipy.ndimage import label, find_objects from PIL import Image, ImageDraw import numpy as np depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-sd21-depth-diffusers", torch_dtype=torch.float16) pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-2-inpainting",controlnet=controlnet, torch_dtype=torch.float16) pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config) pipe.to('cuda') def resize_image(image, target_size): width, height = image.size aspect_ratio = float(width) / float(height) if width > height: new_width = target_size new_height = int(target_size / aspect_ratio) else: new_width = int(target_size * aspect_ratio) new_height = target_size return image.resize((new_width, new_height), Image.BICUBIC) def get_depth_map(image,target_size): image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") with torch.no_grad(), torch.autocast("cuda"): depth_map = depth_estimator(image).predicted_depth depth_map = torch.nn.functional.interpolate( depth_map.unsqueeze(1), size=target_size, # Replace with the size of your blended_image mode="bicubic", align_corners=False, ) depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map = (depth_map - depth_min) / (depth_max - depth_min) image = torch.cat([depth_map] * 3, dim=1) image = image.permute(0, 2, 3, 1).cpu().numpy()[0] image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) return image def add_split_line(mask_image, line_thickness): # Ensure the mask is in the correct mode if mask_image.mode != 'L': mask_image = mask_image.convert('L') # Convert mask to a numpy array mask_array = np.array(mask_image) # Label different regions in the mask labeled_array, num_features = label(mask_array == 255) # Create a draw object draw = ImageDraw.Draw(mask_image) # Iterate over each white area for i in range(1, num_features + 1): # Find the bounding box of the white area slice_x, slice_y = find_objects(labeled_array == i)[0] top, bottom = slice_x.start, slice_x.stop left, right = slice_y.start, slice_y.stop # Draw a line that splits the white area if (right - left) > (bottom - top): # If the area is wider than it is tall, draw a vertical line center_x = (left + right) // 2 draw.line([(center_x, top), (center_x, bottom)], fill=0, width=line_thickness) else: # If the area is taller than it is wide, draw a horizontal line center_y = (top + bottom) // 2 draw.line([(left, center_y), (right, center_y)], fill=0, width=line_thickness) return mask_image def predict(input_dict): start_time = time.time() # Start time # Get the drawn input image and mask image = input_dict["image"].convert("RGB") input_image = input_dict["mask"].convert("RGBA") image = resize_image(image, 768) input_image = resize_image(input_image, 768) mask_holes = add_split_line(input_image, 10) # 10% of white area size # Convert to numpy array image_npp = np.array(image) drawing_np = np.array(input_image) if image_npp.shape[2] == 4: image_npp = cv2.cvtColor(image_npp, cv2.COLOR_RGBA2RGB) # Process the mask similar to Streamlit code background = np.where( (drawing_np[:, :, 0] == 0) & (drawing_np[:, :, 1] == 0) & (drawing_np[:, :, 2] == 0) ) drawing = np.where( (drawing_np[:, :, 0] == 255) & (drawing_np[:, :, 1] == 0) & (drawing_np[:, :, 2] == 255) ) mask_npp = np.zeros_like(drawing_np) mask_npp[background] = [0, 0, 0, 255] # Opaque where not drawing mask_npp[drawing] = [0, 0, 0, 0] # Transparent where drawing # Process inpainting inpainted_image_np = process_inpaint(image_npp, mask_npp) inpainted_image = Image.fromarray(inpainted_image_np) unmasked_region = np.where(mask_npp[:, :, 3] != 0, True, False) # Non-zero in alpha channel indicates unmasked area # Process the blended image blended_image_np = np.array(inpainted_image_np) blended_image_size = inpainted_image.size # This gives you (width, height) # Flip the dimensions to get (768, 512) flipped_size = (blended_image_size[1], blended_image_size[0]) depth_image = get_depth_map(inpainted_image, flipped_size) generator = torch.manual_seed(0) output = pipe( prompt="", num_inference_steps=8, generator=generator, image=blended_image_np, control_image=depth_image, controlnet_conditioning_scale=0.9, mask_image=mask_holes ).images[0] # Convert the final output to a NumPy array output_np = np.array(output) # Ensuring dimensions match before applying unmasked_region if output_np.shape[:2] == inpainted_image_np.shape[:2]: # Paste the unmasked region from inpainted_image_np onto the final output output_np[unmasked_region] = inpainted_image_np[unmasked_region] else: print("Dimension mismatch: cannot apply unmasked_region") # Convert back to PIL Image final_output = Image.fromarray(output_np) end_time = time.time() inference_time = end_time - start_time inference_time_str = f"Inference Time: {inference_time:.2f} seconds" # Return both image and inference time return final_output, inference_time_str image_blocks = gr.Blocks() with image_blocks as demo: with gr.Row(): with gr.Column(): input_image = gr.Image(source='upload', tool='sketch', elem_id="input_image_upload", type="pil", label="Upload & Draw on Image") btn = gr.Button("Remove Object") with gr.Column(): result = gr.Image(label="Result") inference_time_label = gr.Label() # Add a label to display the inference time btn.click(fn=predict, inputs=[input_image], outputs=[result, inference_time_label]) # Update outputs demo.launch(debug=True)