|
import torch |
|
|
|
from einops import rearrange |
|
from dataclasses import dataclass |
|
|
|
T = torch.Tensor |
|
|
|
@dataclass(frozen=True) |
|
class StyleAlignedArgs: |
|
share_group_norm: bool = True |
|
share_layer_norm: bool = True, |
|
share_attention: bool = True |
|
adain_queries: bool = True |
|
adain_keys: bool = True |
|
adain_values: bool = False |
|
full_attention_share: bool = False |
|
keys_scale: float = 1. |
|
only_self_level: float = 0. |
|
|
|
def expand_first(feat: T, scale=1., ) -> T: |
|
b = feat.shape[0] |
|
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) |
|
if scale == 1: |
|
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) |
|
else: |
|
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1) |
|
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1) |
|
return feat_style.reshape(*feat.shape) |
|
|
|
|
|
def concat_first(feat: T, dim=2, scale=1.) -> T: |
|
feat_style = expand_first(feat, scale=scale) |
|
return torch.cat((feat, feat_style), dim=dim) |
|
|
|
|
|
def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]: |
|
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt() |
|
feat_mean = feat.mean(dim=-2, keepdims=True) |
|
return feat_mean, feat_std |
|
|
|
|
|
def adain(feat: T) -> T: |
|
feat_mean, feat_std = calc_mean_std(feat) |
|
feat_style_mean = expand_first(feat_mean) |
|
feat_style_std = expand_first(feat_std) |
|
feat = (feat - feat_mean) / feat_std |
|
feat = feat * feat_style_std + feat_style_mean |
|
return feat |
|
|
|
def swapping_attention(key, value, chunk_size=2): |
|
chunk_length = key.size()[0] // chunk_size |
|
reference_image_index = [0] * chunk_length |
|
key = rearrange(key, "(b f) d c -> b f d c", f=chunk_length) |
|
key = key[:, reference_image_index] |
|
key = rearrange(key, "b f d c -> (b f) d c") |
|
value = rearrange(value, "(b f) d c -> b f d c", f=chunk_length) |
|
value = value[:, reference_image_index] |
|
value = rearrange(value, "b f d c -> (b f) d c") |
|
|
|
return key, value |