|
import json |
|
from dataclasses import asdict, dataclass, field |
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
@dataclass |
|
class BaseConfig: |
|
def get(self, attribute_name, default=None): |
|
return getattr(self, attribute_name, default) |
|
|
|
def pop(self, attribute_name, default=None): |
|
if hasattr(self, attribute_name): |
|
value = getattr(self, attribute_name) |
|
delattr(self, attribute_name) |
|
return value |
|
else: |
|
return default |
|
|
|
def __str__(self): |
|
return json.dumps(asdict(self), indent=4) |
|
|
|
|
|
@dataclass |
|
class DataConfig(BaseConfig): |
|
data_dir: List[Optional[str]] = field(default_factory=list) |
|
caption_proportion: Dict[str, int] = field(default_factory=lambda: {"prompt": 1}) |
|
external_caption_suffixes: List[str] = field(default_factory=list) |
|
external_clipscore_suffixes: List[str] = field(default_factory=list) |
|
clip_thr_temperature: float = 1.0 |
|
clip_thr: float = 0.0 |
|
sort_dataset: bool = False |
|
load_text_feat: bool = False |
|
load_vae_feat: bool = False |
|
transform: str = "default_train" |
|
type: str = "SanaWebDatasetMS" |
|
image_size: int = 512 |
|
hq_only: bool = False |
|
valid_num: int = 0 |
|
data: Any = None |
|
extra: Any = None |
|
|
|
|
|
@dataclass |
|
class ModelConfig(BaseConfig): |
|
model: str = "SanaMS_600M_P1_D28" |
|
image_size: int = 512 |
|
mixed_precision: str = "fp16" |
|
fp32_attention: bool = True |
|
load_from: Optional[str] = None |
|
resume_from: Optional[Dict[str, Any]] = field( |
|
default_factory=lambda: { |
|
"checkpoint": None, |
|
"load_ema": False, |
|
"resume_lr_scheduler": True, |
|
"resume_optimizer": True, |
|
} |
|
) |
|
aspect_ratio_type: str = "ASPECT_RATIO_1024" |
|
multi_scale: bool = True |
|
pe_interpolation: float = 1.0 |
|
micro_condition: bool = False |
|
attn_type: str = "linear" |
|
autocast_linear_attn: bool = False |
|
ffn_type: str = "glumbconv" |
|
mlp_acts: List[Optional[str]] = field(default_factory=lambda: ["silu", "silu", None]) |
|
mlp_ratio: float = 2.5 |
|
use_pe: bool = False |
|
qk_norm: bool = False |
|
class_dropout_prob: float = 0.0 |
|
linear_head_dim: int = 32 |
|
cross_norm: bool = False |
|
cfg_scale: int = 4 |
|
guidance_type: str = "classifier-free" |
|
pag_applied_layers: List[int] = field(default_factory=lambda: [14]) |
|
extra: Any = None |
|
|
|
|
|
@dataclass |
|
class AEConfig(BaseConfig): |
|
vae_type: str = "dc-ae" |
|
vae_pretrained: str = "mit-han-lab/dc-ae-f32c32-sana-1.0" |
|
scale_factor: float = 0.41407 |
|
vae_latent_dim: int = 32 |
|
vae_downsample_rate: int = 32 |
|
sample_posterior: bool = True |
|
extra: Any = None |
|
|
|
|
|
@dataclass |
|
class TextEncoderConfig(BaseConfig): |
|
text_encoder_name: str = "gemma-2-2b-it" |
|
caption_channels: int = 2304 |
|
y_norm: bool = True |
|
y_norm_scale_factor: float = 1.0 |
|
model_max_length: int = 300 |
|
chi_prompt: List[Optional[str]] = field(default_factory=lambda: []) |
|
extra: Any = None |
|
|
|
|
|
@dataclass |
|
class SchedulerConfig(BaseConfig): |
|
train_sampling_steps: int = 1000 |
|
predict_v: bool = True |
|
noise_schedule: str = "linear_flow" |
|
pred_sigma: bool = False |
|
learn_sigma: bool = True |
|
vis_sampler: str = "flow_dpm-solver" |
|
flow_shift: float = 1.0 |
|
|
|
weighting_scheme: Optional[str] = "logit_normal" |
|
logit_mean: float = 0.0 |
|
logit_std: float = 1.0 |
|
extra: Any = None |
|
|
|
|
|
@dataclass |
|
class TrainingConfig(BaseConfig): |
|
num_workers: int = 4 |
|
seed: int = 43 |
|
train_batch_size: int = 32 |
|
num_epochs: int = 100 |
|
gradient_accumulation_steps: int = 1 |
|
grad_checkpointing: bool = False |
|
gradient_clip: float = 1.0 |
|
gc_step: int = 1 |
|
optimizer: Dict[str, Any] = field( |
|
default_factory=lambda: {"eps": 1.0e-10, "lr": 0.0001, "type": "AdamW", "weight_decay": 0.03} |
|
) |
|
lr_schedule: str = "constant" |
|
lr_schedule_args: Dict[str, int] = field(default_factory=lambda: {"num_warmup_steps": 500}) |
|
auto_lr: Dict[str, str] = field(default_factory=lambda: {"rule": "sqrt"}) |
|
ema_rate: float = 0.9999 |
|
eval_batch_size: int = 16 |
|
use_fsdp: bool = False |
|
use_flash_attn: bool = False |
|
eval_sampling_steps: int = 250 |
|
lora_rank: int = 4 |
|
log_interval: int = 50 |
|
mask_type: str = "null" |
|
mask_loss_coef: float = 0.0 |
|
load_mask_index: bool = False |
|
snr_loss: bool = False |
|
real_prompt_ratio: float = 1.0 |
|
save_image_epochs: int = 1 |
|
save_model_epochs: int = 1 |
|
save_model_steps: int = 1000000 |
|
visualize: bool = False |
|
null_embed_root: str = "output/pretrained_models/" |
|
valid_prompt_embed_root: str = "output/tmp_embed/" |
|
validation_prompts: List[str] = field( |
|
default_factory=lambda: [ |
|
"dog", |
|
"portrait photo of a girl, photograph, highly detailed face, depth of field", |
|
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", |
|
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", |
|
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", |
|
] |
|
) |
|
local_save_vis: bool = False |
|
deterministic_validation: bool = True |
|
online_metric: bool = False |
|
eval_metric_step: int = 5000 |
|
online_metric_dir: str = "metric_helper" |
|
work_dir: str = "/cache/exps/" |
|
skip_step: int = 0 |
|
loss_type: str = "huber" |
|
huber_c: float = 0.001 |
|
num_ddim_timesteps: int = 50 |
|
w_max: float = 15.0 |
|
w_min: float = 3.0 |
|
ema_decay: float = 0.95 |
|
debug_nan: bool = False |
|
extra: Any = None |
|
|
|
|
|
@dataclass |
|
class SanaConfig(BaseConfig): |
|
data: DataConfig |
|
model: ModelConfig |
|
vae: AEConfig |
|
text_encoder: TextEncoderConfig |
|
scheduler: SchedulerConfig |
|
train: TrainingConfig |
|
work_dir: str = "output/" |
|
resume_from: Optional[str] = None |
|
load_from: Optional[str] = None |
|
debug: bool = False |
|
caching: bool = False |
|
report_to: str = "wandb" |
|
tracker_project_name: str = "t2i-evit-baseline" |
|
name: str = "baseline" |
|
loss_report_name: str = "loss" |
|
|