import os import gradio as gr import cv2 from PIL import Image import numpy as np from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms import spaces # Import ZeroGPU support # Detect if CUDA is available; otherwise, fallback to CPU device = "cuda" if torch.cuda.is_available() else "cpu" # Load BiRefNet model torch.set_float32_matmul_precision(["high", "highest"][0]) birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to(device) # Image transformation pipeline transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) @spaces.GPU(duration=70) # Decorate to ensure GPU is allocated only during model loading # Function to extract the subject using BiRefNet and create a mask def create_mask(image): image_size = image.size input_images = transform_image(image).unsqueeze(0).to(device) with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() # Always move results to CPU for processing pred = preds[0].squeeze() mask_pil = transforms.ToPILImage()(pred) mask = mask_pil.resize(image_size) return mask # Function to apply the pink filter-like color change def apply_filter(image, mask=None, apply_to_subject=True): # Convert image to numpy array image_np = np.array(image.convert("RGBA")) # Define the pink color in RGBA pink_color = np.array([255, 0, 255, 128]) # Pink color with transparency (alpha = 128) if apply_to_subject and mask is not None: # Convert mask to numpy array mask_np = np.array(mask) # Blend the original image with the pink color where the mask is applied for i in range(image_np.shape[0]): for j in range(image_np.shape[1]): if mask_np[i, j] > 128: # Check if the mask value indicates subject presence image_np[i, j] = (image_np[i, j] * 0.5 + pink_color * 0.5).astype(np.uint8) else: # Apply the pink filter to the whole image if no subject is detected or if chosen by user image_np = (image_np * 0.5 + pink_color * 0.5).astype(np.uint8) # Convert back to PIL image result_image = Image.fromarray(image_np) return result_image # Main processing function for Gradio def process(input_image, subject_choice): if input_image is None: raise gr.Error('Please upload an input image') # Convert input image to PIL image original_image = Image.fromarray(input_image) # Default mask is None mask = None # Generate mask using BiRefNet if the user selected "Subject Only" if subject_choice == "Subject Only": mask = create_mask(original_image) # Apply pink filter based on user choice apply_to_subject = (subject_choice == "Subject Only" and mask is not None) result_image = apply_filter(original_image, mask, apply_to_subject) return result_image # Define Gradio Interface block = gr.Blocks() with block: with gr.Row(): gr.Markdown("Apply Pink Filter Effect to Subject or Full Image") with gr.Row(): with gr.Column(): input_image = gr.Image(type="numpy", label="Input Image", height=640) subject_choice = gr.Radio( choices=["Subject Only", "Full Image"], value="Subject Only", label="Apply Pink Filter to:" ) run_button = gr.Button("Run") with gr.Column(): output_image = gr.Image(label="Output Image") # Set the processing function run_button.click( fn=process, inputs=[input_image, subject_choice], outputs=output_image ) block.launch()