Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,368 Bytes
81d8e7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# Modified from https://github.com/tencent-ailab/IP-Adapter
import os
from typing import List
import torch
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from .utils import is_torch2_available
if is_torch2_available():
from .attention_processor import (
AttnProcessor2_0 as AttnProcessor,
)
else:
from .attention_processor import AttnProcessor
from .resampler import LinearResampler
class MimicBrush_RefNet:
def __init__(self, sd_pipe, image_encoder_path, model_ckpt, depth_estimator, depth_guider,referencenet, device):
# Takes model path as input
self.device = device
self.image_encoder_path = image_encoder_path
self.model_ckpt = model_ckpt
self.referencenet = referencenet.to(self.device)
self.depth_estimator = depth_estimator.to(self.device).eval()
self.depth_guider = depth_guider.to(self.device, dtype=torch.float16)
self.pipe = sd_pipe.to(self.device)
self.pipe.unet.set_attn_processor(AttnProcessor())
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
# image proj model
self.image_proj_model = self.init_proj()
self.image_processor = VaeImageProcessor()
self.load_checkpoint()
def init_proj(self):
image_proj_model = LinearResampler(
input_dim=1280,
output_dim=self.pipe.unet.config.cross_attention_dim,
).to(self.device, dtype=torch.float16)
return image_proj_model
def load_checkpoint(self):
state_dict = torch.load(self.model_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"])
self.depth_guider.load_state_dict(state_dict["depth_guider"])
print('=== load depth_guider ===')
self.referencenet.load_state_dict(state_dict["referencenet"])
print('=== load referencenet ===')
self.image_encoder.load_state_dict(state_dict["image_encoder"])
print('=== load image_encoder ===')
if "unet" in state_dict.keys():
self.pipe.unet.load_state_dict(state_dict["unet"])
print('=== load unet ===')
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds).to(dtype=torch.float16)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
def generate(
self,
pil_image=None,
depth_image = None,
clip_image_embeds=None,
prompt=None,
negative_prompt=None,
num_samples=4,
seed=None,
image = None,
guidance_scale=7.5,
num_inference_steps=30,
**kwargs,
):
if pil_image is not None:
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
else:
num_prompts = clip_image_embeds.size(0)
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
pil_image=pil_image, clip_image_embeds=clip_image_embeds
)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
depth_image = depth_image.to(self.device)
depth_map = self.depth_estimator(depth_image).unsqueeze(1)
depth_feature = self.depth_guider(depth_map.to(self.device, dtype=torch.float16))
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.pipe(
prompt_embeds=image_prompt_embeds , # image clip embedding
negative_prompt_embeds=uncond_image_prompt_embeds, # uncond image clip embedding
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
referencenet=self.referencenet,
source_image=pil_image,
image = image,
clip_image_embed= torch.cat([uncond_image_prompt_embeds, image_prompt_embeds], dim=0), # for reference U-Net
depth_feature = depth_feature,
**kwargs,
).images
return images, depth_map
class MimicBrush_RefNet_inputmodel(MimicBrush_RefNet):
# take model as input
def __init__(self, sd_pipe, image_encoder, image_proj_model, depth_estimator, depth_guider, referencenet, device):
self.device = device
self.image_encoder = image_encoder.to(
self.device, dtype=torch.float16
)
self.depth_estimator = depth_estimator.to(self.device)
self.depth_guider = depth_guider.to(self.device, dtype=torch.float16)
self.image_proj_model = image_proj_model.to(self.device, dtype=torch.float16)
self.referencenet = referencenet.to(self.device, dtype=torch.float16)
self.pipe = sd_pipe.to(self.device)
self.pipe.unet.set_attn_processor(AttnProcessor())
self.referencenet.set_attn_processor(AttnProcessor())
self.clip_image_processor = CLIPImageProcessor()
|