Spaces:
Runtime error
Runtime error
File size: 9,420 Bytes
e1ebf71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
class AttentionBase:
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
def after_step(self):
pass
def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
# after step
self.after_step()
return out
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = rearrange(out, 'b h n d -> b n (h d)')
return out
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
class MutualSelfAttentionControl(AttentionBase):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5):
"""
Mutual self-attention control for Stable-Diffusion model
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
"""
super().__init__()
self.total_steps = total_steps
self.start_step = start_step
self.start_layer = start_layer
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16))
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
# store the guidance scale to decide whether there are unconditional branch
self.guidance_scale = guidance_scale
print("step_idx: ", self.step_idx)
print("layer_idx: ", self.layer_idx)
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
if self.guidance_scale > 1.0:
qu, qc = q[0:2], q[2:4]
ku, kc = k[0:2], k[2:4]
vu, vc = v[0:2], v[2:4]
# merge queries of source and target branch into one so we can use torch API
qu = torch.cat([qu[0:1], qu[1:2]], dim=2)
qc = torch.cat([qc[0:1], qc[1:2]], dim=2)
out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out_u = rearrange(out_u, 'b h n d -> b n (h d)')
out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out_c = rearrange(out_c, 'b h n d -> b n (h d)')
out = torch.cat([out_u, out_c], dim=0)
else:
q = torch.cat([q[0:1], q[1:2]], dim=2)
out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out = rearrange(out, 'b h n d -> b n (h d)')
return out
# forward function for default attention processor
# modified from __call__ function of AttnProcessor in diffusers
def override_attn_proc_forward(attn, editor, place_in_unet):
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
"""
The attention is similar to the original implementation of LDM CrossAttention class
except adding some modifications on the attention
"""
if encoder_hidden_states is not None:
context = encoder_hidden_states
if attention_mask is not None:
mask = attention_mask
to_out = attn.to_out
if isinstance(to_out, nn.modules.container.ModuleList):
to_out = attn.to_out[0]
else:
to_out = attn.to_out
h = attn.heads
q = attn.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = attn.to_k(context)
v = attn.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
# the only difference
out = editor(
q, k, v, is_cross, place_in_unet,
attn.heads, scale=attn.scale)
return to_out(out)
return forward
# forward function for lora attention processor
# modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1
def override_lora_attn_proc_forward(attn, editor, place_in_unet):
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora_scale=1.0):
residual = hidden_states
input_ndim = hidden_states.ndim
is_cross = encoder_hidden_states is not None
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states)
query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value))
# the only difference
hidden_states = editor(
query, key, value, is_cross, place_in_unet,
attn.heads, scale=attn.scale)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
return forward
def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'):
"""
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
"""
def register_editor(net, count, place_in_unet):
for name, subnet in net.named_children():
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
if attn_processor == 'attn_proc':
net.forward = override_attn_proc_forward(net, editor, place_in_unet)
elif attn_processor == 'lora_attn_proc':
net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet)
else:
raise NotImplementedError("not implemented")
return count + 1
elif hasattr(net, 'children'):
count = register_editor(subnet, count, place_in_unet)
return count
cross_att_count = 0
for net_name, net in model.unet.named_children():
if "down" in net_name:
cross_att_count += register_editor(net, 0, "down")
elif "mid" in net_name:
cross_att_count += register_editor(net, 0, "mid")
elif "up" in net_name:
cross_att_count += register_editor(net, 0, "up")
editor.num_att_layers = cross_att_count
|