Spaces:
Runtime error
Runtime error
# Based on https://raw.githubusercontent.com/okotaku/diffusers/feature/reference_only_control/examples/community/stable_diffusion_reference.py | |
# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 | |
from typing import Any, Callable, Dict, List, Optional, Union, Tuple | |
import numpy as np | |
import PIL.Image | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from diffusers.models.attention import BasicTransformerBlock | |
from diffusers.models.unet_2d_blocks import ( | |
CrossAttnDownBlock2D, | |
CrossAttnUpBlock2D, | |
DownBlock2D, | |
UpBlock2D, | |
) | |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | |
from diffusers.utils import PIL_INTERPOLATION, logging | |
import torch.nn.functional as F | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
EXAMPLE_DOC_STRING = """ | |
Examples: | |
```py | |
>>> import torch | |
>>> from diffusers import UniPCMultistepScheduler | |
>>> from diffusers.utils import load_image | |
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") | |
>>> pipe = StableDiffusionReferencePipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
safety_checker=None, | |
torch_dtype=torch.float16 | |
).to('cuda:0') | |
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config) | |
>>> result_img = pipe(ref_image=input_image, | |
prompt="1girl", | |
num_inference_steps=20, | |
reference_attn=True, | |
reference_adain=True).images[0] | |
>>> result_img.show() | |
``` | |
""" | |
def torch_dfs(model: torch.nn.Module): | |
result = [model] | |
for child in model.children(): | |
result += torch_dfs(child) | |
return result | |
class StableDiffusionReferencePipeline: | |
def prepare_ref_image( | |
self, | |
image, | |
width, | |
height, | |
batch_size, | |
num_images_per_prompt, | |
device, | |
dtype, | |
do_classifier_free_guidance=False, | |
guess_mode=False, | |
): | |
if not isinstance(image, torch.Tensor): | |
if isinstance(image, PIL.Image.Image): | |
image = [image] | |
if isinstance(image[0], PIL.Image.Image): | |
images = [] | |
for image_ in image: | |
image_ = image_.convert("RGB") | |
image_ = image_.resize( | |
(width, height), resample=PIL_INTERPOLATION["lanczos"] | |
) | |
image_ = np.array(image_) | |
image_ = image_[None, :] | |
images.append(image_) | |
image = images | |
image = np.concatenate(image, axis=0) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = (image - 0.5) / 0.5 | |
image = image.transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image) | |
elif isinstance(image[0], torch.Tensor): | |
image = torch.cat(image, dim=0) | |
image_batch_size = image.shape[0] | |
if image_batch_size == 1: | |
repeat_by = batch_size | |
else: | |
# image batch size is the same as prompt batch size | |
repeat_by = num_images_per_prompt | |
image = image.repeat_interleave(repeat_by, dim=0) | |
image = image.to(device=device, dtype=dtype) | |
if do_classifier_free_guidance and not guess_mode: | |
image = torch.cat([image] * 2) | |
return image | |
def prepare_ref_latents( | |
self, | |
refimage, | |
batch_size, | |
dtype, | |
device, | |
generator, | |
do_classifier_free_guidance, | |
): | |
refimage = refimage.to(device=device, dtype=dtype) | |
# encode the mask image into latents space so we can concatenate it to the latents | |
if isinstance(generator, list): | |
ref_image_latents = [ | |
self.vae.encode(refimage[i: i + 1]).latent_dist.sample( | |
generator=generator[i] | |
) | |
for i in range(batch_size) | |
] | |
ref_image_latents = torch.cat(ref_image_latents, dim=0) | |
else: | |
ref_image_latents = self.vae.encode(refimage).latent_dist.sample( | |
generator=generator | |
) | |
ref_image_latents = self.vae.config.scaling_factor * ref_image_latents | |
# duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method | |
if ref_image_latents.shape[0] < batch_size: | |
if not batch_size % ref_image_latents.shape[0] == 0: | |
raise ValueError( | |
"The passed images and the required batch size don't match. Images are supposed to be duplicated" | |
f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed." | |
" Make sure the number of images that you pass is divisible by the total requested batch size." | |
) | |
ref_image_latents = ref_image_latents.repeat( | |
batch_size // ref_image_latents.shape[0], 1, 1, 1 | |
) | |
ref_image_latents = ( | |
torch.cat([ref_image_latents] * 2) | |
if do_classifier_free_guidance | |
else ref_image_latents | |
) | |
# aligning device to prevent device errors when concating it with the latent model input | |
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) | |
return ref_image_latents | |
def check_ref_input(self, reference_attn, reference_adain): | |
assert ( | |
reference_attn or reference_adain | |
), "`reference_attn` or `reference_adain` must be True." | |
def redefine_ref_model( | |
self, model, reference_attn, reference_adain, model_type="unet" | |
): | |
def hacked_basic_transformer_inner_forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
timestep: Optional[torch.LongTensor] = None, | |
cross_attention_kwargs: Dict[str, Any] = None, | |
class_labels: Optional[torch.LongTensor] = None, | |
): | |
if self.use_ada_layer_norm: | |
norm_hidden_states = self.norm1(hidden_states, timestep) | |
elif self.use_ada_layer_norm_zero: | |
( | |
norm_hidden_states, | |
gate_msa, | |
shift_mlp, | |
scale_mlp, | |
gate_mlp, | |
) = self.norm1( | |
hidden_states, | |
timestep, | |
class_labels, | |
hidden_dtype=hidden_states.dtype, | |
) | |
else: | |
norm_hidden_states = self.norm1(hidden_states) | |
# 1. Self-Attention | |
cross_attention_kwargs = ( | |
cross_attention_kwargs if cross_attention_kwargs is not None else {} | |
) | |
if self.only_cross_attention: | |
attn_output = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states | |
if self.only_cross_attention | |
else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
else: | |
if self.MODE == "write": | |
if self.attention_auto_machine_weight > self.attn_weight: | |
# print("hacked_basic_transformer_inner_forward") | |
scale_ratio = ( | |
(self.ref_mask.shape[2] * self.ref_mask.shape[3]) | |
/ norm_hidden_states.shape[1] | |
) ** 0.5 | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(norm_hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
resize_norm_hidden_states = norm_hidden_states.view( | |
norm_hidden_states.shape[0], | |
this_ref_mask.shape[2], | |
this_ref_mask.shape[3], | |
-1, | |
).permute(0, 3, 1, 2) | |
ref_scale = 1.0 | |
resize_norm_hidden_states = F.interpolate( | |
resize_norm_hidden_states, | |
scale_factor=ref_scale, | |
mode="bilinear", | |
) | |
this_ref_mask = F.interpolate( | |
this_ref_mask, scale_factor=ref_scale | |
) | |
# print("this_ref_mask",this_ref_mask.shape) | |
# this_ref_mask = this_ref_mask.view(1,-1,1) | |
this_ref_mask = this_ref_mask.repeat( | |
resize_norm_hidden_states.shape[0], | |
resize_norm_hidden_states.shape[1], | |
1, | |
1, | |
).bool() | |
masked_norm_hidden_states = ( | |
resize_norm_hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view( | |
resize_norm_hidden_states.shape[0], | |
resize_norm_hidden_states.shape[1], | |
-1, | |
) | |
) | |
masked_norm_hidden_states = masked_norm_hidden_states.permute( | |
0, 2, 1 | |
) | |
self.bank.append(masked_norm_hidden_states) | |
# self.bank.append(norm_hidden_states.detach().clone()) | |
attn_output = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states | |
if self.only_cross_attention | |
else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
if self.MODE == "read": | |
if self.attention_auto_machine_weight > self.attn_weight: | |
# scale_ratio = ((self.ref_mask.shape[2] * self.ref_mask.shape[3])/norm_hidden_states.shape[1])**0.5 | |
# print(scale_ratio) | |
# this_ref_mask = F.interpolate(self.ref_mask.to(norm_hidden_states.device), scale_factor=1/scale_ratio).view(1,1,-1) | |
# print("resized mask", this_ref_mask.shape, this_ref_mask.max(), this_ref_mask.min(), this_ref_mask.sum()) | |
# ref_hidden_states = torch.cat([norm_hidden_states] + self.bank, dim=1) | |
# if attention_mask is None: | |
# attention_mask = torch.ones( | |
# norm_hidden_states.shape[0], norm_hidden_states.shape[1], ref_hidden_states.shape[1], dtype=norm_hidden_states.dtype, device=norm_hidden_states.device | |
# ) | |
# this_ref_mask = this_ref_mask.repeat(norm_hidden_states.shape[0], norm_hidden_states.shape[1], 1) | |
# this_ref_mask = torch.zeros( | |
# norm_hidden_states.shape[0], norm_hidden_states.shape[1], this_ref_mask.shape[1], dtype=norm_hidden_states.dtype, device=norm_hidden_states.device | |
# ) | |
# print(attention_mask.shape, this_ref_mask.shape) | |
# attention_mask = torch.cat((attention_mask, this_ref_mask), dim=-1) | |
# print("merge", attention_mask.shape) | |
ref_hidden_states = torch.cat( | |
[norm_hidden_states] + self.bank, dim=1 | |
) | |
attn_output_uc = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=ref_hidden_states, | |
# attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
attn_output_c = attn_output_uc.clone() | |
if self.do_classifier_free_guidance and self.style_fidelity > 0: | |
attn_output_c[self.uc_mask] = self.attn1( | |
norm_hidden_states[self.uc_mask], | |
encoder_hidden_states=norm_hidden_states[self.uc_mask], | |
**cross_attention_kwargs, | |
) | |
attn_output = ( | |
self.style_fidelity * attn_output_c | |
+ (1.0 - self.style_fidelity) * attn_output_uc | |
) | |
self.bank.clear() | |
else: | |
attn_output = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states | |
if self.only_cross_attention | |
else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
if self.use_ada_layer_norm_zero: | |
attn_output = gate_msa.unsqueeze(1) * attn_output | |
hidden_states = attn_output + hidden_states | |
if self.attn2 is not None: | |
norm_hidden_states = ( | |
self.norm2(hidden_states, timestep) | |
if self.use_ada_layer_norm | |
else self.norm2(hidden_states) | |
) | |
# 2. Cross-Attention | |
attn_output = self.attn2( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=encoder_attention_mask, | |
**cross_attention_kwargs, | |
) | |
hidden_states = attn_output + hidden_states | |
# 3. Feed-forward | |
norm_hidden_states = self.norm3(hidden_states) | |
if self.use_ada_layer_norm_zero: | |
norm_hidden_states = ( | |
norm_hidden_states * | |
(1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
) | |
ff_output = self.ff(norm_hidden_states) | |
if self.use_ada_layer_norm_zero: | |
ff_output = gate_mlp.unsqueeze(1) * ff_output | |
hidden_states = ff_output + hidden_states | |
return hidden_states | |
def hacked_mid_forward(self, *args, **kwargs): | |
eps = 1e-6 | |
x = self.original_forward(*args, **kwargs) | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / x.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(x.device), scale_factor=1 / scale_ratio | |
) | |
this_ref_mask = this_ref_mask.repeat( | |
x.shape[0], x.shape[1], 1, 1 | |
).bool() | |
masked_x = ( | |
x[this_ref_mask] | |
.detach() | |
.clone() | |
.view(x.shape[0], x.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_x, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank.append(mean) | |
self.var_bank.append(var) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank) > 0 | |
and len(self.var_bank) > 0 | |
): | |
# print("hacked_mid_forward") | |
scale_ratio = self.inpaint_mask.shape[2] / x.shape[2] | |
this_inpaint_mask = F.interpolate( | |
self.inpaint_mask.to(x.device), scale_factor=1 / scale_ratio | |
) | |
this_inpaint_mask = this_inpaint_mask.repeat( | |
x.shape[0], x.shape[1], 1, 1 | |
).bool() | |
masked_x = ( | |
x[this_inpaint_mask] | |
.detach() | |
.clone() | |
.view(x.shape[0], x.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_x, dim=(2, 3), keepdim=True, correction=0 | |
) | |
std = torch.maximum( | |
var, torch.zeros_like(var) + eps) ** 0.5 | |
mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) | |
var_acc = sum(self.var_bank) / float(len(self.var_bank)) | |
std_acc = ( | |
torch.maximum(var_acc, torch.zeros_like( | |
var_acc) + eps) ** 0.5 | |
) | |
x_uc = (((masked_x - mean) / std) * std_acc) + mean_acc | |
x_c = x_uc.clone() | |
if self.do_classifier_free_guidance and self.style_fidelity > 0: | |
x_c[self.uc_mask] = masked_x[self.uc_mask] | |
masked_x = self.style_fidelity * x_c + \ | |
(1.0 - self.style_fidelity) * x_uc | |
x[this_inpaint_mask] = masked_x.view(-1) | |
self.mean_bank = [] | |
self.var_bank = [] | |
return x | |
def hack_CrossAttnDownBlock2D_forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
): | |
eps = 1e-6 | |
# TODO(Patrick, William) - attention mask is not used | |
output_states = () | |
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): | |
hidden_states = resnet(hidden_states, temb) | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank0.append(mean) | |
self.var_bank0.append(var) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank0) > 0 | |
and len(self.var_bank0) > 0 | |
): | |
# print("hacked_CrossAttnDownBlock2D_forward0") | |
scale_ratio = self.inpaint_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_inpaint_mask = F.interpolate( | |
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio | |
) | |
this_inpaint_mask = this_inpaint_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_inpaint_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
std = torch.maximum( | |
var, torch.zeros_like(var) + eps) ** 0.5 | |
mean_acc = sum(self.mean_bank0[i]) / float( | |
len(self.mean_bank0[i]) | |
) | |
var_acc = sum( | |
self.var_bank0[i]) / float(len(self.var_bank0[i])) | |
std_acc = ( | |
torch.maximum( | |
var_acc, torch.zeros_like(var_acc) + eps) | |
** 0.5 | |
) | |
hidden_states_uc = ( | |
((masked_hidden_states - mean) / std) * std_acc | |
) + mean_acc | |
hidden_states_c = hidden_states_uc.clone() | |
if self.do_classifier_free_guidance and self.style_fidelity > 0: | |
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] | |
masked_hidden_states = ( | |
self.style_fidelity * hidden_states_c | |
+ (1.0 - self.style_fidelity) * hidden_states_uc | |
) | |
hidden_states[this_inpaint_mask] = masked_hidden_states.view( | |
-1) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
# attention_mask=attention_mask, | |
# encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank.append(mean) | |
self.var_bank.append(var) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank) > 0 | |
and len(self.var_bank) > 0 | |
): | |
# print("hack_CrossAttnDownBlock2D_forward") | |
scale_ratio = self.inpaint_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_inpaint_mask = F.interpolate( | |
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio | |
) | |
this_inpaint_mask = this_inpaint_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_inpaint_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
std = torch.maximum( | |
var, torch.zeros_like(var) + eps) ** 0.5 | |
mean_acc = sum(self.mean_bank[i]) / float( | |
len(self.mean_bank[i]) | |
) | |
var_acc = sum( | |
self.var_bank[i]) / float(len(self.var_bank[i])) | |
std_acc = ( | |
torch.maximum( | |
var_acc, torch.zeros_like(var_acc) + eps) | |
** 0.5 | |
) | |
hidden_states_uc = ( | |
((masked_hidden_states - mean) / std) * std_acc | |
) + mean_acc | |
hidden_states_c = hidden_states_uc.clone() | |
if self.do_classifier_free_guidance and self.style_fidelity > 0: | |
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] | |
masked_hidden_states = ( | |
self.style_fidelity * hidden_states_c | |
+ (1.0 - self.style_fidelity) * hidden_states_uc | |
) | |
hidden_states[this_inpaint_mask] = masked_hidden_states.view( | |
-1) | |
output_states = output_states + (hidden_states,) | |
if self.MODE == "read": | |
self.mean_bank0 = [] | |
self.var_bank0 = [] | |
self.mean_bank = [] | |
self.var_bank = [] | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states = output_states + (hidden_states,) | |
return hidden_states, output_states | |
def hacked_DownBlock2D_forward(self, hidden_states, temb=None): | |
eps = 1e-6 | |
output_states = () | |
for i, resnet in enumerate(self.resnets): | |
hidden_states = resnet(hidden_states, temb) | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank.append(mean) | |
self.var_bank.append(var) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank) > 0 | |
and len(self.var_bank) > 0 | |
): | |
# print("hacked_DownBlock2D_forward") | |
scale_ratio = self.inpaint_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_inpaint_mask = F.interpolate( | |
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio | |
) | |
this_inpaint_mask = this_inpaint_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_inpaint_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
std = torch.maximum( | |
var, torch.zeros_like(var) + eps) ** 0.5 | |
mean_acc = sum(self.mean_bank[i]) / float( | |
len(self.mean_bank[i]) | |
) | |
var_acc = sum( | |
self.var_bank[i]) / float(len(self.var_bank[i])) | |
std_acc = ( | |
torch.maximum( | |
var_acc, torch.zeros_like(var_acc) + eps) | |
** 0.5 | |
) | |
hidden_states_uc = ( | |
((masked_hidden_states - mean) / std) * std_acc | |
) + mean_acc | |
hidden_states_c = hidden_states_uc.clone() | |
if self.do_classifier_free_guidance and self.style_fidelity > 0: | |
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] | |
masked_hidden_states = ( | |
self.style_fidelity * hidden_states_c | |
+ (1.0 - self.style_fidelity) * hidden_states_uc | |
) | |
hidden_states[this_inpaint_mask] = masked_hidden_states.view( | |
-1) | |
output_states = output_states + (hidden_states,) | |
if self.MODE == "read": | |
self.mean_bank = [] | |
self.var_bank = [] | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states = output_states + (hidden_states,) | |
return hidden_states, output_states | |
def hacked_CrossAttnUpBlock2D_forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
upsample_size: Optional[int] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
): | |
eps = 1e-6 | |
# TODO(Patrick, William) - attention mask is not used | |
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = torch.cat( | |
[hidden_states, res_hidden_states], dim=1) | |
hidden_states = resnet(hidden_states, temb) | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank0.append(mean) | |
self.var_bank0.append(var) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank0) > 0 | |
and len(self.var_bank0) > 0 | |
): | |
# print("hacked_CrossAttnUpBlock2D_forward1") | |
scale_ratio = self.inpaint_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_inpaint_mask = F.interpolate( | |
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio | |
) | |
this_inpaint_mask = this_inpaint_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_inpaint_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
std = torch.maximum( | |
var, torch.zeros_like(var) + eps) ** 0.5 | |
mean_acc = sum(self.mean_bank0[i]) / float( | |
len(self.mean_bank0[i]) | |
) | |
var_acc = sum( | |
self.var_bank0[i]) / float(len(self.var_bank0[i])) | |
std_acc = ( | |
torch.maximum( | |
var_acc, torch.zeros_like(var_acc) + eps) | |
** 0.5 | |
) | |
hidden_states_uc = ( | |
((masked_hidden_states - mean) / std) * std_acc | |
) + mean_acc | |
hidden_states_c = hidden_states_uc.clone() | |
if self.do_classifier_free_guidance and self.style_fidelity > 0: | |
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] | |
masked_hidden_states = ( | |
self.style_fidelity * hidden_states_c | |
+ (1.0 - self.style_fidelity) * hidden_states_uc | |
) | |
hidden_states[this_inpaint_mask] = masked_hidden_states.view( | |
-1) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
# attention_mask=attention_mask, | |
# encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank.append(mean) | |
self.var_bank.append(var) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank) > 0 | |
and len(self.var_bank) > 0 | |
): | |
# print("hacked_CrossAttnUpBlock2D_forward") | |
scale_ratio = self.inpaint_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_inpaint_mask = F.interpolate( | |
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio | |
) | |
this_inpaint_mask = this_inpaint_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_inpaint_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
std = torch.maximum( | |
var, torch.zeros_like(var) + eps) ** 0.5 | |
mean_acc = sum(self.mean_bank[i]) / float( | |
len(self.mean_bank[i]) | |
) | |
var_acc = sum( | |
self.var_bank[i]) / float(len(self.var_bank[i])) | |
std_acc = ( | |
torch.maximum( | |
var_acc, torch.zeros_like(var_acc) + eps) | |
** 0.5 | |
) | |
hidden_states_uc = ( | |
((masked_hidden_states - mean) / std) * std_acc | |
) + mean_acc | |
hidden_states_c = hidden_states_uc.clone() | |
if self.do_classifier_free_guidance and self.style_fidelity > 0: | |
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] | |
masked_hidden_states = ( | |
self.style_fidelity * hidden_states_c | |
+ (1.0 - self.style_fidelity) * hidden_states_uc | |
) | |
hidden_states[this_inpaint_mask] = masked_hidden_states.view( | |
-1) | |
if self.MODE == "read": | |
self.mean_bank0 = [] | |
self.var_bank0 = [] | |
self.mean_bank = [] | |
self.var_bank = [] | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size) | |
return hidden_states | |
def hacked_UpBlock2D_forward( | |
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None | |
): | |
eps = 1e-6 | |
for i, resnet in enumerate(self.resnets): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = torch.cat( | |
[hidden_states, res_hidden_states], dim=1) | |
hidden_states = resnet(hidden_states, temb) | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank.append(mean) | |
self.var_bank.append(var) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank) > 0 | |
and len(self.var_bank) > 0 | |
): | |
# print("hacked_UpBlock2D_forward") | |
scale_ratio = self.inpaint_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_inpaint_mask = F.interpolate( | |
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio | |
) | |
this_inpaint_mask = this_inpaint_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_inpaint_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
std = torch.maximum( | |
var, torch.zeros_like(var) + eps) ** 0.5 | |
mean_acc = sum(self.mean_bank[i]) / float( | |
len(self.mean_bank[i]) | |
) | |
var_acc = sum( | |
self.var_bank[i]) / float(len(self.var_bank[i])) | |
std_acc = ( | |
torch.maximum( | |
var_acc, torch.zeros_like(var_acc) + eps) | |
** 0.5 | |
) | |
hidden_states_uc = ( | |
((masked_hidden_states - mean) / std) * std_acc | |
) + mean_acc | |
hidden_states_c = hidden_states_uc.clone() | |
if self.do_classifier_free_guidance and self.style_fidelity > 0: | |
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] | |
masked_hidden_states = ( | |
self.style_fidelity * hidden_states_c | |
+ (1.0 - self.style_fidelity) * hidden_states_uc | |
) | |
hidden_states[this_inpaint_mask] = masked_hidden_states.view( | |
-1) | |
if self.MODE == "read": | |
self.mean_bank = [] | |
self.var_bank = [] | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size) | |
return hidden_states | |
if model_type == "unet": | |
if reference_attn: | |
attn_modules = [ | |
module | |
for module in torch_dfs(model) | |
if isinstance(module, BasicTransformerBlock) | |
] | |
attn_modules = sorted( | |
attn_modules, key=lambda x: -x.norm1.normalized_shape[0] | |
) | |
for i, module in enumerate(attn_modules): | |
module._original_inner_forward = module.forward | |
module.forward = hacked_basic_transformer_inner_forward.__get__( | |
module, BasicTransformerBlock | |
) | |
module.bank = [] | |
module.attn_weight = float(i) / float(len(attn_modules)) | |
module.attention_auto_machine_weight = ( | |
self.attention_auto_machine_weight | |
) | |
module.gn_auto_machine_weight = self.gn_auto_machine_weight | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.uc_mask = self.uc_mask | |
module.style_fidelity = self.style_fidelity | |
module.ref_mask = self.ref_mask | |
else: | |
attn_modules = None | |
if reference_adain: | |
gn_modules = [model.mid_block] | |
model.mid_block.gn_weight = 0 | |
down_blocks = model.down_blocks | |
for w, module in enumerate(down_blocks): | |
module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) | |
gn_modules.append(module) | |
# print(module.__class__.__name__,module.gn_weight) | |
up_blocks = model.up_blocks | |
for w, module in enumerate(up_blocks): | |
module.gn_weight = float(w) / float(len(up_blocks)) | |
gn_modules.append(module) | |
# print(module.__class__.__name__,module.gn_weight) | |
for i, module in enumerate(gn_modules): | |
if getattr(module, "original_forward", None) is None: | |
module.original_forward = module.forward | |
if i == 0: | |
# mid_block | |
module.forward = hacked_mid_forward.__get__( | |
module, torch.nn.Module | |
) | |
elif isinstance(module, CrossAttnDownBlock2D): | |
module.forward = hack_CrossAttnDownBlock2D_forward.__get__( | |
module, CrossAttnDownBlock2D | |
) | |
module.mean_bank0 = [] | |
module.var_bank0 = [] | |
elif isinstance(module, DownBlock2D): | |
module.forward = hacked_DownBlock2D_forward.__get__( | |
module, DownBlock2D | |
) | |
# elif isinstance(module, CrossAttnUpBlock2D): | |
# module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) | |
# module.mean_bank0 = [] | |
# module.var_bank0 = [] | |
elif isinstance(module, UpBlock2D): | |
module.forward = hacked_UpBlock2D_forward.__get__( | |
module, UpBlock2D | |
) | |
module.mean_bank0 = [] | |
module.var_bank0 = [] | |
module.mean_bank = [] | |
module.var_bank = [] | |
module.attention_auto_machine_weight = ( | |
self.attention_auto_machine_weight | |
) | |
module.gn_auto_machine_weight = self.gn_auto_machine_weight | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.uc_mask = self.uc_mask | |
module.style_fidelity = self.style_fidelity | |
module.ref_mask = self.ref_mask | |
module.inpaint_mask = self.inpaint_mask | |
else: | |
gn_modules = None | |
elif model_type == "controlnet": | |
model = model.nets[-1] # only hack the inpainting controlnet | |
if reference_attn: | |
attn_modules = [ | |
module | |
for module in torch_dfs(model) | |
if isinstance(module, BasicTransformerBlock) | |
] | |
attn_modules = sorted( | |
attn_modules, key=lambda x: -x.norm1.normalized_shape[0] | |
) | |
for i, module in enumerate(attn_modules): | |
module._original_inner_forward = module.forward | |
module.forward = hacked_basic_transformer_inner_forward.__get__( | |
module, BasicTransformerBlock | |
) | |
module.bank = [] | |
# float(i) / float(len(attn_modules)) | |
module.attn_weight = 0.0 | |
module.attention_auto_machine_weight = ( | |
self.attention_auto_machine_weight | |
) | |
module.gn_auto_machine_weight = self.gn_auto_machine_weight | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.uc_mask = self.uc_mask | |
module.style_fidelity = self.style_fidelity | |
module.ref_mask = self.ref_mask | |
else: | |
attn_modules = None | |
gn_modules = None | |
return attn_modules, gn_modules | |
def change_module_mode(self, mode, attn_modules, gn_modules): | |
if attn_modules is not None: | |
for i, module in enumerate(attn_modules): | |
module.MODE = mode | |
if gn_modules is not None: | |
for i, module in enumerate(gn_modules): | |
module.MODE = mode | |