TryOffAnyone / app.py
ixarchakos's picture
Upload app.py
4d6431f verified
from typing import TypedDict
import diffusers.image_processor
import gradio as gr
import pillow_heif
import spaces
import torch
from PIL import Image
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from pipeline import TryOffAnyone
import numpy as np
pillow_heif.register_heif_opener()
pillow_heif.register_avif_opener()
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
TITLE = """
# Try Off Anyone
## Important
1. Choose an example image or upload your own
[[arxiv:2412.08573]](https://arxiv.org/abs/2412.08573)
[[github:ixarchakos/try-off-anyone]](https://github.com/ixarchakos/try-off-anyone)
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16 if DEVICE == 'cuda' else torch.float32
pipeline_tryoff = TryOffAnyone(
device=DEVICE,
dtype=DTYPE,
)
mask_processor = diffusers.image_processor.VaeImageProcessor(
vae_scale_factor=8,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
)
vae_processor = diffusers.image_processor.VaeImageProcessor(
vae_scale_factor=8,
)
def mask_generation(image, processor, model, category):
inputs = processor(images=image, return_tensors="pt").to("cuda")
outputs = model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = torch.nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
predicted_mask = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()
if category == "Tops":
predicted_mask_1 = predicted_mask == 4
predicted_mask_2 = predicted_mask == 7
elif category == "Bottoms":
predicted_mask_1 = predicted_mask == 5
predicted_mask_2 = predicted_mask == 6
else:
raise NotImplementedError
predicted_mask = predicted_mask_1 + predicted_mask_2
mask_image = Image.fromarray((predicted_mask * 255).astype(np.uint8))
return mask_image
class ImageData(TypedDict):
background: Image.Image
composite: Image.Image
layers: list[Image.Image]
@spaces.GPU
def process(
image_data: ImageData,
image_width: int,
image_height: int,
num_inference_steps: int,
condition_scale: float,
seed: int,
) -> Image.Image:
assert image_width > 0
assert image_height > 0
assert num_inference_steps > 0
assert condition_scale > 0
assert seed >= 0
# extract image and mask from image_data
image = image_data["background"]
processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer_b3_clothes")
model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer_b3_clothes")
model.to("cuda")
# preprocess image
image = image.convert("RGB").resize((image_width, image_height))
mask = mask_generation(image, processor, model, "Tops")
image_preprocessed = vae_processor.preprocess(
image=image,
width=image_width,
height=image_height,
)[0]
# preprocess mask
mask = mask.resize((image_width, image_height))
mask_preprocessed = mask_processor.preprocess( # pyright: ignore[reportUnknownMemberType]
image=mask,
width=image_width,
height=image_height,
)[0]
# generate the TryOff image
gen = torch.Generator(device=DEVICE).manual_seed(seed)
tryoff_image = pipeline_tryoff(
image_preprocessed,
mask_preprocessed,
inference_steps=num_inference_steps,
scale=condition_scale,
generator=gen,
)[0]
return tryoff_image
with gr.Blocks() as demo:
gr.Markdown(TITLE)
with gr.Row():
with gr.Column():
input_image = gr.ImageMask(
label="Input Image",
height=1024,
type="pil",
interactive=True,
)
run_button = gr.Button(
value="Extract Clothing",
)
gr.Examples(
examples=[
["examples/model_1.jpg"],
["examples/model_2.jpg"],
["examples/model_3.jpg"],
["examples/model_4.jpg"],
["examples/model_5.jpg"],
["examples/model_6.jpg"],
["examples/model_7.jpg"],
["examples/model_8.jpg"],
["examples/model_9.jpg"],
],
inputs=[input_image],
)
with gr.Column():
output_image = gr.Image(
label="TryOff result",
height=1024,
image_mode="RGB",
type="pil",
)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=36,
maximum=36,
value=36,
step=1,
)
scale = gr.Slider(
label="Scale",
minimum=2.5,
maximum=2.5,
value=2.5,
step=0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=50,
maximum=50,
value=50,
step=1,
)
with gr.Row():
image_width = gr.Slider(
label="Image Width",
minimum=384,
maximum=384,
value=384,
step=8,
)
image_height = gr.Slider(
label="Image Height",
minimum=512,
maximum=512,
value=512,
step=8,
)
run_button.click(
fn=process,
inputs=[
input_image,
image_width,
image_height,
num_inference_steps,
scale,
seed,
],
outputs=output_image,
)
demo.launch()