Spaces:
Runtime error
Runtime error
# Based on https://raw.githubusercontent.com/okotaku/diffusers/feature/reference_only_control/examples/community/stable_diffusion_reference.py | |
# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 | |
import torch.fft as fft | |
from typing import Any, Callable, Dict, List, Optional, Union, Tuple | |
import numpy as np | |
import PIL.Image | |
import torch | |
from diffusers.models.attention import BasicTransformerBlock | |
from diffusers.models.unet_2d_blocks import ( | |
CrossAttnDownBlock2D, | |
CrossAttnUpBlock2D, | |
DownBlock2D, | |
UpBlock2D, | |
) | |
from diffusers.utils import PIL_INTERPOLATION, logging | |
import torch.nn.functional as F | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
EXAMPLE_DOC_STRING = """ | |
Examples: | |
```py | |
>>> import torch | |
>>> from diffusers import UniPCMultistepScheduler | |
>>> from diffusers.utils import load_image | |
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") | |
>>> pipe = StableDiffusionReferencePipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
safety_checker=None, | |
torch_dtype=torch.float16 | |
).to('cuda:0') | |
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config) | |
>>> result_img = pipe(ref_image=input_image, | |
prompt="1girl", | |
num_inference_steps=20, | |
reference_attn=True, | |
reference_adain=True).images[0] | |
>>> result_img.show() | |
``` | |
""" | |
def torch_dfs(model: torch.nn.Module): | |
result = [model] | |
for child in model.children(): | |
result += torch_dfs(child) | |
return result | |
def add_freq_feature(feature1, feature2, ref_ratio): | |
""" | |
feature1: reference feature | |
feature2: target feature | |
ref_ratio: larger ratio means larger reference frequency | |
""" | |
# Convert features to float32 (if not already) for compatibility with fft operations | |
data_type = feature2.dtype | |
feature1 = feature1.to(torch.float32) | |
feature2 = feature2.to(torch.float32) | |
# Compute the Fourier transforms of both features | |
spectrum1 = fft.fftn(feature1, dim=(-2, -1)) | |
spectrum2 = fft.fftn(feature2, dim=(-2, -1)) | |
# Extract high-frequency magnitude and phase from feature1 | |
magnitude1 = torch.abs(spectrum1) | |
# phase1 = torch.angle(spectrum1) | |
# Extract magnitude and phase from feature2 | |
magnitude2 = torch.abs(spectrum2) | |
phase2 = torch.angle(spectrum2) | |
magnitude2.mul_((1-ref_ratio)).add_(magnitude1 * ref_ratio) | |
# phase2.mul_(1.0).add_(phase1 * 0.0) | |
# Combine magnitude and phase information | |
mixed_spectrum = torch.polar(magnitude2, phase2) | |
# Compute the inverse Fourier transform to get the mixed feature | |
mixed_feature = fft.ifftn(mixed_spectrum, dim=(-2, -1)) | |
del feature1, feature2, spectrum1, spectrum2, magnitude1, magnitude2, phase2, mixed_spectrum | |
# Convert back to the original data type and return the result | |
return mixed_feature.to(data_type) | |
def save_ref_feature(feature, mask): | |
""" | |
feature: n,c,h,w | |
mask: n,1,h,w | |
return n,c,h,w | |
""" | |
return feature * mask | |
def mix_ref_feature(feature, ref_fea_bank, cfg=True, ref_scale=0.0, dim3=False): | |
""" | |
feature: n,l,c or n,c,h,w | |
ref_fea_bank: [(n,c,h,w)] | |
cfg: True/False | |
return n,l,c or n,c,h,w | |
""" | |
if cfg: | |
ref_fea = torch.cat( | |
(ref_fea_bank+ref_fea_bank), dim=0) | |
else: | |
ref_fea = ref_fea_bank | |
if dim3: | |
feature = feature.permute(0, 2, 1).view(ref_fea.shape) | |
mixed_feature = add_freq_feature(ref_fea, feature, ref_scale) | |
if dim3: | |
mixed_feature = mixed_feature.view( | |
ref_fea.shape[0], ref_fea.shape[1], -1).permute(0, 2, 1) | |
del ref_fea | |
del feature | |
return mixed_feature | |
def mix_norm_feature(x, inpaint_mask, mean_bank, var_bank, do_classifier_free_guidance, style_fidelity, uc_mask, eps=1e-6): | |
""" | |
x: input feature n,c,h,w | |
inpaint_mask: mask region to inpain | |
""" | |
# get the inpainting region and only mix this region. | |
scale_ratio = inpaint_mask.shape[2] / x.shape[2] | |
this_inpaint_mask = F.interpolate( | |
inpaint_mask.to(x.device), scale_factor=1 / scale_ratio | |
) | |
this_inpaint_mask = this_inpaint_mask.repeat( | |
x.shape[0], x.shape[1], 1, 1 | |
).bool() | |
masked_x = ( | |
x[this_inpaint_mask] | |
.detach() | |
.clone() | |
.view(x.shape[0], x.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_x, dim=(2, 3), keepdim=True, correction=0 | |
) | |
std = torch.maximum( | |
var, torch.zeros_like(var) + eps) ** 0.5 | |
mean_acc = sum(mean_bank) / float(len(mean_bank)) | |
var_acc = sum(var_bank) / float(len(var_bank)) | |
std_acc = ( | |
torch.maximum(var_acc, torch.zeros_like( | |
var_acc) + eps) ** 0.5 | |
) | |
x_uc = (((masked_x - mean) / std) * std_acc) + mean_acc | |
x_c = x_uc.clone() | |
if do_classifier_free_guidance and style_fidelity > 0: | |
x_c[uc_mask] = masked_x[uc_mask] | |
masked_x = style_fidelity * x_c + \ | |
(1.0 - style_fidelity) * x_uc | |
x[this_inpaint_mask] = masked_x.view(-1) | |
return x | |
class StableDiffusionReferencePipeline: | |
def prepare_ref_image( | |
self, | |
image, | |
width, | |
height, | |
batch_size, | |
num_images_per_prompt, | |
device, | |
dtype, | |
do_classifier_free_guidance=False, | |
guess_mode=False, | |
): | |
if not isinstance(image, torch.Tensor): | |
if isinstance(image, PIL.Image.Image): | |
image = [image] | |
if isinstance(image[0], PIL.Image.Image): | |
images = [] | |
for image_ in image: | |
image_ = image_.convert("RGB") | |
image_ = image_.resize( | |
(width, height), resample=PIL_INTERPOLATION["lanczos"] | |
) | |
image_ = np.array(image_) | |
image_ = image_[None, :] | |
images.append(image_) | |
image = images | |
image = np.concatenate(image, axis=0) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = (image - 0.5) / 0.5 | |
image = image.transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image) | |
elif isinstance(image[0], torch.Tensor): | |
image = torch.cat(image, dim=0) | |
image_batch_size = image.shape[0] | |
if image_batch_size == 1: | |
repeat_by = batch_size | |
else: | |
# image batch size is the same as prompt batch size | |
repeat_by = num_images_per_prompt | |
image = image.repeat_interleave(repeat_by, dim=0) | |
image = image.to(device=device, dtype=dtype) | |
if do_classifier_free_guidance and not guess_mode: | |
image = torch.cat([image] * 2) | |
return image | |
def prepare_ref_latents( | |
self, | |
refimage, | |
batch_size, | |
dtype, | |
device, | |
generator, | |
do_classifier_free_guidance, | |
): | |
refimage = refimage.to(device=device, dtype=dtype) | |
# encode the mask image into latents space so we can concatenate it to the latents | |
if isinstance(generator, list): | |
ref_image_latents = [ | |
self.vae.encode(refimage[i: i + 1]).latent_dist.sample( | |
generator=generator[i] | |
) | |
for i in range(batch_size) | |
] | |
ref_image_latents = torch.cat(ref_image_latents, dim=0) | |
else: | |
ref_image_latents = self.vae.encode(refimage).latent_dist.sample( | |
generator=generator | |
) | |
ref_image_latents = self.vae.config.scaling_factor * ref_image_latents | |
# duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method | |
if ref_image_latents.shape[0] < batch_size: | |
if not batch_size % ref_image_latents.shape[0] == 0: | |
raise ValueError( | |
"The passed images and the required batch size don't match. Images are supposed to be duplicated" | |
f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed." | |
" Make sure the number of images that you pass is divisible by the total requested batch size." | |
) | |
ref_image_latents = ref_image_latents.repeat( | |
batch_size // ref_image_latents.shape[0], 1, 1, 1 | |
) | |
ref_image_latents = ( | |
torch.cat([ref_image_latents] * 2) | |
if do_classifier_free_guidance | |
else ref_image_latents | |
) | |
# aligning device to prevent device errors when concating it with the latent model input | |
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) | |
return ref_image_latents | |
def check_ref_input(self, reference_attn, reference_adain): | |
assert ( | |
reference_attn or reference_adain | |
), "`reference_attn` or `reference_adain` must be True." | |
def redefine_ref_model( | |
self, model, reference_attn, reference_adain, model_type="unet" | |
): | |
def hacked_basic_transformer_inner_forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
timestep: Optional[torch.LongTensor] = None, | |
cross_attention_kwargs: Dict[str, Any] = None, | |
class_labels: Optional[torch.LongTensor] = None, | |
): | |
if self.use_ada_layer_norm: | |
norm_hidden_states = self.norm1(hidden_states, timestep) | |
elif self.use_ada_layer_norm_zero: | |
( | |
norm_hidden_states, | |
gate_msa, | |
shift_mlp, | |
scale_mlp, | |
gate_mlp, | |
) = self.norm1( | |
hidden_states, | |
timestep, | |
class_labels, | |
hidden_dtype=hidden_states.dtype, | |
) | |
else: | |
norm_hidden_states = self.norm1(hidden_states) | |
# 1. Self-Attention | |
cross_attention_kwargs = ( | |
cross_attention_kwargs if cross_attention_kwargs is not None else {} | |
) | |
if self.only_cross_attention: | |
attn_output = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states | |
if self.only_cross_attention | |
else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
else: | |
if self.MODE == "write": | |
if self.attention_auto_machine_weight > self.attn_weight: | |
# print("hacked_basic_transformer_inner_forward") | |
scale_ratio = ( | |
(self.ref_mask.shape[2] * self.ref_mask.shape[3]) | |
/ norm_hidden_states.shape[1] | |
) ** 0.5 | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(norm_hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
resize_norm_hidden_states = norm_hidden_states.view( | |
norm_hidden_states.shape[0], | |
this_ref_mask.shape[2], | |
this_ref_mask.shape[3], | |
-1, | |
).permute(0, 3, 1, 2) | |
ref_scale = 1.0 | |
resize_norm_hidden_states = F.interpolate( | |
resize_norm_hidden_states, | |
scale_factor=ref_scale, | |
mode="bilinear", | |
) | |
this_ref_mask = F.interpolate( | |
this_ref_mask, scale_factor=ref_scale | |
) | |
self.fea_bank.append(save_ref_feature( | |
resize_norm_hidden_states, this_ref_mask)) | |
this_ref_mask = this_ref_mask.repeat( | |
resize_norm_hidden_states.shape[0], | |
resize_norm_hidden_states.shape[1], | |
1, | |
1, | |
).bool() | |
masked_norm_hidden_states = ( | |
resize_norm_hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view( | |
resize_norm_hidden_states.shape[0], | |
resize_norm_hidden_states.shape[1], | |
-1, | |
) | |
) | |
masked_norm_hidden_states = masked_norm_hidden_states.permute( | |
0, 2, 1 | |
) | |
self.bank.append(masked_norm_hidden_states) | |
del masked_norm_hidden_states | |
del this_ref_mask | |
del resize_norm_hidden_states | |
attn_output = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states | |
if self.only_cross_attention | |
else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
if self.MODE == "read": | |
if self.attention_auto_machine_weight > self.attn_weight: | |
freq_norm_hidden_states = mix_ref_feature( | |
norm_hidden_states, | |
self.fea_bank, | |
cfg=self.do_classifier_free_guidance, | |
ref_scale=self.ref_scale, | |
dim3=True) | |
self.fea_bank.clear() | |
this_bank = torch.cat(self.bank+self.bank, dim=0) | |
ref_hidden_states = torch.cat( | |
(freq_norm_hidden_states, this_bank), dim=1 | |
) | |
del this_bank | |
self.bank.clear() | |
attn_output_uc = self.attn1( | |
freq_norm_hidden_states, | |
encoder_hidden_states=ref_hidden_states, | |
**cross_attention_kwargs, | |
) | |
del ref_hidden_states | |
attn_output_c = attn_output_uc.clone() | |
if self.do_classifier_free_guidance and self.style_fidelity > 0: | |
attn_output_c[self.uc_mask] = self.attn1( | |
norm_hidden_states[self.uc_mask], | |
encoder_hidden_states=norm_hidden_states[self.uc_mask], | |
**cross_attention_kwargs, | |
) | |
attn_output = ( | |
self.style_fidelity * attn_output_c | |
+ (1.0 - self.style_fidelity) * attn_output_uc | |
) | |
self.bank.clear() | |
self.fea_bank.clear() | |
del attn_output_c | |
del attn_output_uc | |
else: | |
attn_output = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states | |
if self.only_cross_attention | |
else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
self.bank.clear() | |
self.fea_bank.clear() | |
if self.use_ada_layer_norm_zero: | |
attn_output = gate_msa.unsqueeze(1) * attn_output | |
hidden_states = attn_output + hidden_states | |
if self.attn2 is not None: | |
norm_hidden_states = ( | |
self.norm2(hidden_states, timestep) | |
if self.use_ada_layer_norm | |
else self.norm2(hidden_states) | |
) | |
# 2. Cross-Attention | |
attn_output = self.attn2( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=encoder_attention_mask, | |
**cross_attention_kwargs, | |
) | |
hidden_states = attn_output + hidden_states | |
# 3. Feed-forward | |
norm_hidden_states = self.norm3(hidden_states) | |
if self.use_ada_layer_norm_zero: | |
norm_hidden_states = ( | |
norm_hidden_states * | |
(1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
) | |
ff_output = self.ff(norm_hidden_states) | |
if self.use_ada_layer_norm_zero: | |
ff_output = gate_mlp.unsqueeze(1) * ff_output | |
hidden_states = ff_output + hidden_states | |
return hidden_states | |
def hacked_mid_forward(self, *args, **kwargs): | |
eps = 1e-6 | |
x = self.original_forward(*args, **kwargs) | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / x.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(x.device), scale_factor=1 / scale_ratio | |
) | |
self.fea_bank.append(save_ref_feature( | |
x, this_ref_mask)) | |
this_ref_mask = this_ref_mask.repeat( | |
x.shape[0], x.shape[1], 1, 1 | |
).bool() | |
masked_x = ( | |
x[this_ref_mask] | |
.detach() | |
.clone() | |
.view(x.shape[0], x.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_x, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank.append(torch.cat([mean]*2, dim=0)) | |
self.var_bank.append(torch.cat([var]*2, dim=0)) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank) > 0 | |
and len(self.var_bank) > 0 | |
): | |
# print("hacked_mid_forward") | |
x = mix_ref_feature( | |
x, self.fea_bank, cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
self.fea_bank = [] | |
x = mix_norm_feature(x, self.inpaint_mask, self.mean_bank, self.var_bank, | |
self.do_classifier_free_guidance, | |
self.style_fidelity, self.uc_mask) | |
self.mean_bank = [] | |
self.var_bank = [] | |
return x | |
def hack_CrossAttnDownBlock2D_forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
): | |
eps = 1e-6 | |
# TODO(Patrick, William) - attention mask is not used | |
output_states = () | |
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): | |
hidden_states = resnet(hidden_states, temb) | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
self.fea_bank0.append(save_ref_feature( | |
hidden_states, this_ref_mask)) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank0.append(torch.cat([mean]*2, dim=0)) | |
self.var_bank0.append(torch.cat([var]*2, dim=0)) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank0) > 0 | |
and len(self.var_bank0) > 0 | |
): | |
# print("hacked_CrossAttnDownBlock2D_forward0") | |
hidden_states = mix_ref_feature( | |
hidden_states, [self.fea_bank0[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank0[i], self.var_bank0[i], | |
self.do_classifier_free_guidance, | |
self.style_fidelity, self.uc_mask) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
return_dict=False, | |
)[0] | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
self.fea_bank.append(save_ref_feature( | |
hidden_states, this_ref_mask)) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank.append(torch.cat([mean]*2, dim=0)) | |
self.var_bank.append(torch.cat([var]*2, dim=0)) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank) > 0 | |
and len(self.var_bank) > 0 | |
): | |
# print("hack_CrossAttnDownBlock2D_forward") | |
hidden_states = mix_ref_feature( | |
hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i], | |
self.do_classifier_free_guidance, | |
self.style_fidelity, self.uc_mask) | |
output_states = output_states + (hidden_states,) | |
if self.MODE == "read": | |
self.mean_bank0 = [] | |
self.var_bank0 = [] | |
self.mean_bank = [] | |
self.var_bank = [] | |
self.fea_bank0 = [] | |
self.fea_bank = [] | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states = output_states + (hidden_states,) | |
return hidden_states, output_states | |
def hacked_DownBlock2D_forward(self, hidden_states, temb=None): | |
eps = 1e-6 | |
output_states = () | |
for i, resnet in enumerate(self.resnets): | |
hidden_states = resnet(hidden_states, temb) | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
self.fea_bank.append(save_ref_feature( | |
hidden_states, this_ref_mask)) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank.append(torch.cat([mean]*2, dim=0)) | |
self.var_bank.append(torch.cat([var]*2, dim=0)) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank) > 0 | |
and len(self.var_bank) > 0 | |
): | |
# print("hacked_DownBlock2D_forward") | |
hidden_states = mix_ref_feature( | |
hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i], | |
self.do_classifier_free_guidance, | |
self.style_fidelity, self.uc_mask) | |
output_states = output_states + (hidden_states,) | |
if self.MODE == "read": | |
self.mean_bank = [] | |
self.var_bank = [] | |
self.fea_bank = [] | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states = output_states + (hidden_states,) | |
return hidden_states, output_states | |
def hacked_CrossAttnUpBlock2D_forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
upsample_size: Optional[int] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
): | |
eps = 1e-6 | |
# TODO(Patrick, William) - attention mask is not used | |
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = torch.cat( | |
[hidden_states, res_hidden_states], dim=1) | |
hidden_states = resnet(hidden_states, temb) | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
self.fea_bank0.append(save_ref_feature( | |
hidden_states, this_ref_mask)) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank0.append(torch.cat([mean]*2, dim=0)) | |
self.var_bank0.append(torch.cat([var]*2, dim=0)) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank0) > 0 | |
and len(self.var_bank0) > 0 | |
): | |
# print("hacked_CrossAttnUpBlock2D_forward1") | |
hidden_states = mix_ref_feature( | |
hidden_states, [self.fea_bank0[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank0[i], self.var_bank0[i], | |
self.do_classifier_free_guidance, | |
self.style_fidelity, self.uc_mask) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
# attention_mask=attention_mask, | |
# encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
self.fea_bank.append(save_ref_feature( | |
hidden_states, this_ref_mask)) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank.append(torch.cat([mean]*2, dim=0)) | |
self.var_bank.append(torch.cat([var]*2, dim=0)) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank) > 0 | |
and len(self.var_bank) > 0 | |
): | |
# print("hacked_CrossAttnUpBlock2D_forward") | |
hidden_states = mix_ref_feature( | |
hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i], | |
self.do_classifier_free_guidance, | |
self.style_fidelity, self.uc_mask) | |
if self.MODE == "read": | |
self.mean_bank0 = [] | |
self.var_bank0 = [] | |
self.mean_bank = [] | |
self.var_bank = [] | |
self.fea_bank = [] | |
self.fea_bank0 = [] | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size) | |
return hidden_states | |
def hacked_UpBlock2D_forward( | |
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None | |
): | |
eps = 1e-6 | |
for i, resnet in enumerate(self.resnets): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = torch.cat( | |
[hidden_states, res_hidden_states], dim=1) | |
hidden_states = resnet(hidden_states, temb) | |
if self.MODE == "write": | |
if self.gn_auto_machine_weight >= self.gn_weight: | |
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
# mask var mean | |
scale_ratio = self.ref_mask.shape[2] / \ | |
hidden_states.shape[2] | |
this_ref_mask = F.interpolate( | |
self.ref_mask.to(hidden_states.device), | |
scale_factor=1 / scale_ratio, | |
) | |
self.fea_bank.append(save_ref_feature( | |
hidden_states, this_ref_mask)) | |
this_ref_mask = this_ref_mask.repeat( | |
hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
).bool() | |
masked_hidden_states = ( | |
hidden_states[this_ref_mask] | |
.detach() | |
.clone() | |
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
) | |
var, mean = torch.var_mean( | |
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
) | |
self.mean_bank.append(torch.cat([mean]*2, dim=0)) | |
self.var_bank.append(torch.cat([var]*2, dim=0)) | |
if self.MODE == "read": | |
if ( | |
self.gn_auto_machine_weight >= self.gn_weight | |
and len(self.mean_bank) > 0 | |
and len(self.var_bank) > 0 | |
): | |
# print("hacked_UpBlock2D_forward") | |
hidden_states = mix_ref_feature( | |
hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i], | |
self.do_classifier_free_guidance, | |
self.style_fidelity, self.uc_mask) | |
if self.MODE == "read": | |
self.mean_bank = [] | |
self.var_bank = [] | |
self.fea_bank = [] | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size) | |
return hidden_states | |
if model_type == "unet": | |
if reference_attn: | |
attn_modules = [ | |
module | |
for module in torch_dfs(model) | |
if isinstance(module, BasicTransformerBlock) | |
] | |
attn_modules = sorted( | |
attn_modules, key=lambda x: -x.norm1.normalized_shape[0] | |
) | |
for i, module in enumerate(attn_modules): | |
module._original_inner_forward = module.forward | |
module.forward = hacked_basic_transformer_inner_forward.__get__( | |
module, BasicTransformerBlock | |
) | |
module.bank = [] | |
module.fea_bank = [] | |
module.attn_weight = float(i) / float(len(attn_modules)) | |
module.attention_auto_machine_weight = ( | |
self.attention_auto_machine_weight | |
) | |
module.gn_auto_machine_weight = self.gn_auto_machine_weight | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.uc_mask = self.uc_mask | |
module.style_fidelity = self.style_fidelity | |
module.ref_mask = self.ref_mask | |
module.ref_scale = self.ref_scale | |
else: | |
attn_modules = None | |
if reference_adain: | |
gn_modules = [model.mid_block] | |
model.mid_block.gn_weight = 0 | |
down_blocks = model.down_blocks | |
for w, module in enumerate(down_blocks): | |
module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) | |
gn_modules.append(module) | |
# print(module.__class__.__name__,module.gn_weight) | |
up_blocks = model.up_blocks | |
for w, module in enumerate(up_blocks): | |
module.gn_weight = float(w) / float(len(up_blocks)) | |
gn_modules.append(module) | |
# print(module.__class__.__name__,module.gn_weight) | |
for i, module in enumerate(gn_modules): | |
if getattr(module, "original_forward", None) is None: | |
module.original_forward = module.forward | |
if i == 0: | |
# mid_block | |
module.forward = hacked_mid_forward.__get__( | |
module, torch.nn.Module | |
) | |
# elif isinstance(module, CrossAttnDownBlock2D): | |
# module.forward = hack_CrossAttnDownBlock2D_forward.__get__( | |
# module, CrossAttnDownBlock2D | |
# ) | |
# module.mean_bank0 = [] | |
# module.var_bank0 = [] | |
# module.fea_bank0 = [] | |
elif isinstance(module, DownBlock2D): | |
module.forward = hacked_DownBlock2D_forward.__get__( | |
module, DownBlock2D | |
) | |
# elif isinstance(module, CrossAttnUpBlock2D): | |
# module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) | |
# module.mean_bank0 = [] | |
# module.var_bank0 = [] | |
# module.fea_bank0 = [] | |
elif isinstance(module, UpBlock2D): | |
module.forward = hacked_UpBlock2D_forward.__get__( | |
module, UpBlock2D | |
) | |
module.mean_bank0 = [] | |
module.var_bank0 = [] | |
module.fea_bank0 = [] | |
module.mean_bank = [] | |
module.var_bank = [] | |
module.fea_bank = [] | |
module.attention_auto_machine_weight = ( | |
self.attention_auto_machine_weight | |
) | |
module.gn_auto_machine_weight = self.gn_auto_machine_weight | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.uc_mask = self.uc_mask | |
module.style_fidelity = self.style_fidelity | |
module.ref_mask = self.ref_mask | |
module.inpaint_mask = self.inpaint_mask | |
module.ref_scale = self.ref_scale | |
else: | |
gn_modules = None | |
elif model_type == "controlnet": | |
model = model.nets[-1] # only hack the inpainting controlnet | |
if reference_attn: | |
attn_modules = [ | |
module | |
for module in torch_dfs(model) | |
if isinstance(module, BasicTransformerBlock) | |
] | |
attn_modules = sorted( | |
attn_modules, key=lambda x: -x.norm1.normalized_shape[0] | |
) | |
for i, module in enumerate(attn_modules): | |
module._original_inner_forward = module.forward | |
module.forward = hacked_basic_transformer_inner_forward.__get__( | |
module, BasicTransformerBlock | |
) | |
module.bank = [] | |
module.fea_bank = [] | |
# float(i) / float(len(attn_modules)) | |
module.attn_weight = 0.0 | |
module.attention_auto_machine_weight = ( | |
self.attention_auto_machine_weight | |
) | |
module.gn_auto_machine_weight = self.gn_auto_machine_weight | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.uc_mask = self.uc_mask | |
module.style_fidelity = self.style_fidelity | |
module.ref_mask = self.ref_mask | |
module.ref_scale = self.ref_scale | |
else: | |
attn_modules = None | |
# gn_modules = None | |
if reference_adain: | |
gn_modules = [model.mid_block] | |
model.mid_block.gn_weight = 0 | |
down_blocks = model.down_blocks | |
for w, module in enumerate(down_blocks): | |
module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) | |
gn_modules.append(module) | |
# print(module.__class__.__name__,module.gn_weight) | |
for i, module in enumerate(gn_modules): | |
if getattr(module, "original_forward", None) is None: | |
module.original_forward = module.forward | |
if i == 0: | |
# mid_block | |
module.forward = hacked_mid_forward.__get__( | |
module, torch.nn.Module | |
) | |
# elif isinstance(module, CrossAttnDownBlock2D): | |
# module.forward = hack_CrossAttnDownBlock2D_forward.__get__( | |
# module, CrossAttnDownBlock2D | |
# ) | |
# module.mean_bank0 = [] | |
# module.var_bank0 = [] | |
# module.fea_bank0 = [] | |
elif isinstance(module, DownBlock2D): | |
module.forward = hacked_DownBlock2D_forward.__get__( | |
module, DownBlock2D | |
) | |
module.mean_bank = [] | |
module.var_bank = [] | |
module.fea_bank = [] | |
module.attention_auto_machine_weight = ( | |
self.attention_auto_machine_weight | |
) | |
module.gn_auto_machine_weight = self.gn_auto_machine_weight | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.do_classifier_free_guidance = ( | |
self.do_classifier_free_guidance | |
) | |
module.uc_mask = self.uc_mask | |
module.style_fidelity = self.style_fidelity | |
module.ref_mask = self.ref_mask | |
module.inpaint_mask = self.inpaint_mask | |
module.ref_scale = self.ref_scale | |
else: | |
gn_modules = None | |
return attn_modules, gn_modules | |
def change_module_mode(self, mode, attn_modules, gn_modules): | |
if attn_modules is not None: | |
for i, module in enumerate(attn_modules): | |
module.MODE = mode | |
if gn_modules is not None: | |
for i, module in enumerate(gn_modules): | |
module.MODE = mode | |