Spaces:
Running
on
Zero
Running
on
Zero
from typing import TypedDict | |
import diffusers.image_processor | |
import gradio as gr | |
import pillow_heif | |
import spaces | |
import torch | |
from PIL import Image | |
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
from pipeline import TryOffAnyone | |
import numpy as np | |
pillow_heif.register_heif_opener() | |
pillow_heif.register_avif_opener() | |
torch.set_float32_matmul_precision("high") | |
torch.backends.cuda.matmul.allow_tf32 = True | |
TITLE = """ | |
# Try Off Anyone | |
## Important | |
1. Choose an example image or upload your own | |
[[arxiv:2412.08573]](https://arxiv.org/abs/2412.08573) | |
[[github:ixarchakos/try-off-anyone]](https://github.com/ixarchakos/try-off-anyone) | |
""" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
DTYPE = torch.bfloat16 if DEVICE == 'cuda' else torch.float32 | |
pipeline_tryoff = TryOffAnyone( | |
device=DEVICE, | |
dtype=DTYPE, | |
) | |
mask_processor = diffusers.image_processor.VaeImageProcessor( | |
vae_scale_factor=8, | |
do_normalize=False, | |
do_binarize=True, | |
do_convert_grayscale=True, | |
) | |
vae_processor = diffusers.image_processor.VaeImageProcessor( | |
vae_scale_factor=8, | |
) | |
def mask_generation(image, processor, model, category): | |
inputs = processor(images=image, return_tensors="pt").to("cuda") | |
outputs = model(**inputs) | |
logits = outputs.logits.cpu() | |
upsampled_logits = torch.nn.functional.interpolate( | |
logits, | |
size=image.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
predicted_mask = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy() | |
if category == "Tops": | |
predicted_mask_1 = predicted_mask == 4 | |
predicted_mask_2 = predicted_mask == 7 | |
elif category == "Bottoms": | |
predicted_mask_1 = predicted_mask == 5 | |
predicted_mask_2 = predicted_mask == 6 | |
else: | |
raise NotImplementedError | |
predicted_mask = predicted_mask_1 + predicted_mask_2 | |
mask_image = Image.fromarray((predicted_mask * 255).astype(np.uint8)) | |
return mask_image | |
class ImageData(TypedDict): | |
background: Image.Image | |
composite: Image.Image | |
layers: list[Image.Image] | |
def process( | |
image_data: ImageData, | |
image_width: int, | |
image_height: int, | |
num_inference_steps: int, | |
condition_scale: float, | |
seed: int, | |
) -> Image.Image: | |
assert image_width > 0 | |
assert image_height > 0 | |
assert num_inference_steps > 0 | |
assert condition_scale > 0 | |
assert seed >= 0 | |
# extract image and mask from image_data | |
image = image_data["background"] | |
processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer_b3_clothes") | |
model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer_b3_clothes") | |
model.to("cuda") | |
# preprocess image | |
image = image.convert("RGB").resize((image_width, image_height)) | |
mask = mask_generation(image, processor, model, "Tops") | |
image_preprocessed = vae_processor.preprocess( | |
image=image, | |
width=image_width, | |
height=image_height, | |
)[0] | |
# preprocess mask | |
mask = mask.resize((image_width, image_height)) | |
mask_preprocessed = mask_processor.preprocess( # pyright: ignore[reportUnknownMemberType] | |
image=mask, | |
width=image_width, | |
height=image_height, | |
)[0] | |
# generate the TryOff image | |
gen = torch.Generator(device=DEVICE).manual_seed(seed) | |
tryoff_image = pipeline_tryoff( | |
image_preprocessed, | |
mask_preprocessed, | |
inference_steps=num_inference_steps, | |
scale=condition_scale, | |
generator=gen, | |
)[0] | |
return tryoff_image | |
with gr.Blocks() as demo: | |
gr.Markdown(TITLE) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.ImageMask( | |
label="Input Image", | |
height=1024, | |
type="pil", | |
interactive=True, | |
) | |
run_button = gr.Button( | |
value="Extract Clothing", | |
) | |
gr.Examples( | |
examples=[ | |
["examples/model_1.jpg"], | |
["examples/model_2.jpg"], | |
["examples/model_3.jpg"], | |
["examples/model_4.jpg"], | |
["examples/model_5.jpg"], | |
["examples/model_6.jpg"], | |
["examples/model_7.jpg"], | |
["examples/model_8.jpg"], | |
["examples/model_9.jpg"], | |
], | |
inputs=[input_image], | |
) | |
with gr.Column(): | |
output_image = gr.Image( | |
label="TryOff result", | |
height=1024, | |
image_mode="RGB", | |
type="pil", | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=36, | |
maximum=36, | |
value=36, | |
step=1, | |
) | |
scale = gr.Slider( | |
label="Scale", | |
minimum=2.5, | |
maximum=2.5, | |
value=2.5, | |
step=0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=50, | |
maximum=50, | |
value=50, | |
step=1, | |
) | |
with gr.Row(): | |
image_width = gr.Slider( | |
label="Image Width", | |
minimum=384, | |
maximum=384, | |
value=384, | |
step=8, | |
) | |
image_height = gr.Slider( | |
label="Image Height", | |
minimum=512, | |
maximum=512, | |
value=512, | |
step=8, | |
) | |
run_button.click( | |
fn=process, | |
inputs=[ | |
input_image, | |
image_width, | |
image_height, | |
num_inference_steps, | |
scale, | |
seed, | |
], | |
outputs=output_image, | |
) | |
demo.launch() | |