File size: 6,100 Bytes
d643072 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import json
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class BaseConfig:
def get(self, attribute_name, default=None):
return getattr(self, attribute_name, default)
def pop(self, attribute_name, default=None):
if hasattr(self, attribute_name):
value = getattr(self, attribute_name)
delattr(self, attribute_name)
return value
else:
return default
def __str__(self):
return json.dumps(asdict(self), indent=4)
@dataclass
class DataConfig(BaseConfig):
data_dir: List[Optional[str]] = field(default_factory=list)
caption_proportion: Dict[str, int] = field(default_factory=lambda: {"prompt": 1})
external_caption_suffixes: List[str] = field(default_factory=list)
external_clipscore_suffixes: List[str] = field(default_factory=list)
clip_thr_temperature: float = 1.0
clip_thr: float = 0.0
sort_dataset: bool = False
load_text_feat: bool = False
load_vae_feat: bool = False
transform: str = "default_train"
type: str = "SanaWebDatasetMS"
image_size: int = 512
hq_only: bool = False
valid_num: int = 0
data: Any = None
extra: Any = None
@dataclass
class ModelConfig(BaseConfig):
model: str = "SanaMS_600M_P1_D28"
image_size: int = 512
mixed_precision: str = "fp16" # ['fp16', 'fp32', 'bf16']
fp32_attention: bool = True
load_from: Optional[str] = None
resume_from: Optional[Dict[str, Any]] = field(
default_factory=lambda: {
"checkpoint": None,
"load_ema": False,
"resume_lr_scheduler": True,
"resume_optimizer": True,
}
)
aspect_ratio_type: str = "ASPECT_RATIO_1024"
multi_scale: bool = True
pe_interpolation: float = 1.0
micro_condition: bool = False
attn_type: str = "linear"
autocast_linear_attn: bool = False
ffn_type: str = "glumbconv"
mlp_acts: List[Optional[str]] = field(default_factory=lambda: ["silu", "silu", None])
mlp_ratio: float = 2.5
use_pe: bool = False
qk_norm: bool = False
class_dropout_prob: float = 0.0
linear_head_dim: int = 32
cross_norm: bool = False
cfg_scale: int = 4
guidance_type: str = "classifier-free"
pag_applied_layers: List[int] = field(default_factory=lambda: [14])
extra: Any = None
@dataclass
class AEConfig(BaseConfig):
vae_type: str = "dc-ae"
vae_pretrained: str = "mit-han-lab/dc-ae-f32c32-sana-1.0"
scale_factor: float = 0.41407
vae_latent_dim: int = 32
vae_downsample_rate: int = 32
sample_posterior: bool = True
extra: Any = None
@dataclass
class TextEncoderConfig(BaseConfig):
text_encoder_name: str = "gemma-2-2b-it"
caption_channels: int = 2304
y_norm: bool = True
y_norm_scale_factor: float = 1.0
model_max_length: int = 300
chi_prompt: List[Optional[str]] = field(default_factory=lambda: [])
extra: Any = None
@dataclass
class SchedulerConfig(BaseConfig):
train_sampling_steps: int = 1000
predict_v: bool = True
noise_schedule: str = "linear_flow"
pred_sigma: bool = False
learn_sigma: bool = True
vis_sampler: str = "flow_dpm-solver"
flow_shift: float = 1.0
# logit-normal timestep
weighting_scheme: Optional[str] = "logit_normal"
logit_mean: float = 0.0
logit_std: float = 1.0
extra: Any = None
@dataclass
class TrainingConfig(BaseConfig):
num_workers: int = 4
seed: int = 43
train_batch_size: int = 32
num_epochs: int = 100
gradient_accumulation_steps: int = 1
grad_checkpointing: bool = False
gradient_clip: float = 1.0
gc_step: int = 1
optimizer: Dict[str, Any] = field(
default_factory=lambda: {"eps": 1.0e-10, "lr": 0.0001, "type": "AdamW", "weight_decay": 0.03}
)
lr_schedule: str = "constant"
lr_schedule_args: Dict[str, int] = field(default_factory=lambda: {"num_warmup_steps": 500})
auto_lr: Dict[str, str] = field(default_factory=lambda: {"rule": "sqrt"})
ema_rate: float = 0.9999
eval_batch_size: int = 16
use_fsdp: bool = False
use_flash_attn: bool = False
eval_sampling_steps: int = 250
lora_rank: int = 4
log_interval: int = 50
mask_type: str = "null"
mask_loss_coef: float = 0.0
load_mask_index: bool = False
snr_loss: bool = False
real_prompt_ratio: float = 1.0
save_image_epochs: int = 1
save_model_epochs: int = 1
save_model_steps: int = 1000000
visualize: bool = False
null_embed_root: str = "output/pretrained_models/"
valid_prompt_embed_root: str = "output/tmp_embed/"
validation_prompts: List[str] = field(
default_factory=lambda: [
"dog",
"portrait photo of a girl, photograph, highly detailed face, depth of field",
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
]
)
local_save_vis: bool = False
deterministic_validation: bool = True
online_metric: bool = False
eval_metric_step: int = 5000
online_metric_dir: str = "metric_helper"
work_dir: str = "/cache/exps/"
skip_step: int = 0
loss_type: str = "huber"
huber_c: float = 0.001
num_ddim_timesteps: int = 50
w_max: float = 15.0
w_min: float = 3.0
ema_decay: float = 0.95
debug_nan: bool = False
extra: Any = None
@dataclass
class SanaConfig(BaseConfig):
data: DataConfig
model: ModelConfig
vae: AEConfig
text_encoder: TextEncoderConfig
scheduler: SchedulerConfig
train: TrainingConfig
work_dir: str = "output/"
resume_from: Optional[str] = None
load_from: Optional[str] = None
debug: bool = False
caching: bool = False
report_to: str = "wandb"
tracker_project_name: str = "t2i-evit-baseline"
name: str = "baseline"
loss_report_name: str = "loss"
|