Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
from transformers import AutoModelForImageSegmentation | |
import torch | |
from torchvision import transforms | |
import spaces # Import ZeroGPU support | |
# Detect if CUDA is available; otherwise, fallback to CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load BiRefNet model | |
torch.set_float32_matmul_precision(["high", "highest"][0]) | |
birefnet = AutoModelForImageSegmentation.from_pretrained( | |
"ZhengPeng7/BiRefNet", trust_remote_code=True | |
) | |
birefnet.to(device) | |
# Image transformation pipeline | |
transform_image = transforms.Compose( | |
[ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
# Decorate to ensure GPU is allocated only during model loading | |
# Function to extract the subject using BiRefNet and create a mask | |
def create_mask(image): | |
image_size = image.size | |
input_images = transform_image(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
preds = birefnet(input_images)[-1].sigmoid().cpu() # Always move results to CPU for processing | |
pred = preds[0].squeeze() | |
mask_pil = transforms.ToPILImage()(pred) | |
mask = mask_pil.resize(image_size) | |
return mask | |
# Function to apply the pink filter-like color change | |
def apply_filter(image, mask=None, apply_to_subject=True): | |
# Convert image to numpy array | |
image_np = np.array(image.convert("RGBA")) | |
# Define the pink color in RGBA | |
pink_color = np.array([255, 0, 255, 128]) # Pink color with transparency (alpha = 128) | |
if apply_to_subject and mask is not None: | |
# Convert mask to numpy array | |
mask_np = np.array(mask) | |
# Blend the original image with the pink color where the mask is applied | |
for i in range(image_np.shape[0]): | |
for j in range(image_np.shape[1]): | |
if mask_np[i, j] > 128: # Check if the mask value indicates subject presence | |
image_np[i, j] = (image_np[i, j] * 0.5 + pink_color * 0.5).astype(np.uint8) | |
else: | |
# Apply the pink filter to the whole image if no subject is detected or if chosen by user | |
image_np = (image_np * 0.5 + pink_color * 0.5).astype(np.uint8) | |
# Convert back to PIL image | |
result_image = Image.fromarray(image_np) | |
return result_image | |
# Main processing function for Gradio | |
def process(input_image, subject_choice): | |
if input_image is None: | |
raise gr.Error('Please upload an input image') | |
# Convert input image to PIL image | |
original_image = Image.fromarray(input_image) | |
# Default mask is None | |
mask = None | |
# Generate mask using BiRefNet if the user selected "Subject Only" | |
if subject_choice == "Subject Only": | |
mask = create_mask(original_image) | |
# Apply pink filter based on user choice | |
apply_to_subject = (subject_choice == "Subject Only" and mask is not None) | |
result_image = apply_filter(original_image, mask, apply_to_subject) | |
return result_image | |
# Define Gradio Interface | |
block = gr.Blocks() | |
with block: | |
with gr.Row(): | |
gr.Markdown("Apply Pink Filter Effect to Subject or Full Image") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="numpy", label="Input Image", height=640) | |
subject_choice = gr.Radio( | |
choices=["Subject Only", "Full Image"], | |
value="Subject Only", | |
label="Apply Pink Filter to:" | |
) | |
run_button = gr.Button("Run") | |
with gr.Column(): | |
output_image = gr.Image(label="Output Image") | |
# Set the processing function | |
run_button.click( | |
fn=process, | |
inputs=[input_image, subject_choice], | |
outputs=output_image | |
) | |
block.launch() |