from typing import TypedDict import diffusers.image_processor import gradio as gr import pillow_heif import spaces import torch from PIL import Image from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation from pipeline import TryOffAnyone import numpy as np pillow_heif.register_heif_opener() pillow_heif.register_avif_opener() torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True TITLE = """ # Try Off Anyone ## Important 1. Choose an example image or upload your own [[arxiv:2412.08573]](https://arxiv.org/abs/2412.08573) [[github:ixarchakos/try-off-anyone]](https://github.com/ixarchakos/try-off-anyone) """ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DTYPE = torch.bfloat16 if DEVICE == 'cuda' else torch.float32 pipeline_tryoff = TryOffAnyone( device=DEVICE, dtype=DTYPE, ) mask_processor = diffusers.image_processor.VaeImageProcessor( vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True, ) vae_processor = diffusers.image_processor.VaeImageProcessor( vae_scale_factor=8, ) def mask_generation(image, processor, model, category): inputs = processor(images=image, return_tensors="pt").to("cuda") outputs = model(**inputs) logits = outputs.logits.cpu() upsampled_logits = torch.nn.functional.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False, ) predicted_mask = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy() if category == "Tops": predicted_mask_1 = predicted_mask == 4 predicted_mask_2 = predicted_mask == 7 elif category == "Bottoms": predicted_mask_1 = predicted_mask == 5 predicted_mask_2 = predicted_mask == 6 else: raise NotImplementedError predicted_mask = predicted_mask_1 + predicted_mask_2 mask_image = Image.fromarray((predicted_mask * 255).astype(np.uint8)) return mask_image class ImageData(TypedDict): background: Image.Image composite: Image.Image layers: list[Image.Image] @spaces.GPU def process( image_data: ImageData, image_width: int, image_height: int, num_inference_steps: int, condition_scale: float, seed: int, ) -> Image.Image: assert image_width > 0 assert image_height > 0 assert num_inference_steps > 0 assert condition_scale > 0 assert seed >= 0 # extract image and mask from image_data image = image_data["background"] processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer_b3_clothes") model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer_b3_clothes") model.to("cuda") # preprocess image image = image.convert("RGB").resize((image_width, image_height)) mask = mask_generation(image, processor, model, "Tops") image_preprocessed = vae_processor.preprocess( image=image, width=image_width, height=image_height, )[0] # preprocess mask mask = mask.resize((image_width, image_height)) mask_preprocessed = mask_processor.preprocess( # pyright: ignore[reportUnknownMemberType] image=mask, width=image_width, height=image_height, )[0] # generate the TryOff image gen = torch.Generator(device=DEVICE).manual_seed(seed) tryoff_image = pipeline_tryoff( image_preprocessed, mask_preprocessed, inference_steps=num_inference_steps, scale=condition_scale, generator=gen, )[0] return tryoff_image with gr.Blocks() as demo: gr.Markdown(TITLE) with gr.Row(): with gr.Column(): input_image = gr.ImageMask( label="Input Image", height=1024, type="pil", interactive=True, ) run_button = gr.Button( value="Extract Clothing", ) gr.Examples( examples=[ ["examples/model_1.jpg"], ["examples/model_2.jpg"], ["examples/model_3.jpg"], ["examples/model_4.jpg"], ["examples/model_5.jpg"], ["examples/model_6.jpg"], ["examples/model_7.jpg"], ["examples/model_8.jpg"], ["examples/model_9.jpg"], ], inputs=[input_image], ) with gr.Column(): output_image = gr.Image( label="TryOff result", height=1024, image_mode="RGB", type="pil", ) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=36, maximum=36, value=36, step=1, ) scale = gr.Slider( label="Scale", minimum=2.5, maximum=2.5, value=2.5, step=0, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=50, maximum=50, value=50, step=1, ) with gr.Row(): image_width = gr.Slider( label="Image Width", minimum=384, maximum=384, value=384, step=8, ) image_height = gr.Slider( label="Image Height", minimum=512, maximum=512, value=512, step=8, ) run_button.click( fn=process, inputs=[ input_image, image_width, image_height, num_inference_steps, scale, seed, ], outputs=output_image, ) demo.launch()