|
import functools |
|
|
|
import torch |
|
from diffusers.models.attention import BasicTransformerBlock |
|
from diffusers.utils.import_utils import is_xformers_available |
|
|
|
from .lora import LoraInjectedLinear |
|
|
|
if is_xformers_available(): |
|
import xformers |
|
import xformers.ops |
|
else: |
|
xformers = None |
|
|
|
|
|
@functools.cache |
|
def test_xformers_backwards(size): |
|
@torch.enable_grad() |
|
def _grad(size): |
|
q = torch.randn((1, 4, size), device="cuda") |
|
k = torch.randn((1, 4, size), device="cuda") |
|
v = torch.randn((1, 4, size), device="cuda") |
|
|
|
q = q.detach().requires_grad_() |
|
k = k.detach().requires_grad_() |
|
v = v.detach().requires_grad_() |
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v) |
|
loss = out.sum(2).mean(0).sum() |
|
|
|
return torch.autograd.grad(loss, v) |
|
|
|
try: |
|
_grad(size) |
|
print(size, "pass") |
|
return True |
|
except Exception as e: |
|
print(size, "fail") |
|
return False |
|
|
|
|
|
def set_use_memory_efficient_attention_xformers( |
|
module: torch.nn.Module, valid: bool |
|
) -> None: |
|
def fn_test_dim_head(module: torch.nn.Module): |
|
if isinstance(module, BasicTransformerBlock): |
|
|
|
source = module.attn1.to_v |
|
if isinstance(source, LoraInjectedLinear): |
|
source = source.linear |
|
|
|
dim_head = source.out_features // module.attn1.heads |
|
|
|
result = test_xformers_backwards(dim_head) |
|
|
|
|
|
if not result: |
|
module.set_use_memory_efficient_attention_xformers(False) |
|
|
|
for child in module.children(): |
|
fn_test_dim_head(child) |
|
|
|
if not is_xformers_available() and valid: |
|
print("XFormers is not available. Skipping.") |
|
return |
|
|
|
module.set_use_memory_efficient_attention_xformers(valid) |
|
|
|
if valid: |
|
fn_test_dim_head(module) |
|
|