File size: 5,442 Bytes
d0ffe9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import json
import logging
from os import PathLike
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

from pydantic.v1 import BaseConfig, BaseSettings, Field
from pydantic.env_settings import (EnvSettingsSource, InitSettingsSource,
                                   SecretsSettingsSource,
                                   SettingsSourceCallable)

from animatediff import get_dir
from animatediff.schedulers import DiffusionScheduler

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

CKPT_EXTENSIONS = [".pt", ".ckpt", ".pth", ".safetensors"]


class JsonSettingsSource:
    __slots__ = ["json_config_path"]

    def __init__(
        self,
        json_config_path: Optional[Union[PathLike, list[PathLike]]] = list(),
    ) -> None:
        if isinstance(json_config_path, list):
            self.json_config_path = [Path(path) for path in json_config_path]
        else:
            self.json_config_path = [Path(json_config_path)] if json_config_path is not None else []

    def __call__(self, settings: BaseSettings) -> Dict[str, Any]:  # noqa C901
        classname = settings.__class__.__name__
        encoding = settings.__config__.env_file_encoding
        if len(self.json_config_path) == 0:
            pass  # no json config provided

        merged_config = dict()  # create an empty dict to merge configs into
        for idx, path in enumerate(self.json_config_path):
            if path.exists() and path.is_file():  # check if the path exists and is a file
                logger.debug(f"{classname}: loading config #{idx+1} from {path}")
                merged_config.update(json.loads(path.read_text(encoding=encoding)))
                logger.debug(f"{classname}: config state #{idx+1}: {merged_config}")
            else:
                raise FileNotFoundError(f"{classname}: config #{idx+1} at {path} not found or not a file")

        logger.debug(f"{classname}: loaded config: {merged_config}")
        return merged_config  # return the merged config

    def __repr__(self) -> str:
        return f"JsonSettingsSource(json_config_path={repr(self.json_config_path)})"


class JsonConfig(BaseConfig):
    json_config_path: Optional[Union[Path, list[Path]]] = None
    env_file_encoding: str = "utf-8"

    @classmethod
    def customise_sources(
        cls,
        init_settings: InitSettingsSource,
        env_settings: EnvSettingsSource,
        file_secret_settings: SecretsSettingsSource,
    ) -> Tuple[SettingsSourceCallable, ...]:
        # pull json_config_path from init_settings if passed, otherwise use the class var
        json_config_path = init_settings.init_kwargs.pop("json_config_path", cls.json_config_path)

        logger.debug(f"Using JsonSettingsSource for {cls.__name__}")
        json_settings = JsonSettingsSource(json_config_path=json_config_path)

        # return the new settings sources
        return (
            init_settings,
            json_settings,
        )


class InferenceConfig(BaseSettings):
    unet_additional_kwargs: dict[str, Any]
    noise_scheduler_kwargs: dict[str, Any]

    class Config(JsonConfig):
        json_config_path: Path


def get_infer_config(
    is_v2:bool,
    is_sdxl:bool,
) -> InferenceConfig:
    config_path: Path = get_dir("config").joinpath("inference/default.json" if not is_v2 else "inference/motion_v2.json")

    if is_sdxl:
        config_path = get_dir("config").joinpath("inference/motion_sdxl.json")

    settings = InferenceConfig(json_config_path=config_path)
    return settings


class ModelConfig(BaseSettings):
    name: str = Field(...)  # Config name, not actually used for much of anything
    path: Path = Field(...)  # Path to the model
    vae_path: str = ""  # Path to the model
    motion_module: Path = Field(...)  # Path to the motion module
    context_schedule: str = "uniform"
    lcm_map: Dict[str,Any]= Field({})
    gradual_latent_hires_fix_map: Dict[str,Any]= Field({})
    compile: bool = Field(False)  # whether to compile the model with TorchDynamo
    tensor_interpolation_slerp: bool = Field(True)
    seed: list[int] = Field([])  # Seed(s) for the random number generators
    scheduler: DiffusionScheduler = Field(DiffusionScheduler.k_dpmpp_2m)  # Scheduler to use
    steps: int = 25  # Number of inference steps to run
    guidance_scale: float = 7.5  # CFG scale to use
    unet_batch_size: int = 1
    clip_skip: int = 1  # skip the last N-1 layers of the CLIP text encoder
    prompt_fixed_ratio: float = 0.5
    head_prompt: str = ""
    prompt_map: Dict[str,str]= Field({})
    tail_prompt: str = ""
    n_prompt: list[str] = Field([])  # Anti-prompt(s) to use
    is_single_prompt_mode : bool = Field(False)
    lora_map: Dict[str,Any]= Field({})
    motion_lora_map: Dict[str,float]= Field({})
    ip_adapter_map: Dict[str,Any]= Field({})
    img2img_map: Dict[str,Any]= Field({})
    region_map: Dict[str,Any]= Field({})
    controlnet_map: Dict[str,Any]= Field({})
    upscale_config: Dict[str,Any]= Field({})
    stylize_config: Dict[str,Any]= Field({})
    output: Dict[str,Any]= Field({})
    result: Dict[str,Any]= Field({})

    class Config(JsonConfig):
        json_config_path: Path

    @property
    def save_name(self):
        return f"{self.name.lower()}-{self.path.stem.lower()}"


def get_model_config(config_path: Path) -> ModelConfig:
    settings = ModelConfig(json_config_path=config_path)
    return settings