import gradio as gr from gradio_imageslider import ImageSlider from PIL import Image, ImageDraw, ImageFont import numpy as np import cv2 import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation torch.set_float32_matmul_precision(["high", "highest"][0]) # Load BiRefNet model for background removal birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to("cuda") transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) def load_img(image, output_type="numpy"): if output_type == "pil": return Image.open(image).convert("RGB") else: return np.array(Image.open(image).convert("RGB")) def add_text_to_image(image, text, position, color, font_size): img = Image.fromarray(image) draw = ImageDraw.Draw(img) font = ImageFont.truetype("arial.ttf", font_size) draw.text(position, text, fill=color, font=font) return np.array(img) def inpaint_image(image, mask, inpaint_radius): img_cv = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) mask_cv = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) result = cv2.inpaint(img_cv, mask_cv, inpaint_radius, cv2.INPAINT_TELEA) return cv2.cvtColor(result, cv2.COLOR_BGR2RGB) def background_removal(image): im = load_img(image, output_type="pil") im = im.convert("RGB") image_size = im.size origin = im.copy() image = load_img(im) input_images = transform_image(image).unsqueeze(0).to("cuda") with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) im.putalpha(mask) return (im, origin) def update_image(image, text, color, font_size, mask_image, inpaint_radius): img_with_text = add_text_to_image(image, text, (50, 50), color, font_size) if mask_image is not None: mask = np.array(mask_image) img_with_text = inpaint_image(img_with_text, mask, inpaint_radius) return img_with_text def fn(image): return background_removal(image) slider1 = ImageSlider(label="Original Image", type="pil") slider2 = ImageSlider(label="Processed Image", type="pil") image_input = gr.Image(label="Upload an image for background removal") text_input = gr.Textbox(label="Enter Text to Add", placeholder="Your text here...") color_input = gr.ColorPicker(label="Text Color") font_size_input = gr.Slider(minimum=10, maximum=100, label="Font Size") mask_input = gr.Image(type="numpy", label="Upload Mask Image (for Inpainting)", optional=True) inpaint_radius_input = gr.Slider(minimum=1, maximum=50, value=3, label="Inpaint Radius") bg_removal_interface = gr.Interface( fn, inputs=image_input, outputs=slider1, examples=["chameleon.jpg"] ) design_editing_interface = gr.Interface( fn=lambda image, text, color, font_size, mask_image, inpaint_radius: update_image(image, text, color, font_size, mask_image, inpaint_radius), inputs=[image_input, text_input, color_input, font_size_input, mask_input, inpaint_radius_input], outputs=slider2 ) demo = gr.TabbedInterface( [bg_removal_interface, design_editing_interface], ["Background Removal", "Design Editing"], title="Advanced Image Editor" ) if __name__ == "__main__": demo.launch()