File size: 2,228 Bytes
6c4dee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from typing import Tuple

import torch.nn as nn

from .clip import FrozenCLIPEmbedder
from .quant import VectorQuantizer2
from .var import VAR
from .vqvae import VQVAE
from .pipeline import TVARPipeline


def build_vae_var(
    # Shared args
    device,
    patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),  # 10 steps by default
    # VQVAE args
    V=4096,
    Cvae=32,
    ch=160,
    share_quant_resi=4,
    # VAR args
    depth=16,
    shared_aln=False,
    attn_l2_norm=True,
    init_adaln=0.5,
    init_adaln_gamma=1e-5,
    init_head=0.02,
    init_std=-1,  # init_std < 0: automated
    text_encoder_path=None,
    text_encoder_2_path=None,
    rope=False,
    rope_theta=100,
    rope_size=None,
    dpr=0,
    use_swiglu_ffn=False,
) -> Tuple[VQVAE, VAR]:
    heads = depth
    width = depth * 64
    if dpr > 0:
        dpr = dpr * depth / 24

    # disable built-in initialization for speed
    for clz in (
        nn.Linear,
        nn.LayerNorm,
        nn.BatchNorm2d,
        nn.SyncBatchNorm,
        nn.Conv1d,
        nn.Conv2d,
        nn.ConvTranspose1d,
        nn.ConvTranspose2d,
    ):
        setattr(clz, "reset_parameters", lambda self: None)

    # build models
    vae_local = VQVAE(
        vocab_size=V,
        z_channels=Cvae,
        ch=ch,
        test_mode=True,
        share_quant_resi=share_quant_resi,
        v_patch_nums=patch_nums,
    ).to(device)
    var_wo_ddp = VAR(
        depth=depth,
        embed_dim=width,
        num_heads=heads,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=dpr,
        norm_eps=1e-6,
        shared_aln=shared_aln,
        attn_l2_norm=attn_l2_norm,
        patch_nums=patch_nums,
        rope=rope,
        rope_theta=rope_theta,
        rope_size=rope_size,
        use_swiglu_ffn=use_swiglu_ffn,
    ).to(device)
    var_wo_ddp.init_weights(
        init_adaln=init_adaln,
        init_adaln_gamma=init_adaln_gamma,
        init_head=init_head,
        init_std=init_std,
    )
    text_encoder = FrozenCLIPEmbedder(text_encoder_path)
    text_encoder_2 = FrozenCLIPEmbedder(text_encoder_2_path)
    pipe = TVARPipeline(var_wo_ddp, vae_local, text_encoder, text_encoder_2, device)

    return vae_local, var_wo_ddp, pipe