Spaces:
Runtime error
Runtime error
# ************************************************************************* | |
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- | |
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- | |
# ytedance Inc.. | |
# ************************************************************************* | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
class AttentionBase: | |
def __init__(self): | |
self.cur_step = 0 | |
self.num_att_layers = -1 | |
self.cur_att_layer = 0 | |
def after_step(self): | |
pass | |
def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): | |
out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs) | |
self.cur_att_layer += 1 | |
if self.cur_att_layer == self.num_att_layers: | |
self.cur_att_layer = 0 | |
self.cur_step += 1 | |
# after step | |
self.after_step() | |
return out | |
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): | |
out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
return out | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
class MutualSelfAttentionControl(AttentionBase): | |
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5): | |
""" | |
Mutual self-attention control for Stable-Diffusion model | |
Args: | |
start_step: the step to start mutual self-attention control | |
start_layer: the layer to start mutual self-attention control | |
layer_idx: list of the layers to apply mutual self-attention control | |
step_idx: list the steps to apply mutual self-attention control | |
total_steps: the total number of steps | |
""" | |
super().__init__() | |
self.total_steps = total_steps | |
self.start_step = start_step | |
self.start_layer = start_layer | |
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16)) | |
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) | |
# store the guidance scale to decide whether there are unconditional branch | |
self.guidance_scale = guidance_scale | |
print("step_idx: ", self.step_idx) | |
print("layer_idx: ", self.layer_idx) | |
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): | |
""" | |
Attention forward function | |
""" | |
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: | |
return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs) | |
if self.guidance_scale > 1.0: | |
qu, qc = q[0:2], q[2:4] | |
ku, kc = k[0:2], k[2:4] | |
vu, vc = v[0:2], v[2:4] | |
# merge queries of source and target branch into one so we can use torch API | |
qu = torch.cat([qu[0:1], qu[1:2]], dim=2) | |
qc = torch.cat([qc[0:1], qc[1:2]], dim=2) | |
out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) | |
out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch | |
out_u = rearrange(out_u, 'b h n d -> b n (h d)') | |
out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) | |
out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch | |
out_c = rearrange(out_c, 'b h n d -> b n (h d)') | |
out = torch.cat([out_u, out_c], dim=0) | |
else: | |
q = torch.cat([q[0:1], q[1:2]], dim=2) | |
out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) | |
out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
return out | |
# forward function for default attention processor | |
# modified from __call__ function of AttnProcessor in diffusers | |
def override_attn_proc_forward(attn, editor, place_in_unet): | |
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): | |
""" | |
The attention is similar to the original implementation of LDM CrossAttention class | |
except adding some modifications on the attention | |
""" | |
if encoder_hidden_states is not None: | |
context = encoder_hidden_states | |
if attention_mask is not None: | |
mask = attention_mask | |
to_out = attn.to_out | |
if isinstance(to_out, nn.modules.container.ModuleList): | |
to_out = attn.to_out[0] | |
else: | |
to_out = attn.to_out | |
h = attn.heads | |
q = attn.to_q(x) | |
is_cross = context is not None | |
context = context if is_cross else x | |
k = attn.to_k(context) | |
v = attn.to_v(context) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) | |
# the only difference | |
out = editor( | |
q, k, v, is_cross, place_in_unet, | |
attn.heads, scale=attn.scale) | |
return to_out(out) | |
return forward | |
# forward function for lora attention processor | |
# modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1 | |
def override_lora_attn_proc_forward(attn, editor, place_in_unet): | |
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora_scale=1.0): | |
residual = hidden_states | |
input_ndim = hidden_states.ndim | |
is_cross = encoder_hidden_states is not None | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states) | |
query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value)) | |
# the only difference | |
hidden_states = editor( | |
query, key, value, is_cross, place_in_unet, | |
attn.heads, scale=attn.scale) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
return forward | |
def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'): | |
""" | |
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt] | |
""" | |
def register_editor(net, count, place_in_unet): | |
for name, subnet in net.named_children(): | |
if net.__class__.__name__ == 'Attention': # spatial Transformer layer | |
if attn_processor == 'attn_proc': | |
net.forward = override_attn_proc_forward(net, editor, place_in_unet) | |
elif attn_processor == 'lora_attn_proc': | |
net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet) | |
else: | |
raise NotImplementedError("not implemented") | |
return count + 1 | |
elif hasattr(net, 'children'): | |
count = register_editor(subnet, count, place_in_unet) | |
return count | |
cross_att_count = 0 | |
for net_name, net in model.unet.named_children(): | |
if "down" in net_name: | |
cross_att_count += register_editor(net, 0, "down") | |
elif "mid" in net_name: | |
cross_att_count += register_editor(net, 0, "mid") | |
elif "up" in net_name: | |
cross_att_count += register_editor(net, 0, "up") | |
editor.num_att_layers = cross_att_count | |