# type: ignore # Inspired from https://github.com/ixarchakos/try-off-anyone/blob/aa3045453013065573a647e4536922bac696b968/src/model/pipeline.py # Inspired from https://github.com/ixarchakos/try-off-anyone/blob/aa3045453013065573a647e4536922bac696b968/src/model/attention.py import torch from accelerate import load_checkpoint_in_model from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel from diffusers.models.attention_processor import AttnProcessor from diffusers.utils.torch_utils import randn_tensor from huggingface_hub import hf_hub_download from PIL import Image class Skip(torch.nn.Module): def __init__(self) -> None: super().__init__() def __call__( self, attn: torch.Tensor, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, attention_mask: torch.Tensor = None, temb: torch.Tensor = None, ) -> torch.Tensor: return hidden_states def fine_tuned_modules(unet: UNet2DConditionModel) -> torch.nn.ModuleList: trainable_modules = torch.nn.ModuleList() for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]: if hasattr(blocks, "attentions"): trainable_modules.append(blocks.attentions) else: for block in blocks: if hasattr(block, "attentions"): trainable_modules.append(block.attentions) return trainable_modules def skip_cross_attentions(unet: UNet2DConditionModel) -> dict[str, AttnProcessor | Skip]: attn_processors = { name: unet.attn_processors[name] if name.endswith("attn1.processor") else Skip() for name in unet.attn_processors.keys() } return attn_processors def encode(image: torch.Tensor, vae: AutoencoderKL) -> torch.Tensor: image = image.to(memory_format=torch.contiguous_format).float().to(vae.device, dtype=vae.dtype) with torch.no_grad(): return vae.encode(image).latent_dist.sample() * vae.config.scaling_factor class TryOffAnyone: def __init__( self, device: torch.device, dtype: torch.dtype, concat_dim: int = -2, ) -> None: self.concat_dim = concat_dim self.device = device self.dtype = dtype self.noise_scheduler = DDIMScheduler.from_pretrained( pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-inpainting", subfolder="scheduler", ) self.vae = AutoencoderKL.from_pretrained( pretrained_model_name_or_path="stabilityai/sd-vae-ft-mse", ).to(device, dtype=dtype) self.unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-inpainting", subfolder="unet", variant="fp16", ).to(device, dtype=dtype) self.unet.set_attn_processor(skip_cross_attentions(self.unet)) load_checkpoint_in_model( model=fine_tuned_modules(unet=self.unet), checkpoint=hf_hub_download( repo_id="ixarchakos/tryOffAnyone", filename="model.safetensors", ), ) @torch.no_grad() def __call__( self, image: torch.Tensor, mask: torch.Tensor, inference_steps: int, scale: float, generator: torch.Generator, ) -> list[Image.Image]: image = image.unsqueeze(0).to(self.device, dtype=self.dtype) mask = (mask.unsqueeze(0) > 0.5).to(self.device, dtype=self.dtype) masked_image = image * (mask < 0.5) masked_latent = encode(masked_image, self.vae) image_latent = encode(image, self.vae) mask = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest") masked_latent_concat = torch.cat([masked_latent, image_latent], dim=self.concat_dim) mask_concat = torch.cat([mask, torch.zeros_like(mask)], dim=self.concat_dim) latents = randn_tensor( shape=masked_latent_concat.shape, generator=generator, device=self.device, dtype=self.dtype, ) self.noise_scheduler.set_timesteps(inference_steps, device=self.device) timesteps = self.noise_scheduler.timesteps if do_classifier_free_guidance := (scale > 1.0): masked_latent_concat = torch.cat( [ torch.cat([masked_latent, torch.zeros_like(image_latent)], dim=self.concat_dim), masked_latent_concat, ] ) mask_concat = torch.cat([mask_concat] * 2) extra_step = {"generator": generator, "eta": 1.0} for t in timesteps: input_latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents input_latents = self.noise_scheduler.scale_model_input(input_latents, t) input_latents = torch.cat([input_latents, mask_concat, masked_latent_concat], dim=1) noise_pred = self.unet( input_latents, t.to(self.device), encoder_hidden_states=None, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred_unc, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_unc + scale * (noise_pred_text - noise_pred_unc) latents = self.noise_scheduler.step(noise_pred, t, latents, **extra_step).prev_sample latents = latents.split(latents.shape[self.concat_dim] // 2, dim=self.concat_dim)[0] latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents.to(self.device, dtype=self.dtype)).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = (image * 255).round().astype("uint8") image = [Image.fromarray(im) for im in image] return image