fuchsia-filter / app.py
Golfies's picture
Update app.py
7cddaa4 verified
import os
import gradio as gr
import cv2
from PIL import Image
import numpy as np
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import spaces # Import ZeroGPU support
# Detect if CUDA is available; otherwise, fallback to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load BiRefNet model
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to(device)
# Image transformation pipeline
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@spaces.GPU(duration=70) # Decorate to ensure GPU is allocated only during model loading
# Function to extract the subject using BiRefNet and create a mask
def create_mask(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu() # Always move results to CPU for processing
pred = preds[0].squeeze()
mask_pil = transforms.ToPILImage()(pred)
mask = mask_pil.resize(image_size)
return mask
# Function to apply the pink filter-like color change
def apply_filter(image, mask=None, apply_to_subject=True):
# Convert image to numpy array
image_np = np.array(image.convert("RGBA"))
# Define the pink color in RGBA
pink_color = np.array([255, 0, 255, 128]) # Pink color with transparency (alpha = 128)
if apply_to_subject and mask is not None:
# Convert mask to numpy array
mask_np = np.array(mask)
# Blend the original image with the pink color where the mask is applied
for i in range(image_np.shape[0]):
for j in range(image_np.shape[1]):
if mask_np[i, j] > 128: # Check if the mask value indicates subject presence
image_np[i, j] = (image_np[i, j] * 0.5 + pink_color * 0.5).astype(np.uint8)
else:
# Apply the pink filter to the whole image if no subject is detected or if chosen by user
image_np = (image_np * 0.5 + pink_color * 0.5).astype(np.uint8)
# Convert back to PIL image
result_image = Image.fromarray(image_np)
return result_image
# Main processing function for Gradio
def process(input_image, subject_choice):
if input_image is None:
raise gr.Error('Please upload an input image')
# Convert input image to PIL image
original_image = Image.fromarray(input_image)
# Default mask is None
mask = None
# Generate mask using BiRefNet if the user selected "Subject Only"
if subject_choice == "Subject Only":
mask = create_mask(original_image)
# Apply pink filter based on user choice
apply_to_subject = (subject_choice == "Subject Only" and mask is not None)
result_image = apply_filter(original_image, mask, apply_to_subject)
return result_image
# Define Gradio Interface
block = gr.Blocks()
with block:
with gr.Row():
gr.Markdown("Apply Pink Filter Effect to Subject or Full Image")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy", label="Input Image", height=640)
subject_choice = gr.Radio(
choices=["Subject Only", "Full Image"],
value="Subject Only",
label="Apply Pink Filter to:"
)
run_button = gr.Button("Run")
with gr.Column():
output_image = gr.Image(label="Output Image")
# Set the processing function
run_button.click(
fn=process,
inputs=[input_image, subject_choice],
outputs=output_image
)
block.launch()