Spaces:
Running
on
Zero
Running
on
Zero
# Modified from https://github.com/tencent-ailab/IP-Adapter | |
import os | |
from typing import List | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from diffusers.pipelines.controlnet import MultiControlNetModel | |
from PIL import Image | |
from safetensors import safe_open | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | |
from .utils import is_torch2_available | |
if is_torch2_available(): | |
from .attention_processor import ( | |
AttnProcessor2_0 as AttnProcessor, | |
) | |
else: | |
from .attention_processor import AttnProcessor | |
from .resampler import LinearResampler | |
class MimicBrush_RefNet: | |
def __init__(self, sd_pipe, image_encoder_path, model_ckpt, depth_estimator, depth_guider,referencenet, device): | |
# Takes model path as input | |
self.device = device | |
self.image_encoder_path = image_encoder_path | |
self.model_ckpt = model_ckpt | |
self.referencenet = referencenet.to(self.device) | |
self.depth_estimator = depth_estimator.to(self.device).eval() | |
self.depth_guider = depth_guider.to(self.device, dtype=torch.float16) | |
self.pipe = sd_pipe.to(self.device) | |
self.pipe.unet.set_attn_processor(AttnProcessor()) | |
# load image encoder | |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( | |
self.device, dtype=torch.float16 | |
) | |
self.clip_image_processor = CLIPImageProcessor() | |
# image proj model | |
self.image_proj_model = self.init_proj() | |
self.image_processor = VaeImageProcessor() | |
self.load_checkpoint() | |
def init_proj(self): | |
image_proj_model = LinearResampler( | |
input_dim=1280, | |
output_dim=self.pipe.unet.config.cross_attention_dim, | |
).to(self.device, dtype=torch.float16) | |
return image_proj_model | |
def load_checkpoint(self): | |
state_dict = torch.load(self.model_ckpt, map_location="cpu") | |
self.image_proj_model.load_state_dict(state_dict["image_proj"]) | |
self.depth_guider.load_state_dict(state_dict["depth_guider"]) | |
print('=== load depth_guider ===') | |
self.referencenet.load_state_dict(state_dict["referencenet"]) | |
print('=== load referencenet ===') | |
self.image_encoder.load_state_dict(state_dict["image_encoder"]) | |
print('=== load image_encoder ===') | |
if "unet" in state_dict.keys(): | |
self.pipe.unet.load_state_dict(state_dict["unet"]) | |
print('=== load unet ===') | |
def get_image_embeds(self, pil_image=None, clip_image_embeds=None): | |
if isinstance(pil_image, Image.Image): | |
pil_image = [pil_image] | |
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values | |
clip_image = clip_image.to(self.device, dtype=torch.float16) | |
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] | |
image_prompt_embeds = self.image_proj_model(clip_image_embeds).to(dtype=torch.float16) | |
uncond_clip_image_embeds = self.image_encoder( | |
torch.zeros_like(clip_image), output_hidden_states=True | |
).hidden_states[-2] | |
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) | |
return image_prompt_embeds, uncond_image_prompt_embeds | |
def generate( | |
self, | |
pil_image=None, | |
depth_image = None, | |
clip_image_embeds=None, | |
prompt=None, | |
negative_prompt=None, | |
num_samples=4, | |
seed=None, | |
image = None, | |
guidance_scale=7.5, | |
num_inference_steps=30, | |
**kwargs, | |
): | |
if pil_image is not None: | |
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) | |
else: | |
num_prompts = clip_image_embeds.size(0) | |
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( | |
pil_image=pil_image, clip_image_embeds=clip_image_embeds | |
) | |
bs_embed, seq_len, _ = image_prompt_embeds.shape | |
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) | |
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) | |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
depth_image = depth_image.to(self.device) | |
depth_map = self.depth_estimator(depth_image).unsqueeze(1) | |
depth_feature = self.depth_guider(depth_map.to(self.device, dtype=torch.float16)) | |
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None | |
images = self.pipe( | |
prompt_embeds=image_prompt_embeds , # image clip embedding | |
negative_prompt_embeds=uncond_image_prompt_embeds, # uncond image clip embedding | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
referencenet=self.referencenet, | |
source_image=pil_image, | |
image = image, | |
clip_image_embed= torch.cat([uncond_image_prompt_embeds, image_prompt_embeds], dim=0), # for reference U-Net | |
depth_feature = depth_feature, | |
**kwargs, | |
).images | |
return images, depth_map | |
class MimicBrush_RefNet_inputmodel(MimicBrush_RefNet): | |
# take model as input | |
def __init__(self, sd_pipe, image_encoder, image_proj_model, depth_estimator, depth_guider, referencenet, device): | |
self.device = device | |
self.image_encoder = image_encoder.to( | |
self.device, dtype=torch.float16 | |
) | |
self.depth_estimator = depth_estimator.to(self.device) | |
self.depth_guider = depth_guider.to(self.device, dtype=torch.float16) | |
self.image_proj_model = image_proj_model.to(self.device, dtype=torch.float16) | |
self.referencenet = referencenet.to(self.device, dtype=torch.float16) | |
self.pipe = sd_pipe.to(self.device) | |
self.pipe.unet.set_attn_processor(AttnProcessor()) | |
self.referencenet.set_attn_processor(AttnProcessor()) | |
self.clip_image_processor = CLIPImageProcessor() | |