jayparmr commited on
Commit
ea5c647
·
1 Parent(s): 3e9c18d

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. external/llite/library/__init__.py +0 -0
  2. external/llite/library/attention_processors.py +227 -0
  3. external/llite/library/config_util.py +621 -0
  4. external/llite/library/custom_train_functions.py +529 -0
  5. external/llite/library/huggingface_util.py +81 -0
  6. external/llite/library/hypernetwork.py +223 -0
  7. external/llite/library/ipex/__init__.py +169 -0
  8. external/llite/library/ipex/attention.py +151 -0
  9. external/llite/library/ipex/diffusers.py +120 -0
  10. external/llite/library/ipex/gradscaler.py +183 -0
  11. external/llite/library/ipex/hijacks.py +252 -0
  12. external/llite/library/lpw_stable_diffusion.py +1254 -0
  13. external/llite/library/model_util.py +1350 -0
  14. external/llite/library/original_unet.py +1915 -0
  15. external/llite/library/sai_model_spec.py +305 -0
  16. external/llite/library/sdxl_lpw_stable_diffusion.py +1342 -0
  17. external/llite/library/sdxl_model_util.py +578 -0
  18. external/llite/library/sdxl_original_unet.py +1281 -0
  19. external/llite/library/sdxl_train_util.py +367 -0
  20. external/llite/library/slicing_vae.py +679 -0
  21. external/llite/library/train_util.py +0 -0
  22. external/llite/library/utils.py +6 -0
  23. external/llite/networks/.ipynb_checkpoints/control_net_lllite-checkpoint.py +446 -0
  24. external/llite/networks/check_lora_weights.py +45 -0
  25. external/llite/networks/control_net_lllite.py +446 -0
  26. external/llite/networks/control_net_lllite_for_train.py +502 -0
  27. external/llite/networks/dylora.py +450 -0
  28. external/llite/networks/extract_lora_from_dylora.py +125 -0
  29. external/llite/networks/extract_lora_from_models.py +296 -0
  30. external/llite/networks/lora.py +1225 -0
  31. external/llite/networks/lora_diffusers.py +609 -0
  32. external/llite/networks/lora_fa.py +1241 -0
  33. external/llite/networks/lora_interrogator.py +139 -0
  34. external/llite/networks/merge_lora.py +357 -0
  35. external/llite/networks/merge_lora_old.py +185 -0
  36. external/llite/networks/oft.py +430 -0
  37. external/llite/networks/resize_lora.py +362 -0
  38. external/llite/networks/sdxl_merge_lora.py +348 -0
  39. external/llite/networks/svd_merge_lora.py +260 -0
  40. external/llite/tools/cache_latents.py +194 -0
  41. external/llite/tools/cache_text_encoder_outputs.py +191 -0
  42. external/llite/tools/canny.py +30 -0
  43. external/llite/tools/convert_diffusers20_original_sd.py +160 -0
  44. external/llite/tools/detect_face_rotate.py +246 -0
  45. external/llite/tools/latent_upscaler.py +348 -0
  46. external/llite/tools/merge_models.py +168 -0
  47. external/llite/tools/original_control_net.py +337 -0
  48. external/llite/tools/resize_images_to_resolution.py +128 -0
  49. external/llite/tools/show_metadata.py +19 -0
  50. inference.py +38 -16
external/llite/library/__init__.py ADDED
File without changes
external/llite/library/attention_processors.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any
3
+ from einops import rearrange
4
+ import torch
5
+ from diffusers.models.attention_processor import Attention
6
+
7
+
8
+ # flash attention forwards and backwards
9
+
10
+ # https://arxiv.org/abs/2205.14135
11
+
12
+ EPSILON = 1e-6
13
+
14
+
15
+ class FlashAttentionFunction(torch.autograd.function.Function):
16
+ @staticmethod
17
+ @torch.no_grad()
18
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
19
+ """Algorithm 2 in the paper"""
20
+
21
+ device = q.device
22
+ dtype = q.dtype
23
+ max_neg_value = -torch.finfo(q.dtype).max
24
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
25
+
26
+ o = torch.zeros_like(q)
27
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
28
+ all_row_maxes = torch.full(
29
+ (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
30
+ )
31
+
32
+ scale = q.shape[-1] ** -0.5
33
+
34
+ if mask is None:
35
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
36
+ else:
37
+ mask = rearrange(mask, "b n -> b 1 1 n")
38
+ mask = mask.split(q_bucket_size, dim=-1)
39
+
40
+ row_splits = zip(
41
+ q.split(q_bucket_size, dim=-2),
42
+ o.split(q_bucket_size, dim=-2),
43
+ mask,
44
+ all_row_sums.split(q_bucket_size, dim=-2),
45
+ all_row_maxes.split(q_bucket_size, dim=-2),
46
+ )
47
+
48
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
49
+ q_start_index = ind * q_bucket_size - qk_len_diff
50
+
51
+ col_splits = zip(
52
+ k.split(k_bucket_size, dim=-2),
53
+ v.split(k_bucket_size, dim=-2),
54
+ )
55
+
56
+ for k_ind, (kc, vc) in enumerate(col_splits):
57
+ k_start_index = k_ind * k_bucket_size
58
+
59
+ attn_weights = (
60
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
61
+ )
62
+
63
+ if row_mask is not None:
64
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
65
+
66
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
67
+ causal_mask = torch.ones(
68
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
69
+ ).triu(q_start_index - k_start_index + 1)
70
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
71
+
72
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
73
+ attn_weights -= block_row_maxes
74
+ exp_weights = torch.exp(attn_weights)
75
+
76
+ if row_mask is not None:
77
+ exp_weights.masked_fill_(~row_mask, 0.0)
78
+
79
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
80
+ min=EPSILON
81
+ )
82
+
83
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
84
+
85
+ exp_values = torch.einsum(
86
+ "... i j, ... j d -> ... i d", exp_weights, vc
87
+ )
88
+
89
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
90
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
91
+
92
+ new_row_sums = (
93
+ exp_row_max_diff * row_sums
94
+ + exp_block_row_max_diff * block_row_sums
95
+ )
96
+
97
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
98
+ (exp_block_row_max_diff / new_row_sums) * exp_values
99
+ )
100
+
101
+ row_maxes.copy_(new_row_maxes)
102
+ row_sums.copy_(new_row_sums)
103
+
104
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
105
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
106
+
107
+ return o
108
+
109
+ @staticmethod
110
+ @torch.no_grad()
111
+ def backward(ctx, do):
112
+ """Algorithm 4 in the paper"""
113
+
114
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
115
+ q, k, v, o, l, m = ctx.saved_tensors
116
+
117
+ device = q.device
118
+
119
+ max_neg_value = -torch.finfo(q.dtype).max
120
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
121
+
122
+ dq = torch.zeros_like(q)
123
+ dk = torch.zeros_like(k)
124
+ dv = torch.zeros_like(v)
125
+
126
+ row_splits = zip(
127
+ q.split(q_bucket_size, dim=-2),
128
+ o.split(q_bucket_size, dim=-2),
129
+ do.split(q_bucket_size, dim=-2),
130
+ mask,
131
+ l.split(q_bucket_size, dim=-2),
132
+ m.split(q_bucket_size, dim=-2),
133
+ dq.split(q_bucket_size, dim=-2),
134
+ )
135
+
136
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
137
+ q_start_index = ind * q_bucket_size - qk_len_diff
138
+
139
+ col_splits = zip(
140
+ k.split(k_bucket_size, dim=-2),
141
+ v.split(k_bucket_size, dim=-2),
142
+ dk.split(k_bucket_size, dim=-2),
143
+ dv.split(k_bucket_size, dim=-2),
144
+ )
145
+
146
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
147
+ k_start_index = k_ind * k_bucket_size
148
+
149
+ attn_weights = (
150
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
151
+ )
152
+
153
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
154
+ causal_mask = torch.ones(
155
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
156
+ ).triu(q_start_index - k_start_index + 1)
157
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
158
+
159
+ exp_attn_weights = torch.exp(attn_weights - mc)
160
+
161
+ if row_mask is not None:
162
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
163
+
164
+ p = exp_attn_weights / lc
165
+
166
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
167
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
168
+
169
+ D = (doc * oc).sum(dim=-1, keepdims=True)
170
+ ds = p * scale * (dp - D)
171
+
172
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
173
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
174
+
175
+ dqc.add_(dq_chunk)
176
+ dkc.add_(dk_chunk)
177
+ dvc.add_(dv_chunk)
178
+
179
+ return dq, dk, dv, None, None, None, None
180
+
181
+
182
+ class FlashAttnProcessor:
183
+ def __call__(
184
+ self,
185
+ attn: Attention,
186
+ hidden_states,
187
+ encoder_hidden_states=None,
188
+ attention_mask=None,
189
+ ) -> Any:
190
+ q_bucket_size = 512
191
+ k_bucket_size = 1024
192
+
193
+ h = attn.heads
194
+ q = attn.to_q(hidden_states)
195
+
196
+ encoder_hidden_states = (
197
+ encoder_hidden_states
198
+ if encoder_hidden_states is not None
199
+ else hidden_states
200
+ )
201
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
202
+
203
+ if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
204
+ context_k, context_v = attn.hypernetwork.forward(
205
+ hidden_states, encoder_hidden_states
206
+ )
207
+ context_k = context_k.to(hidden_states.dtype)
208
+ context_v = context_v.to(hidden_states.dtype)
209
+ else:
210
+ context_k = encoder_hidden_states
211
+ context_v = encoder_hidden_states
212
+
213
+ k = attn.to_k(context_k)
214
+ v = attn.to_v(context_v)
215
+ del encoder_hidden_states, hidden_states
216
+
217
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
218
+
219
+ out = FlashAttentionFunction.apply(
220
+ q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
221
+ )
222
+
223
+ out = rearrange(out, "b h n d -> b n (h d)")
224
+
225
+ out = attn.to_out[0](out)
226
+ out = attn.to_out[1](out)
227
+ return out
external/llite/library/config_util.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+ # from toolz import curry
12
+ from typing import (
13
+ List,
14
+ Optional,
15
+ Sequence,
16
+ Tuple,
17
+ Union,
18
+ )
19
+
20
+ import toml
21
+ import voluptuous
22
+ from voluptuous import (
23
+ Any,
24
+ ExactSequence,
25
+ MultipleInvalid,
26
+ Object,
27
+ Required,
28
+ Schema,
29
+ )
30
+ from transformers import CLIPTokenizer
31
+
32
+ from . import train_util
33
+ from .train_util import (
34
+ DreamBoothSubset,
35
+ FineTuningSubset,
36
+ ControlNetSubset,
37
+ DreamBoothDataset,
38
+ FineTuningDataset,
39
+ ControlNetDataset,
40
+ DatasetGroup,
41
+ )
42
+
43
+
44
+ def add_config_arguments(parser: argparse.ArgumentParser):
45
+ parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
46
+
47
+ # TODO: inherit Params class in Subset, Dataset
48
+
49
+ @dataclass
50
+ class BaseSubsetParams:
51
+ image_dir: Optional[str] = None
52
+ num_repeats: int = 1
53
+ shuffle_caption: bool = False
54
+ caption_separator: str = ',',
55
+ keep_tokens: int = 0
56
+ keep_tokens_separator: str = None,
57
+ color_aug: bool = False
58
+ flip_aug: bool = False
59
+ face_crop_aug_range: Optional[Tuple[float, float]] = None
60
+ random_crop: bool = False
61
+ caption_prefix: Optional[str] = None
62
+ caption_suffix: Optional[str] = None
63
+ caption_dropout_rate: float = 0.0
64
+ caption_dropout_every_n_epochs: int = 0
65
+ caption_tag_dropout_rate: float = 0.0
66
+ token_warmup_min: int = 1
67
+ token_warmup_step: float = 0
68
+
69
+ @dataclass
70
+ class DreamBoothSubsetParams(BaseSubsetParams):
71
+ is_reg: bool = False
72
+ class_tokens: Optional[str] = None
73
+ caption_extension: str = ".caption"
74
+
75
+ @dataclass
76
+ class FineTuningSubsetParams(BaseSubsetParams):
77
+ metadata_file: Optional[str] = None
78
+
79
+ @dataclass
80
+ class ControlNetSubsetParams(BaseSubsetParams):
81
+ conditioning_data_dir: str = None
82
+ caption_extension: str = ".caption"
83
+
84
+ @dataclass
85
+ class BaseDatasetParams:
86
+ tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
87
+ max_token_length: int = None
88
+ resolution: Optional[Tuple[int, int]] = None
89
+ debug_dataset: bool = False
90
+
91
+ @dataclass
92
+ class DreamBoothDatasetParams(BaseDatasetParams):
93
+ batch_size: int = 1
94
+ enable_bucket: bool = False
95
+ min_bucket_reso: int = 256
96
+ max_bucket_reso: int = 1024
97
+ bucket_reso_steps: int = 64
98
+ bucket_no_upscale: bool = False
99
+ prior_loss_weight: float = 1.0
100
+
101
+ @dataclass
102
+ class FineTuningDatasetParams(BaseDatasetParams):
103
+ batch_size: int = 1
104
+ enable_bucket: bool = False
105
+ min_bucket_reso: int = 256
106
+ max_bucket_reso: int = 1024
107
+ bucket_reso_steps: int = 64
108
+ bucket_no_upscale: bool = False
109
+
110
+ @dataclass
111
+ class ControlNetDatasetParams(BaseDatasetParams):
112
+ batch_size: int = 1
113
+ enable_bucket: bool = False
114
+ min_bucket_reso: int = 256
115
+ max_bucket_reso: int = 1024
116
+ bucket_reso_steps: int = 64
117
+ bucket_no_upscale: bool = False
118
+
119
+ @dataclass
120
+ class SubsetBlueprint:
121
+ params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
122
+
123
+ @dataclass
124
+ class DatasetBlueprint:
125
+ is_dreambooth: bool
126
+ is_controlnet: bool
127
+ params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
128
+ subsets: Sequence[SubsetBlueprint]
129
+
130
+ @dataclass
131
+ class DatasetGroupBlueprint:
132
+ datasets: Sequence[DatasetBlueprint]
133
+ @dataclass
134
+ class Blueprint:
135
+ dataset_group: DatasetGroupBlueprint
136
+
137
+
138
+ class ConfigSanitizer:
139
+ # @curry
140
+ @staticmethod
141
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
142
+ Schema(ExactSequence([klass, klass]))(value)
143
+ return tuple(value)
144
+
145
+ # @curry
146
+ @staticmethod
147
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
148
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
149
+ try:
150
+ Schema(klass)(value)
151
+ return (value, value)
152
+ except:
153
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
154
+
155
+ # subset schema
156
+ SUBSET_ASCENDABLE_SCHEMA = {
157
+ "color_aug": bool,
158
+ "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
159
+ "flip_aug": bool,
160
+ "num_repeats": int,
161
+ "random_crop": bool,
162
+ "shuffle_caption": bool,
163
+ "keep_tokens": int,
164
+ "keep_tokens_separator": str,
165
+ "token_warmup_min": int,
166
+ "token_warmup_step": Any(float,int),
167
+ "caption_prefix": str,
168
+ "caption_suffix": str,
169
+ }
170
+ # DO means DropOut
171
+ DO_SUBSET_ASCENDABLE_SCHEMA = {
172
+ "caption_dropout_every_n_epochs": int,
173
+ "caption_dropout_rate": Any(float, int),
174
+ "caption_tag_dropout_rate": Any(float, int),
175
+ }
176
+ # DB means DreamBooth
177
+ DB_SUBSET_ASCENDABLE_SCHEMA = {
178
+ "caption_extension": str,
179
+ "class_tokens": str,
180
+ }
181
+ DB_SUBSET_DISTINCT_SCHEMA = {
182
+ Required("image_dir"): str,
183
+ "is_reg": bool,
184
+ }
185
+ # FT means FineTuning
186
+ FT_SUBSET_DISTINCT_SCHEMA = {
187
+ Required("metadata_file"): str,
188
+ "image_dir": str,
189
+ }
190
+ CN_SUBSET_ASCENDABLE_SCHEMA = {
191
+ "caption_extension": str,
192
+ }
193
+ CN_SUBSET_DISTINCT_SCHEMA = {
194
+ Required("image_dir"): str,
195
+ Required("conditioning_data_dir"): str,
196
+ }
197
+
198
+ # datasets schema
199
+ DATASET_ASCENDABLE_SCHEMA = {
200
+ "batch_size": int,
201
+ "bucket_no_upscale": bool,
202
+ "bucket_reso_steps": int,
203
+ "enable_bucket": bool,
204
+ "max_bucket_reso": int,
205
+ "min_bucket_reso": int,
206
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
207
+ }
208
+
209
+ # options handled by argparse but not handled by user config
210
+ ARGPARSE_SPECIFIC_SCHEMA = {
211
+ "debug_dataset": bool,
212
+ "max_token_length": Any(None, int),
213
+ "prior_loss_weight": Any(float, int),
214
+ }
215
+ # for handling default None value of argparse
216
+ ARGPARSE_NULLABLE_OPTNAMES = [
217
+ "face_crop_aug_range",
218
+ "resolution",
219
+ ]
220
+ # prepare map because option name may differ among argparse and user config
221
+ ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
222
+ "train_batch_size": "batch_size",
223
+ "dataset_repeats": "num_repeats",
224
+ }
225
+
226
+ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
227
+ assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
228
+
229
+ self.db_subset_schema = self.__merge_dict(
230
+ self.SUBSET_ASCENDABLE_SCHEMA,
231
+ self.DB_SUBSET_DISTINCT_SCHEMA,
232
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
233
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
234
+ )
235
+
236
+ self.ft_subset_schema = self.__merge_dict(
237
+ self.SUBSET_ASCENDABLE_SCHEMA,
238
+ self.FT_SUBSET_DISTINCT_SCHEMA,
239
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
240
+ )
241
+
242
+ self.cn_subset_schema = self.__merge_dict(
243
+ self.SUBSET_ASCENDABLE_SCHEMA,
244
+ self.CN_SUBSET_DISTINCT_SCHEMA,
245
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
246
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
247
+ )
248
+
249
+ self.db_dataset_schema = self.__merge_dict(
250
+ self.DATASET_ASCENDABLE_SCHEMA,
251
+ self.SUBSET_ASCENDABLE_SCHEMA,
252
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
253
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
254
+ {"subsets": [self.db_subset_schema]},
255
+ )
256
+
257
+ self.ft_dataset_schema = self.__merge_dict(
258
+ self.DATASET_ASCENDABLE_SCHEMA,
259
+ self.SUBSET_ASCENDABLE_SCHEMA,
260
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
261
+ {"subsets": [self.ft_subset_schema]},
262
+ )
263
+
264
+ self.cn_dataset_schema = self.__merge_dict(
265
+ self.DATASET_ASCENDABLE_SCHEMA,
266
+ self.SUBSET_ASCENDABLE_SCHEMA,
267
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
268
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
269
+ {"subsets": [self.cn_subset_schema]},
270
+ )
271
+
272
+ if support_dreambooth and support_finetuning:
273
+ def validate_flex_dataset(dataset_config: dict):
274
+ subsets_config = dataset_config.get("subsets", [])
275
+
276
+ if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
277
+ return Schema(self.cn_dataset_schema)(dataset_config)
278
+ # check dataset meets FT style
279
+ # NOTE: all FT subsets should have "metadata_file"
280
+ elif all(["metadata_file" in subset for subset in subsets_config]):
281
+ return Schema(self.ft_dataset_schema)(dataset_config)
282
+ # check dataset meets DB style
283
+ # NOTE: all DB subsets should have no "metadata_file"
284
+ elif all(["metadata_file" not in subset for subset in subsets_config]):
285
+ return Schema(self.db_dataset_schema)(dataset_config)
286
+ else:
287
+ raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。")
288
+
289
+ self.dataset_schema = validate_flex_dataset
290
+ elif support_dreambooth:
291
+ self.dataset_schema = self.db_dataset_schema
292
+ elif support_finetuning:
293
+ self.dataset_schema = self.ft_dataset_schema
294
+ elif support_controlnet:
295
+ self.dataset_schema = self.cn_dataset_schema
296
+
297
+ self.general_schema = self.__merge_dict(
298
+ self.DATASET_ASCENDABLE_SCHEMA,
299
+ self.SUBSET_ASCENDABLE_SCHEMA,
300
+ self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
301
+ self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
302
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
303
+ )
304
+
305
+ self.user_config_validator = Schema({
306
+ "general": self.general_schema,
307
+ "datasets": [self.dataset_schema],
308
+ })
309
+
310
+ self.argparse_schema = self.__merge_dict(
311
+ self.general_schema,
312
+ self.ARGPARSE_SPECIFIC_SCHEMA,
313
+ {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
314
+ {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
315
+ )
316
+
317
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
318
+
319
+ def sanitize_user_config(self, user_config: dict) -> dict:
320
+ try:
321
+ return self.user_config_validator(user_config)
322
+ except MultipleInvalid:
323
+ # TODO: エラー発生時のメッセージをわかりやすくする
324
+ print("Invalid user config / ユーザ設定の形式が正しくないようです")
325
+ raise
326
+
327
+ # NOTE: In nature, argument parser result is not needed to be sanitize
328
+ # However this will help us to detect program bug
329
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
330
+ try:
331
+ return self.argparse_config_validator(argparse_namespace)
332
+ except MultipleInvalid:
333
+ # XXX: this should be a bug
334
+ print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
335
+ raise
336
+
337
+ # NOTE: value would be overwritten by latter dict if there is already the same key
338
+ @staticmethod
339
+ def __merge_dict(*dict_list: dict) -> dict:
340
+ merged = {}
341
+ for schema in dict_list:
342
+ # merged |= schema
343
+ for k, v in schema.items():
344
+ merged[k] = v
345
+ return merged
346
+
347
+
348
+ class BlueprintGenerator:
349
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {
350
+ }
351
+
352
+ def __init__(self, sanitizer: ConfigSanitizer):
353
+ self.sanitizer = sanitizer
354
+
355
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
356
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
357
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
358
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
359
+
360
+ # convert argparse namespace to dict like config
361
+ # NOTE: it is ok to have extra entries in dict
362
+ optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
363
+ argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()}
364
+
365
+ general_config = sanitized_user_config.get("general", {})
366
+
367
+ dataset_blueprints = []
368
+ for dataset_config in sanitized_user_config.get("datasets", []):
369
+ # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
370
+ subsets = dataset_config.get("subsets", [])
371
+ is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
372
+ is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
373
+ if is_controlnet:
374
+ subset_params_klass = ControlNetSubsetParams
375
+ dataset_params_klass = ControlNetDatasetParams
376
+ elif is_dreambooth:
377
+ subset_params_klass = DreamBoothSubsetParams
378
+ dataset_params_klass = DreamBoothDatasetParams
379
+ else:
380
+ subset_params_klass = FineTuningSubsetParams
381
+ dataset_params_klass = FineTuningDatasetParams
382
+
383
+ subset_blueprints = []
384
+ for subset_config in subsets:
385
+ params = self.generate_params_by_fallbacks(subset_params_klass,
386
+ [subset_config, dataset_config, general_config, argparse_config, runtime_params])
387
+ subset_blueprints.append(SubsetBlueprint(params))
388
+
389
+ params = self.generate_params_by_fallbacks(dataset_params_klass,
390
+ [dataset_config, general_config, argparse_config, runtime_params])
391
+ dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
392
+
393
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
394
+
395
+ return Blueprint(dataset_group_blueprint)
396
+
397
+ @staticmethod
398
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
399
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
400
+ search_value = BlueprintGenerator.search_value
401
+ default_params = asdict(param_klass())
402
+ param_names = default_params.keys()
403
+
404
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
405
+
406
+ return param_klass(**params)
407
+
408
+ @staticmethod
409
+ def search_value(key: str, fallbacks: Sequence[dict], default_value = None):
410
+ for cand in fallbacks:
411
+ value = cand.get(key)
412
+ if value is not None:
413
+ return value
414
+
415
+ return default_value
416
+
417
+
418
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
419
+ datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
420
+
421
+ for dataset_blueprint in dataset_group_blueprint.datasets:
422
+ if dataset_blueprint.is_controlnet:
423
+ subset_klass = ControlNetSubset
424
+ dataset_klass = ControlNetDataset
425
+ elif dataset_blueprint.is_dreambooth:
426
+ subset_klass = DreamBoothSubset
427
+ dataset_klass = DreamBoothDataset
428
+ else:
429
+ subset_klass = FineTuningSubset
430
+ dataset_klass = FineTuningDataset
431
+
432
+ subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
433
+ dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
434
+ datasets.append(dataset)
435
+
436
+ # print info
437
+ info = ""
438
+ for i, dataset in enumerate(datasets):
439
+ is_dreambooth = isinstance(dataset, DreamBoothDataset)
440
+ is_controlnet = isinstance(dataset, ControlNetDataset)
441
+ info += dedent(f"""\
442
+ [Dataset {i}]
443
+ batch_size: {dataset.batch_size}
444
+ resolution: {(dataset.width, dataset.height)}
445
+ enable_bucket: {dataset.enable_bucket}
446
+ """)
447
+
448
+ if dataset.enable_bucket:
449
+ info += indent(dedent(f"""\
450
+ min_bucket_reso: {dataset.min_bucket_reso}
451
+ max_bucket_reso: {dataset.max_bucket_reso}
452
+ bucket_reso_steps: {dataset.bucket_reso_steps}
453
+ bucket_no_upscale: {dataset.bucket_no_upscale}
454
+ \n"""), " ")
455
+ else:
456
+ info += "\n"
457
+
458
+ for j, subset in enumerate(dataset.subsets):
459
+ info += indent(dedent(f"""\
460
+ [Subset {j} of Dataset {i}]
461
+ image_dir: "{subset.image_dir}"
462
+ image_count: {subset.img_count}
463
+ num_repeats: {subset.num_repeats}
464
+ shuffle_caption: {subset.shuffle_caption}
465
+ keep_tokens: {subset.keep_tokens}
466
+ keep_tokens_separator: {subset.keep_tokens_separator}
467
+ caption_dropout_rate: {subset.caption_dropout_rate}
468
+ caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
469
+ caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
470
+ caption_prefix: {subset.caption_prefix}
471
+ caption_suffix: {subset.caption_suffix}
472
+ color_aug: {subset.color_aug}
473
+ flip_aug: {subset.flip_aug}
474
+ face_crop_aug_range: {subset.face_crop_aug_range}
475
+ random_crop: {subset.random_crop}
476
+ token_warmup_min: {subset.token_warmup_min},
477
+ token_warmup_step: {subset.token_warmup_step},
478
+ """), " ")
479
+
480
+ if is_dreambooth:
481
+ info += indent(dedent(f"""\
482
+ is_reg: {subset.is_reg}
483
+ class_tokens: {subset.class_tokens}
484
+ caption_extension: {subset.caption_extension}
485
+ \n"""), " ")
486
+ elif not is_controlnet:
487
+ info += indent(dedent(f"""\
488
+ metadata_file: {subset.metadata_file}
489
+ \n"""), " ")
490
+
491
+ print(info)
492
+
493
+ # make buckets first because it determines the length of dataset
494
+ # and set the same seed for all datasets
495
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
496
+ for i, dataset in enumerate(datasets):
497
+ print(f"[Dataset {i}]")
498
+ dataset.make_buckets()
499
+ dataset.set_seed(seed)
500
+
501
+ return DatasetGroup(datasets)
502
+
503
+
504
+ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
505
+ def extract_dreambooth_params(name: str) -> Tuple[int, str]:
506
+ tokens = name.split('_')
507
+ try:
508
+ n_repeats = int(tokens[0])
509
+ except ValueError as e:
510
+ print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
511
+ return 0, ""
512
+ caption_by_folder = '_'.join(tokens[1:])
513
+ return n_repeats, caption_by_folder
514
+
515
+ def generate(base_dir: Optional[str], is_reg: bool):
516
+ if base_dir is None:
517
+ return []
518
+
519
+ base_dir: Path = Path(base_dir)
520
+ if not base_dir.is_dir():
521
+ return []
522
+
523
+ subsets_config = []
524
+ for subdir in base_dir.iterdir():
525
+ if not subdir.is_dir():
526
+ continue
527
+
528
+ num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
529
+ if num_repeats < 1:
530
+ continue
531
+
532
+ subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
533
+ subsets_config.append(subset_config)
534
+
535
+ return subsets_config
536
+
537
+ subsets_config = []
538
+ subsets_config += generate(train_data_dir, False)
539
+ subsets_config += generate(reg_data_dir, True)
540
+
541
+ return subsets_config
542
+
543
+
544
+ def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"):
545
+ def generate(base_dir: Optional[str]):
546
+ if base_dir is None:
547
+ return []
548
+
549
+ base_dir: Path = Path(base_dir)
550
+ if not base_dir.is_dir():
551
+ return []
552
+
553
+ subsets_config = []
554
+ subset_config = {"image_dir": train_data_dir, "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1}
555
+ subsets_config.append(subset_config)
556
+
557
+ return subsets_config
558
+
559
+ subsets_config = []
560
+ subsets_config += generate(train_data_dir)
561
+
562
+ return subsets_config
563
+
564
+
565
+ def load_user_config(file: str) -> dict:
566
+ file: Path = Path(file)
567
+ if not file.is_file():
568
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
569
+
570
+ if file.name.lower().endswith('.json'):
571
+ try:
572
+ with open(file, 'r') as f:
573
+ config = json.load(f)
574
+ except Exception:
575
+ print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
576
+ raise
577
+ elif file.name.lower().endswith('.toml'):
578
+ try:
579
+ config = toml.load(file)
580
+ except Exception:
581
+ print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
582
+ raise
583
+ else:
584
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
585
+
586
+ return config
587
+
588
+ # for config test
589
+ if __name__ == "__main__":
590
+ parser = argparse.ArgumentParser()
591
+ parser.add_argument("--support_dreambooth", action="store_true")
592
+ parser.add_argument("--support_finetuning", action="store_true")
593
+ parser.add_argument("--support_controlnet", action="store_true")
594
+ parser.add_argument("--support_dropout", action="store_true")
595
+ parser.add_argument("dataset_config")
596
+ config_args, remain = parser.parse_known_args()
597
+
598
+ parser = argparse.ArgumentParser()
599
+ train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
600
+ train_util.add_training_arguments(parser, config_args.support_dreambooth)
601
+ argparse_namespace = parser.parse_args(remain)
602
+ train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
603
+
604
+ print("[argparse_namespace]")
605
+ print(vars(argparse_namespace))
606
+
607
+ user_config = load_user_config(config_args.dataset_config)
608
+
609
+ print("\n[user_config]")
610
+ print(user_config)
611
+
612
+ sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout)
613
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
614
+
615
+ print("\n[sanitized_user_config]")
616
+ print(sanitized_user_config)
617
+
618
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
619
+
620
+ print("\n[blueprint]")
621
+ print(blueprint)
external/llite/library/custom_train_functions.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import random
4
+ import re
5
+ from typing import List, Optional, Union
6
+
7
+
8
+ def prepare_scheduler_for_custom_training(noise_scheduler, device):
9
+ if hasattr(noise_scheduler, "all_snr"):
10
+ return
11
+
12
+ alphas_cumprod = noise_scheduler.alphas_cumprod
13
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
14
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
15
+ alpha = sqrt_alphas_cumprod
16
+ sigma = sqrt_one_minus_alphas_cumprod
17
+ all_snr = (alpha / sigma) ** 2
18
+
19
+ noise_scheduler.all_snr = all_snr.to(device)
20
+
21
+
22
+ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
23
+ # fix beta: zero terminal SNR
24
+ print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
25
+
26
+ def enforce_zero_terminal_snr(betas):
27
+ # Convert betas to alphas_bar_sqrt
28
+ alphas = 1 - betas
29
+ alphas_bar = alphas.cumprod(0)
30
+ alphas_bar_sqrt = alphas_bar.sqrt()
31
+
32
+ # Store old values.
33
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
34
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
35
+ # Shift so last timestep is zero.
36
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
37
+ # Scale so first timestep is back to old value.
38
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
39
+
40
+ # Convert alphas_bar_sqrt to betas
41
+ alphas_bar = alphas_bar_sqrt**2
42
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
43
+ alphas = torch.cat([alphas_bar[0:1], alphas])
44
+ betas = 1 - alphas
45
+ return betas
46
+
47
+ betas = noise_scheduler.betas
48
+ betas = enforce_zero_terminal_snr(betas)
49
+ alphas = 1.0 - betas
50
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
51
+
52
+ # print("original:", noise_scheduler.betas)
53
+ # print("fixed:", betas)
54
+
55
+ noise_scheduler.betas = betas
56
+ noise_scheduler.alphas = alphas
57
+ noise_scheduler.alphas_cumprod = alphas_cumprod
58
+
59
+
60
+ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
61
+ snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
62
+ min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
63
+ if v_prediction:
64
+ snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
65
+ else:
66
+ snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
67
+ loss = loss * snr_weight
68
+ return loss
69
+
70
+
71
+ def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
72
+ scale = get_snr_scale(timesteps, noise_scheduler)
73
+ loss = loss * scale
74
+ return loss
75
+
76
+
77
+ def get_snr_scale(timesteps, noise_scheduler):
78
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
79
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
80
+ scale = snr_t / (snr_t + 1)
81
+ # # show debug info
82
+ # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
83
+ return scale
84
+
85
+
86
+ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
87
+ scale = get_snr_scale(timesteps, noise_scheduler)
88
+ # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
89
+ loss = loss + loss / scale * v_pred_like_loss
90
+ return loss
91
+
92
+ def apply_debiased_estimation(loss, timesteps, noise_scheduler):
93
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
94
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
95
+ weight = 1/torch.sqrt(snr_t)
96
+ loss = weight * loss
97
+ return loss
98
+
99
+ # TODO train_utilと分散しているのでどちらかに寄せる
100
+
101
+
102
+ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
103
+ parser.add_argument(
104
+ "--min_snr_gamma",
105
+ type=float,
106
+ default=None,
107
+ help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
108
+ )
109
+ parser.add_argument(
110
+ "--scale_v_pred_loss_like_noise_pred",
111
+ action="store_true",
112
+ help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
113
+ )
114
+ parser.add_argument(
115
+ "--v_pred_like_loss",
116
+ type=float,
117
+ default=None,
118
+ help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
119
+ )
120
+ parser.add_argument(
121
+ "--debiased_estimation_loss",
122
+ action="store_true",
123
+ help="debiased estimation loss / debiased estimation loss",
124
+ )
125
+ if support_weighted_captions:
126
+ parser.add_argument(
127
+ "--weighted_captions",
128
+ action="store_true",
129
+ default=False,
130
+ help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
131
+ )
132
+
133
+
134
+ re_attention = re.compile(
135
+ r"""
136
+ \\\(|
137
+ \\\)|
138
+ \\\[|
139
+ \\]|
140
+ \\\\|
141
+ \\|
142
+ \(|
143
+ \[|
144
+ :([+-]?[.\d]+)\)|
145
+ \)|
146
+ ]|
147
+ [^\\()\[\]:]+|
148
+ :
149
+ """,
150
+ re.X,
151
+ )
152
+
153
+
154
+ def parse_prompt_attention(text):
155
+ """
156
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
157
+ Accepted tokens are:
158
+ (abc) - increases attention to abc by a multiplier of 1.1
159
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
160
+ [abc] - decreases attention to abc by a multiplier of 1.1
161
+ \( - literal character '('
162
+ \[ - literal character '['
163
+ \) - literal character ')'
164
+ \] - literal character ']'
165
+ \\ - literal character '\'
166
+ anything else - just text
167
+ >>> parse_prompt_attention('normal text')
168
+ [['normal text', 1.0]]
169
+ >>> parse_prompt_attention('an (important) word')
170
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
171
+ >>> parse_prompt_attention('(unbalanced')
172
+ [['unbalanced', 1.1]]
173
+ >>> parse_prompt_attention('\(literal\]')
174
+ [['(literal]', 1.0]]
175
+ >>> parse_prompt_attention('(unnecessary)(parens)')
176
+ [['unnecessaryparens', 1.1]]
177
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
178
+ [['a ', 1.0],
179
+ ['house', 1.5730000000000004],
180
+ [' ', 1.1],
181
+ ['on', 1.0],
182
+ [' a ', 1.1],
183
+ ['hill', 0.55],
184
+ [', sun, ', 1.1],
185
+ ['sky', 1.4641000000000006],
186
+ ['.', 1.1]]
187
+ """
188
+
189
+ res = []
190
+ round_brackets = []
191
+ square_brackets = []
192
+
193
+ round_bracket_multiplier = 1.1
194
+ square_bracket_multiplier = 1 / 1.1
195
+
196
+ def multiply_range(start_position, multiplier):
197
+ for p in range(start_position, len(res)):
198
+ res[p][1] *= multiplier
199
+
200
+ for m in re_attention.finditer(text):
201
+ text = m.group(0)
202
+ weight = m.group(1)
203
+
204
+ if text.startswith("\\"):
205
+ res.append([text[1:], 1.0])
206
+ elif text == "(":
207
+ round_brackets.append(len(res))
208
+ elif text == "[":
209
+ square_brackets.append(len(res))
210
+ elif weight is not None and len(round_brackets) > 0:
211
+ multiply_range(round_brackets.pop(), float(weight))
212
+ elif text == ")" and len(round_brackets) > 0:
213
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
214
+ elif text == "]" and len(square_brackets) > 0:
215
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
216
+ else:
217
+ res.append([text, 1.0])
218
+
219
+ for pos in round_brackets:
220
+ multiply_range(pos, round_bracket_multiplier)
221
+
222
+ for pos in square_brackets:
223
+ multiply_range(pos, square_bracket_multiplier)
224
+
225
+ if len(res) == 0:
226
+ res = [["", 1.0]]
227
+
228
+ # merge runs of identical weights
229
+ i = 0
230
+ while i + 1 < len(res):
231
+ if res[i][1] == res[i + 1][1]:
232
+ res[i][0] += res[i + 1][0]
233
+ res.pop(i + 1)
234
+ else:
235
+ i += 1
236
+
237
+ return res
238
+
239
+
240
+ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
241
+ r"""
242
+ Tokenize a list of prompts and return its tokens with weights of each token.
243
+
244
+ No padding, starting or ending token is included.
245
+ """
246
+ tokens = []
247
+ weights = []
248
+ truncated = False
249
+ for text in prompt:
250
+ texts_and_weights = parse_prompt_attention(text)
251
+ text_token = []
252
+ text_weight = []
253
+ for word, weight in texts_and_weights:
254
+ # tokenize and discard the starting and the ending token
255
+ token = tokenizer(word).input_ids[1:-1]
256
+ text_token += token
257
+ # copy the weight by length of token
258
+ text_weight += [weight] * len(token)
259
+ # stop if the text is too long (longer than truncation limit)
260
+ if len(text_token) > max_length:
261
+ truncated = True
262
+ break
263
+ # truncate
264
+ if len(text_token) > max_length:
265
+ truncated = True
266
+ text_token = text_token[:max_length]
267
+ text_weight = text_weight[:max_length]
268
+ tokens.append(text_token)
269
+ weights.append(text_weight)
270
+ if truncated:
271
+ print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
272
+ return tokens, weights
273
+
274
+
275
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
276
+ r"""
277
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
278
+ """
279
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
280
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
281
+ for i in range(len(tokens)):
282
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
283
+ if no_boseos_middle:
284
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
285
+ else:
286
+ w = []
287
+ if len(weights[i]) == 0:
288
+ w = [1.0] * weights_length
289
+ else:
290
+ for j in range(max_embeddings_multiples):
291
+ w.append(1.0) # weight for starting token in this chunk
292
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
293
+ w.append(1.0) # weight for ending token in this chunk
294
+ w += [1.0] * (weights_length - len(w))
295
+ weights[i] = w[:]
296
+
297
+ return tokens, weights
298
+
299
+
300
+ def get_unweighted_text_embeddings(
301
+ tokenizer,
302
+ text_encoder,
303
+ text_input: torch.Tensor,
304
+ chunk_length: int,
305
+ clip_skip: int,
306
+ eos: int,
307
+ pad: int,
308
+ no_boseos_middle: Optional[bool] = True,
309
+ ):
310
+ """
311
+ When the length of tokens is a multiple of the capacity of the text encoder,
312
+ it should be split into chunks and sent to the text encoder individually.
313
+ """
314
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
315
+ if max_embeddings_multiples > 1:
316
+ text_embeddings = []
317
+ for i in range(max_embeddings_multiples):
318
+ # extract the i-th chunk
319
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
320
+
321
+ # cover the head and the tail by the starting and the ending tokens
322
+ text_input_chunk[:, 0] = text_input[0, 0]
323
+ if pad == eos: # v1
324
+ text_input_chunk[:, -1] = text_input[0, -1]
325
+ else: # v2
326
+ for j in range(len(text_input_chunk)):
327
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
328
+ text_input_chunk[j, -1] = eos
329
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
330
+ text_input_chunk[j, 1] = eos
331
+
332
+ if clip_skip is None or clip_skip == 1:
333
+ text_embedding = text_encoder(text_input_chunk)[0]
334
+ else:
335
+ enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
336
+ text_embedding = enc_out["hidden_states"][-clip_skip]
337
+ text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
338
+
339
+ if no_boseos_middle:
340
+ if i == 0:
341
+ # discard the ending token
342
+ text_embedding = text_embedding[:, :-1]
343
+ elif i == max_embeddings_multiples - 1:
344
+ # discard the starting token
345
+ text_embedding = text_embedding[:, 1:]
346
+ else:
347
+ # discard both starting and ending tokens
348
+ text_embedding = text_embedding[:, 1:-1]
349
+
350
+ text_embeddings.append(text_embedding)
351
+ text_embeddings = torch.concat(text_embeddings, axis=1)
352
+ else:
353
+ if clip_skip is None or clip_skip == 1:
354
+ text_embeddings = text_encoder(text_input)[0]
355
+ else:
356
+ enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
357
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
358
+ text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
359
+ return text_embeddings
360
+
361
+
362
+ def get_weighted_text_embeddings(
363
+ tokenizer,
364
+ text_encoder,
365
+ prompt: Union[str, List[str]],
366
+ device,
367
+ max_embeddings_multiples: Optional[int] = 3,
368
+ no_boseos_middle: Optional[bool] = False,
369
+ clip_skip=None,
370
+ ):
371
+ r"""
372
+ Prompts can be assigned with local weights using brackets. For example,
373
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
374
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
375
+
376
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
377
+
378
+ Args:
379
+ prompt (`str` or `List[str]`):
380
+ The prompt or prompts to guide the image generation.
381
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
382
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
383
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
384
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
385
+ ending token in each of the chunk in the middle.
386
+ skip_parsing (`bool`, *optional*, defaults to `False`):
387
+ Skip the parsing of brackets.
388
+ skip_weighting (`bool`, *optional*, defaults to `False`):
389
+ Skip the weighting. When the parsing is skipped, it is forced True.
390
+ """
391
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
392
+ if isinstance(prompt, str):
393
+ prompt = [prompt]
394
+
395
+ prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
396
+
397
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
398
+ max_length = max([len(token) for token in prompt_tokens])
399
+
400
+ max_embeddings_multiples = min(
401
+ max_embeddings_multiples,
402
+ (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
403
+ )
404
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
405
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
406
+
407
+ # pad the length of tokens and weights
408
+ bos = tokenizer.bos_token_id
409
+ eos = tokenizer.eos_token_id
410
+ pad = tokenizer.pad_token_id
411
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
412
+ prompt_tokens,
413
+ prompt_weights,
414
+ max_length,
415
+ bos,
416
+ eos,
417
+ no_boseos_middle=no_boseos_middle,
418
+ chunk_length=tokenizer.model_max_length,
419
+ )
420
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
421
+
422
+ # get the embeddings
423
+ text_embeddings = get_unweighted_text_embeddings(
424
+ tokenizer,
425
+ text_encoder,
426
+ prompt_tokens,
427
+ tokenizer.model_max_length,
428
+ clip_skip,
429
+ eos,
430
+ pad,
431
+ no_boseos_middle=no_boseos_middle,
432
+ )
433
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
434
+
435
+ # assign weights to the prompts and normalize in the sense of mean
436
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
437
+ text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
438
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
439
+ text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
440
+
441
+ return text_embeddings
442
+
443
+
444
+ # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
445
+ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
446
+ b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
447
+ u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
448
+ for i in range(iterations):
449
+ r = random.random() * 2 + 2 # Rather than always going 2x,
450
+ wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
451
+ noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
452
+ if wn == 1 or hn == 1:
453
+ break # Lowest resolution is 1x1
454
+ return noise / noise.std() # Scaled back to roughly unit variance
455
+
456
+
457
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
458
+ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
459
+ if noise_offset is None:
460
+ return noise
461
+ if adaptive_noise_scale is not None:
462
+ # latent shape: (batch_size, channels, height, width)
463
+ # abs mean value for each channel
464
+ latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
465
+
466
+ # multiply adaptive noise scale to the mean value and add it to the noise offset
467
+ noise_offset = noise_offset + adaptive_noise_scale * latent_mean
468
+ noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
469
+
470
+ noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
471
+ return noise
472
+
473
+
474
+ """
475
+ ##########################################
476
+ # Perlin Noise
477
+ def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
478
+ delta = (res[0] / shape[0], res[1] / shape[1])
479
+ d = (shape[0] // res[0], shape[1] // res[1])
480
+
481
+ grid = (
482
+ torch.stack(
483
+ torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
484
+ dim=-1,
485
+ )
486
+ % 1
487
+ )
488
+ angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
489
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
490
+
491
+ tile_grads = (
492
+ lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
493
+ .repeat_interleave(d[0], 0)
494
+ .repeat_interleave(d[1], 1)
495
+ )
496
+ dot = lambda grad, shift: (
497
+ torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
498
+ * grad[: shape[0], : shape[1]]
499
+ ).sum(dim=-1)
500
+
501
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
502
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
503
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
504
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
505
+ t = fade(grid[: shape[0], : shape[1]])
506
+ return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
507
+
508
+
509
+ def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
510
+ noise = torch.zeros(shape, device=device)
511
+ frequency = 1
512
+ amplitude = 1
513
+ for _ in range(octaves):
514
+ noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
515
+ frequency *= 2
516
+ amplitude *= persistence
517
+ return noise
518
+
519
+
520
+ def perlin_noise(noise, device, octaves):
521
+ _, c, w, h = noise.shape
522
+ perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
523
+ noise_perlin = []
524
+ for _ in range(c):
525
+ noise_perlin.append(perlin())
526
+ noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
527
+ noise += noise_perlin # broadcast for each batch
528
+ return noise / noise.std() # Scaled back to roughly unit variance
529
+ """
external/llite/library/huggingface_util.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, BinaryIO
2
+ from huggingface_hub import HfApi
3
+ from pathlib import Path
4
+ import argparse
5
+ import os
6
+ from external.llite.library.utils import fire_in_thread
7
+
8
+
9
+ def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
10
+ api = HfApi(
11
+ token=token,
12
+ )
13
+ try:
14
+ api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
15
+ return True
16
+ except:
17
+ return False
18
+
19
+
20
+ def upload(
21
+ args: argparse.Namespace,
22
+ src: Union[str, Path, bytes, BinaryIO],
23
+ dest_suffix: str = "",
24
+ force_sync_upload: bool = False,
25
+ ):
26
+ repo_id = args.huggingface_repo_id
27
+ repo_type = args.huggingface_repo_type
28
+ token = args.huggingface_token
29
+ path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
30
+ private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
31
+ api = HfApi(token=token)
32
+ if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
33
+ try:
34
+ api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
35
+ except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
36
+ print("===========================================")
37
+ print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
38
+ print("===========================================")
39
+
40
+ is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
41
+
42
+ def uploader():
43
+ try:
44
+ if is_folder:
45
+ api.upload_folder(
46
+ repo_id=repo_id,
47
+ repo_type=repo_type,
48
+ folder_path=src,
49
+ path_in_repo=path_in_repo,
50
+ )
51
+ else:
52
+ api.upload_file(
53
+ repo_id=repo_id,
54
+ repo_type=repo_type,
55
+ path_or_fileobj=src,
56
+ path_in_repo=path_in_repo,
57
+ )
58
+ except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
59
+ print("===========================================")
60
+ print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
61
+ print("===========================================")
62
+
63
+ if args.async_upload and not force_sync_upload:
64
+ fire_in_thread(uploader)
65
+ else:
66
+ uploader()
67
+
68
+
69
+ def list_dir(
70
+ repo_id: str,
71
+ subfolder: str,
72
+ repo_type: str,
73
+ revision: str = "main",
74
+ token: str = None,
75
+ ):
76
+ api = HfApi(
77
+ token=token,
78
+ )
79
+ repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
80
+ file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
81
+ return file_list
external/llite/library/hypernetwork.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from diffusers.models.attention_processor import (
4
+ Attention,
5
+ AttnProcessor2_0,
6
+ SlicedAttnProcessor,
7
+ XFormersAttnProcessor
8
+ )
9
+
10
+ try:
11
+ import xformers.ops
12
+ except:
13
+ xformers = None
14
+
15
+
16
+ loaded_networks = []
17
+
18
+
19
+ def apply_single_hypernetwork(
20
+ hypernetwork, hidden_states, encoder_hidden_states
21
+ ):
22
+ context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
23
+ return context_k, context_v
24
+
25
+
26
+ def apply_hypernetworks(context_k, context_v, layer=None):
27
+ if len(loaded_networks) == 0:
28
+ return context_v, context_v
29
+ for hypernetwork in loaded_networks:
30
+ context_k, context_v = hypernetwork.forward(context_k, context_v)
31
+
32
+ context_k = context_k.to(dtype=context_k.dtype)
33
+ context_v = context_v.to(dtype=context_k.dtype)
34
+
35
+ return context_k, context_v
36
+
37
+
38
+
39
+ def xformers_forward(
40
+ self: XFormersAttnProcessor,
41
+ attn: Attention,
42
+ hidden_states: torch.Tensor,
43
+ encoder_hidden_states: torch.Tensor = None,
44
+ attention_mask: torch.Tensor = None,
45
+ ):
46
+ batch_size, sequence_length, _ = (
47
+ hidden_states.shape
48
+ if encoder_hidden_states is None
49
+ else encoder_hidden_states.shape
50
+ )
51
+
52
+ attention_mask = attn.prepare_attention_mask(
53
+ attention_mask, sequence_length, batch_size
54
+ )
55
+
56
+ query = attn.to_q(hidden_states)
57
+
58
+ if encoder_hidden_states is None:
59
+ encoder_hidden_states = hidden_states
60
+ elif attn.norm_cross:
61
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
62
+
63
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
64
+
65
+ key = attn.to_k(context_k)
66
+ value = attn.to_v(context_v)
67
+
68
+ query = attn.head_to_batch_dim(query).contiguous()
69
+ key = attn.head_to_batch_dim(key).contiguous()
70
+ value = attn.head_to_batch_dim(value).contiguous()
71
+
72
+ hidden_states = xformers.ops.memory_efficient_attention(
73
+ query,
74
+ key,
75
+ value,
76
+ attn_bias=attention_mask,
77
+ op=self.attention_op,
78
+ scale=attn.scale,
79
+ )
80
+ hidden_states = hidden_states.to(query.dtype)
81
+ hidden_states = attn.batch_to_head_dim(hidden_states)
82
+
83
+ # linear proj
84
+ hidden_states = attn.to_out[0](hidden_states)
85
+ # dropout
86
+ hidden_states = attn.to_out[1](hidden_states)
87
+ return hidden_states
88
+
89
+
90
+ def sliced_attn_forward(
91
+ self: SlicedAttnProcessor,
92
+ attn: Attention,
93
+ hidden_states: torch.Tensor,
94
+ encoder_hidden_states: torch.Tensor = None,
95
+ attention_mask: torch.Tensor = None,
96
+ ):
97
+ batch_size, sequence_length, _ = (
98
+ hidden_states.shape
99
+ if encoder_hidden_states is None
100
+ else encoder_hidden_states.shape
101
+ )
102
+ attention_mask = attn.prepare_attention_mask(
103
+ attention_mask, sequence_length, batch_size
104
+ )
105
+
106
+ query = attn.to_q(hidden_states)
107
+ dim = query.shape[-1]
108
+ query = attn.head_to_batch_dim(query)
109
+
110
+ if encoder_hidden_states is None:
111
+ encoder_hidden_states = hidden_states
112
+ elif attn.norm_cross:
113
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
114
+
115
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
116
+
117
+ key = attn.to_k(context_k)
118
+ value = attn.to_v(context_v)
119
+ key = attn.head_to_batch_dim(key)
120
+ value = attn.head_to_batch_dim(value)
121
+
122
+ batch_size_attention, query_tokens, _ = query.shape
123
+ hidden_states = torch.zeros(
124
+ (batch_size_attention, query_tokens, dim // attn.heads),
125
+ device=query.device,
126
+ dtype=query.dtype,
127
+ )
128
+
129
+ for i in range(batch_size_attention // self.slice_size):
130
+ start_idx = i * self.slice_size
131
+ end_idx = (i + 1) * self.slice_size
132
+
133
+ query_slice = query[start_idx:end_idx]
134
+ key_slice = key[start_idx:end_idx]
135
+ attn_mask_slice = (
136
+ attention_mask[start_idx:end_idx] if attention_mask is not None else None
137
+ )
138
+
139
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
140
+
141
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
142
+
143
+ hidden_states[start_idx:end_idx] = attn_slice
144
+
145
+ hidden_states = attn.batch_to_head_dim(hidden_states)
146
+
147
+ # linear proj
148
+ hidden_states = attn.to_out[0](hidden_states)
149
+ # dropout
150
+ hidden_states = attn.to_out[1](hidden_states)
151
+
152
+ return hidden_states
153
+
154
+
155
+ def v2_0_forward(
156
+ self: AttnProcessor2_0,
157
+ attn: Attention,
158
+ hidden_states,
159
+ encoder_hidden_states=None,
160
+ attention_mask=None,
161
+ ):
162
+ batch_size, sequence_length, _ = (
163
+ hidden_states.shape
164
+ if encoder_hidden_states is None
165
+ else encoder_hidden_states.shape
166
+ )
167
+ inner_dim = hidden_states.shape[-1]
168
+
169
+ if attention_mask is not None:
170
+ attention_mask = attn.prepare_attention_mask(
171
+ attention_mask, sequence_length, batch_size
172
+ )
173
+ # scaled_dot_product_attention expects attention_mask shape to be
174
+ # (batch, heads, source_length, target_length)
175
+ attention_mask = attention_mask.view(
176
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
177
+ )
178
+
179
+ query = attn.to_q(hidden_states)
180
+
181
+ if encoder_hidden_states is None:
182
+ encoder_hidden_states = hidden_states
183
+ elif attn.norm_cross:
184
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
185
+
186
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
187
+
188
+ key = attn.to_k(context_k)
189
+ value = attn.to_v(context_v)
190
+
191
+ head_dim = inner_dim // attn.heads
192
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
193
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
194
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
195
+
196
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
197
+ # TODO: add support for attn.scale when we move to Torch 2.1
198
+ hidden_states = F.scaled_dot_product_attention(
199
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
200
+ )
201
+
202
+ hidden_states = hidden_states.transpose(1, 2).reshape(
203
+ batch_size, -1, attn.heads * head_dim
204
+ )
205
+ hidden_states = hidden_states.to(query.dtype)
206
+
207
+ # linear proj
208
+ hidden_states = attn.to_out[0](hidden_states)
209
+ # dropout
210
+ hidden_states = attn.to_out[1](hidden_states)
211
+ return hidden_states
212
+
213
+
214
+ def replace_attentions_for_hypernetwork():
215
+ import diffusers.models.attention_processor
216
+
217
+ diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
218
+ xformers_forward
219
+ )
220
+ diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
221
+ sliced_attn_forward
222
+ )
223
+ diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
external/llite/library/ipex/__init__.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import contextlib
4
+ import torch
5
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
6
+ from .hijacks import ipex_hijacks
7
+
8
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
9
+
10
+ def ipex_init(): # pylint: disable=too-many-statements
11
+ try:
12
+ # Replace cuda with xpu:
13
+ torch.cuda.current_device = torch.xpu.current_device
14
+ torch.cuda.current_stream = torch.xpu.current_stream
15
+ torch.cuda.device = torch.xpu.device
16
+ torch.cuda.device_count = torch.xpu.device_count
17
+ torch.cuda.device_of = torch.xpu.device_of
18
+ torch.cuda.get_device_name = torch.xpu.get_device_name
19
+ torch.cuda.get_device_properties = torch.xpu.get_device_properties
20
+ torch.cuda.init = torch.xpu.init
21
+ torch.cuda.is_available = torch.xpu.is_available
22
+ torch.cuda.is_initialized = torch.xpu.is_initialized
23
+ torch.cuda.is_current_stream_capturing = lambda: False
24
+ torch.cuda.set_device = torch.xpu.set_device
25
+ torch.cuda.stream = torch.xpu.stream
26
+ torch.cuda.synchronize = torch.xpu.synchronize
27
+ torch.cuda.Event = torch.xpu.Event
28
+ torch.cuda.Stream = torch.xpu.Stream
29
+ torch.cuda.FloatTensor = torch.xpu.FloatTensor
30
+ torch.Tensor.cuda = torch.Tensor.xpu
31
+ torch.Tensor.is_cuda = torch.Tensor.is_xpu
32
+ torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
33
+ torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
34
+ torch.cuda._initialized = torch.xpu.lazy_init._initialized
35
+ torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
36
+ torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
37
+ torch.cuda._tls = torch.xpu.lazy_init._tls
38
+ torch.cuda.threading = torch.xpu.lazy_init.threading
39
+ torch.cuda.traceback = torch.xpu.lazy_init.traceback
40
+ torch.cuda.Optional = torch.xpu.Optional
41
+ torch.cuda.__cached__ = torch.xpu.__cached__
42
+ torch.cuda.__loader__ = torch.xpu.__loader__
43
+ torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
44
+ torch.cuda.Tuple = torch.xpu.Tuple
45
+ torch.cuda.streams = torch.xpu.streams
46
+ torch.cuda._lazy_new = torch.xpu._lazy_new
47
+ torch.cuda.FloatStorage = torch.xpu.FloatStorage
48
+ torch.cuda.Any = torch.xpu.Any
49
+ torch.cuda.__doc__ = torch.xpu.__doc__
50
+ torch.cuda.default_generators = torch.xpu.default_generators
51
+ torch.cuda.HalfTensor = torch.xpu.HalfTensor
52
+ torch.cuda._get_device_index = torch.xpu._get_device_index
53
+ torch.cuda.__path__ = torch.xpu.__path__
54
+ torch.cuda.Device = torch.xpu.Device
55
+ torch.cuda.IntTensor = torch.xpu.IntTensor
56
+ torch.cuda.ByteStorage = torch.xpu.ByteStorage
57
+ torch.cuda.set_stream = torch.xpu.set_stream
58
+ torch.cuda.BoolStorage = torch.xpu.BoolStorage
59
+ torch.cuda.os = torch.xpu.os
60
+ torch.cuda.torch = torch.xpu.torch
61
+ torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
62
+ torch.cuda.Union = torch.xpu.Union
63
+ torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
64
+ torch.cuda.ShortTensor = torch.xpu.ShortTensor
65
+ torch.cuda.LongTensor = torch.xpu.LongTensor
66
+ torch.cuda.IntStorage = torch.xpu.IntStorage
67
+ torch.cuda.LongStorage = torch.xpu.LongStorage
68
+ torch.cuda.__annotations__ = torch.xpu.__annotations__
69
+ torch.cuda.__package__ = torch.xpu.__package__
70
+ torch.cuda.__builtins__ = torch.xpu.__builtins__
71
+ torch.cuda.CharTensor = torch.xpu.CharTensor
72
+ torch.cuda.List = torch.xpu.List
73
+ torch.cuda._lazy_init = torch.xpu._lazy_init
74
+ torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
75
+ torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
76
+ torch.cuda.ByteTensor = torch.xpu.ByteTensor
77
+ torch.cuda.StreamContext = torch.xpu.StreamContext
78
+ torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
79
+ torch.cuda.ShortStorage = torch.xpu.ShortStorage
80
+ torch.cuda._lazy_call = torch.xpu._lazy_call
81
+ torch.cuda.HalfStorage = torch.xpu.HalfStorage
82
+ torch.cuda.random = torch.xpu.random
83
+ torch.cuda._device = torch.xpu._device
84
+ torch.cuda.classproperty = torch.xpu.classproperty
85
+ torch.cuda.__name__ = torch.xpu.__name__
86
+ torch.cuda._device_t = torch.xpu._device_t
87
+ torch.cuda.warnings = torch.xpu.warnings
88
+ torch.cuda.__spec__ = torch.xpu.__spec__
89
+ torch.cuda.BoolTensor = torch.xpu.BoolTensor
90
+ torch.cuda.CharStorage = torch.xpu.CharStorage
91
+ torch.cuda.__file__ = torch.xpu.__file__
92
+ torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
93
+ # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
94
+
95
+ # Memory:
96
+ torch.cuda.memory = torch.xpu.memory
97
+ if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
98
+ torch.xpu.empty_cache = lambda: None
99
+ torch.cuda.empty_cache = torch.xpu.empty_cache
100
+ torch.cuda.memory_stats = torch.xpu.memory_stats
101
+ torch.cuda.memory_summary = torch.xpu.memory_summary
102
+ torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
103
+ torch.cuda.memory_allocated = torch.xpu.memory_allocated
104
+ torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
105
+ torch.cuda.memory_reserved = torch.xpu.memory_reserved
106
+ torch.cuda.memory_cached = torch.xpu.memory_reserved
107
+ torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
108
+ torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
109
+ torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
110
+ torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
111
+ torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
112
+ torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
113
+ torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
114
+
115
+ # RNG:
116
+ torch.cuda.get_rng_state = torch.xpu.get_rng_state
117
+ torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
118
+ torch.cuda.set_rng_state = torch.xpu.set_rng_state
119
+ torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
120
+ torch.cuda.manual_seed = torch.xpu.manual_seed
121
+ torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
122
+ torch.cuda.seed = torch.xpu.seed
123
+ torch.cuda.seed_all = torch.xpu.seed_all
124
+ torch.cuda.initial_seed = torch.xpu.initial_seed
125
+
126
+ # AMP:
127
+ torch.cuda.amp = torch.xpu.amp
128
+ if not hasattr(torch.cuda.amp, "common"):
129
+ torch.cuda.amp.common = contextlib.nullcontext()
130
+ torch.cuda.amp.common.amp_definitely_not_available = lambda: False
131
+ try:
132
+ torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
133
+ except Exception: # pylint: disable=broad-exception-caught
134
+ try:
135
+ from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
136
+ gradscaler_init()
137
+ torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
138
+ except Exception: # pylint: disable=broad-exception-caught
139
+ torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
140
+
141
+ # C
142
+ torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
143
+ ipex._C._DeviceProperties.major = 2023
144
+ ipex._C._DeviceProperties.minor = 2
145
+
146
+ # Fix functions with ipex:
147
+ torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
148
+ torch._utils._get_available_device_type = lambda: "xpu"
149
+ torch.has_cuda = True
150
+ torch.cuda.has_half = True
151
+ torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
152
+ torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
153
+ torch.version.cuda = "11.7"
154
+ torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
155
+ torch.cuda.get_device_properties.major = 11
156
+ torch.cuda.get_device_properties.minor = 7
157
+ torch.cuda.ipc_collect = lambda *args, **kwargs: None
158
+ torch.cuda.utilization = lambda *args, **kwargs: 0
159
+
160
+ ipex_hijacks()
161
+ if not torch.xpu.has_fp64_dtype():
162
+ try:
163
+ from .diffusers import ipex_diffusers
164
+ ipex_diffusers()
165
+ except Exception: # pylint: disable=broad-exception-caught
166
+ pass
167
+ except Exception as e:
168
+ return False, e
169
+ return True, None
external/llite/library/ipex/attention.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
3
+
4
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
5
+
6
+ original_torch_bmm = torch.bmm
7
+ def torch_bmm_32_bit(input, mat2, *, out=None):
8
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
9
+ batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
10
+ block_multiply = input.element_size()
11
+ slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
12
+ block_size = batch_size_attention * slice_block_size
13
+
14
+ split_slice_size = batch_size_attention
15
+ if block_size > 4:
16
+ do_split = True
17
+ # Find something divisible with the input_tokens
18
+ while (split_slice_size * slice_block_size) > 4:
19
+ split_slice_size = split_slice_size // 2
20
+ if split_slice_size <= 1:
21
+ split_slice_size = 1
22
+ break
23
+ split_2_slice_size = input_tokens
24
+ if split_slice_size * slice_block_size > 4:
25
+ slice_block_size_2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
26
+ do_split_2 = True
27
+ # Find something divisible with the input_tokens
28
+ while (split_2_slice_size * slice_block_size_2) > 4:
29
+ split_2_slice_size = split_2_slice_size // 2
30
+ if split_2_slice_size <= 1:
31
+ split_2_slice_size = 1
32
+ break
33
+ else:
34
+ do_split_2 = False
35
+ else:
36
+ do_split = False
37
+
38
+ if do_split:
39
+ hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
40
+ for i in range(batch_size_attention // split_slice_size):
41
+ start_idx = i * split_slice_size
42
+ end_idx = (i + 1) * split_slice_size
43
+ if do_split_2:
44
+ for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
45
+ start_idx_2 = i2 * split_2_slice_size
46
+ end_idx_2 = (i2 + 1) * split_2_slice_size
47
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
48
+ input[start_idx:end_idx, start_idx_2:end_idx_2],
49
+ mat2[start_idx:end_idx, start_idx_2:end_idx_2],
50
+ out=out
51
+ )
52
+ else:
53
+ hidden_states[start_idx:end_idx] = original_torch_bmm(
54
+ input[start_idx:end_idx],
55
+ mat2[start_idx:end_idx],
56
+ out=out
57
+ )
58
+ else:
59
+ return original_torch_bmm(input, mat2, out=out)
60
+ return hidden_states
61
+
62
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
63
+ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
64
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
65
+ if len(query.shape) == 3:
66
+ batch_size_attention, query_tokens, shape_three = query.shape
67
+ shape_four = 1
68
+ else:
69
+ batch_size_attention, query_tokens, shape_three, shape_four = query.shape
70
+
71
+ block_multiply = query.element_size()
72
+ slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * block_multiply
73
+ block_size = batch_size_attention * slice_block_size
74
+
75
+ split_slice_size = batch_size_attention
76
+ if block_size > 4:
77
+ do_split = True
78
+ # Find something divisible with the batch_size_attention
79
+ while (split_slice_size * slice_block_size) > 4:
80
+ split_slice_size = split_slice_size // 2
81
+ if split_slice_size <= 1:
82
+ split_slice_size = 1
83
+ break
84
+ split_2_slice_size = query_tokens
85
+ if split_slice_size * slice_block_size > 4:
86
+ slice_block_size_2 = split_slice_size * shape_three * shape_four / 1024 / 1024 * block_multiply
87
+ do_split_2 = True
88
+ # Find something divisible with the query_tokens
89
+ while (split_2_slice_size * slice_block_size_2) > 4:
90
+ split_2_slice_size = split_2_slice_size // 2
91
+ if split_2_slice_size <= 1:
92
+ split_2_slice_size = 1
93
+ break
94
+ split_3_slice_size = shape_three
95
+ if split_2_slice_size * slice_block_size_2 > 4:
96
+ slice_block_size_3 = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * block_multiply
97
+ do_split_3 = True
98
+ # Find something divisible with the shape_three
99
+ while (split_3_slice_size * slice_block_size_3) > 4:
100
+ split_3_slice_size = split_3_slice_size // 2
101
+ if split_3_slice_size <= 1:
102
+ split_3_slice_size = 1
103
+ break
104
+ else:
105
+ do_split_3 = False
106
+ else:
107
+ do_split_2 = False
108
+ else:
109
+ do_split = False
110
+
111
+ if do_split:
112
+ hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
113
+ for i in range(batch_size_attention // split_slice_size):
114
+ start_idx = i * split_slice_size
115
+ end_idx = (i + 1) * split_slice_size
116
+ if do_split_2:
117
+ for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
118
+ start_idx_2 = i2 * split_2_slice_size
119
+ end_idx_2 = (i2 + 1) * split_2_slice_size
120
+ if do_split_3:
121
+ for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
122
+ start_idx_3 = i3 * split_3_slice_size
123
+ end_idx_3 = (i3 + 1) * split_3_slice_size
124
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
125
+ query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
126
+ key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
127
+ value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
128
+ attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
129
+ dropout_p=dropout_p, is_causal=is_causal
130
+ )
131
+ else:
132
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
133
+ query[start_idx:end_idx, start_idx_2:end_idx_2],
134
+ key[start_idx:end_idx, start_idx_2:end_idx_2],
135
+ value[start_idx:end_idx, start_idx_2:end_idx_2],
136
+ attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
137
+ dropout_p=dropout_p, is_causal=is_causal
138
+ )
139
+ else:
140
+ hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
141
+ query[start_idx:end_idx],
142
+ key[start_idx:end_idx],
143
+ value[start_idx:end_idx],
144
+ attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
145
+ dropout_p=dropout_p, is_causal=is_causal
146
+ )
147
+ else:
148
+ return original_scaled_dot_product_attention(
149
+ query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
150
+ )
151
+ return hidden_states
external/llite/library/ipex/diffusers.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
3
+ import diffusers #0.24.0 # pylint: disable=import-error
4
+ from diffusers.models.attention_processor import Attention
5
+
6
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
7
+
8
+ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
9
+ r"""
10
+ Processor for implementing sliced attention.
11
+
12
+ Args:
13
+ slice_size (`int`, *optional*):
14
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
15
+ `attention_head_dim` must be a multiple of the `slice_size`.
16
+ """
17
+
18
+ def __init__(self, slice_size):
19
+ self.slice_size = slice_size
20
+
21
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
22
+ residual = hidden_states
23
+
24
+ input_ndim = hidden_states.ndim
25
+
26
+ if input_ndim == 4:
27
+ batch_size, channel, height, width = hidden_states.shape
28
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
29
+
30
+ batch_size, sequence_length, _ = (
31
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
32
+ )
33
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
34
+
35
+ if attn.group_norm is not None:
36
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
37
+
38
+ query = attn.to_q(hidden_states)
39
+ dim = query.shape[-1]
40
+ query = attn.head_to_batch_dim(query)
41
+
42
+ if encoder_hidden_states is None:
43
+ encoder_hidden_states = hidden_states
44
+ elif attn.norm_cross:
45
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
46
+
47
+ key = attn.to_k(encoder_hidden_states)
48
+ value = attn.to_v(encoder_hidden_states)
49
+ key = attn.head_to_batch_dim(key)
50
+ value = attn.head_to_batch_dim(value)
51
+
52
+ batch_size_attention, query_tokens, shape_three = query.shape
53
+ hidden_states = torch.zeros(
54
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
55
+ )
56
+
57
+ #ARC GPUs can't allocate more than 4GB to a single block, Slice it:
58
+ block_multiply = query.element_size()
59
+ slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
60
+ block_size = query_tokens * slice_block_size
61
+ split_2_slice_size = query_tokens
62
+ if block_size > 4:
63
+ do_split_2 = True
64
+ #Find something divisible with the query_tokens
65
+ while (split_2_slice_size * slice_block_size) > 4:
66
+ split_2_slice_size = split_2_slice_size // 2
67
+ if split_2_slice_size <= 1:
68
+ split_2_slice_size = 1
69
+ break
70
+ else:
71
+ do_split_2 = False
72
+
73
+ for i in range(batch_size_attention // self.slice_size):
74
+ start_idx = i * self.slice_size
75
+ end_idx = (i + 1) * self.slice_size
76
+
77
+ if do_split_2:
78
+ for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
79
+ start_idx_2 = i2 * split_2_slice_size
80
+ end_idx_2 = (i2 + 1) * split_2_slice_size
81
+
82
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
83
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
84
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
85
+
86
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
87
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
88
+
89
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
90
+ else:
91
+ query_slice = query[start_idx:end_idx]
92
+ key_slice = key[start_idx:end_idx]
93
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
94
+
95
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
96
+
97
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
98
+
99
+ hidden_states[start_idx:end_idx] = attn_slice
100
+
101
+ hidden_states = attn.batch_to_head_dim(hidden_states)
102
+
103
+ # linear proj
104
+ hidden_states = attn.to_out[0](hidden_states)
105
+ # dropout
106
+ hidden_states = attn.to_out[1](hidden_states)
107
+
108
+ if input_ndim == 4:
109
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
110
+
111
+ if attn.residual_connection:
112
+ hidden_states = hidden_states + residual
113
+
114
+ hidden_states = hidden_states / attn.rescale_output_factor
115
+
116
+ return hidden_states
117
+
118
+ def ipex_diffusers():
119
+ #ARC GPUs can't allocate more than 4GB to a single block:
120
+ diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
external/llite/library/ipex/gradscaler.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import torch
3
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4
+ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
5
+
6
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
7
+
8
+ device_supports_fp64 = torch.xpu.has_fp64_dtype()
9
+ OptState = ipex.cpu.autocast._grad_scaler.OptState
10
+ _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
11
+ _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
12
+
13
+ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
14
+ per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
15
+ per_device_found_inf = _MultiDeviceReplicator(found_inf)
16
+
17
+ # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
18
+ # There could be hundreds of grads, so we'd like to iterate through them just once.
19
+ # However, we don't know their devices or dtypes in advance.
20
+
21
+ # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
22
+ # Google says mypy struggles with defaultdicts type annotations.
23
+ per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
24
+ # sync grad to master weight
25
+ if hasattr(optimizer, "sync_grad"):
26
+ optimizer.sync_grad()
27
+ with torch.no_grad():
28
+ for group in optimizer.param_groups:
29
+ for param in group["params"]:
30
+ if param.grad is None:
31
+ continue
32
+ if (not allow_fp16) and param.grad.dtype == torch.float16:
33
+ raise ValueError("Attempting to unscale FP16 gradients.")
34
+ if param.grad.is_sparse:
35
+ # is_coalesced() == False means the sparse grad has values with duplicate indices.
36
+ # coalesce() deduplicates indices and adds all values that have the same index.
37
+ # For scaled fp16 values, there's a good chance coalescing will cause overflow,
38
+ # so we should check the coalesced _values().
39
+ if param.grad.dtype is torch.float16:
40
+ param.grad = param.grad.coalesce()
41
+ to_unscale = param.grad._values()
42
+ else:
43
+ to_unscale = param.grad
44
+
45
+ # -: is there a way to split by device and dtype without appending in the inner loop?
46
+ to_unscale = to_unscale.to("cpu")
47
+ per_device_and_dtype_grads[to_unscale.device][
48
+ to_unscale.dtype
49
+ ].append(to_unscale)
50
+
51
+ for _, per_dtype_grads in per_device_and_dtype_grads.items():
52
+ for grads in per_dtype_grads.values():
53
+ core._amp_foreach_non_finite_check_and_unscale_(
54
+ grads,
55
+ per_device_found_inf.get("cpu"),
56
+ per_device_inv_scale.get("cpu"),
57
+ )
58
+
59
+ return per_device_found_inf._per_device_tensors
60
+
61
+ def unscale_(self, optimizer):
62
+ """
63
+ Divides ("unscales") the optimizer's gradient tensors by the scale factor.
64
+ :meth:`unscale_` is optional, serving cases where you need to
65
+ :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
66
+ between the backward pass(es) and :meth:`step`.
67
+ If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
68
+ Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
69
+ ...
70
+ scaler.scale(loss).backward()
71
+ scaler.unscale_(optimizer)
72
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
73
+ scaler.step(optimizer)
74
+ scaler.update()
75
+ Args:
76
+ optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
77
+ .. warning::
78
+ :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
79
+ and only after all gradients for that optimizer's assigned parameters have been accumulated.
80
+ Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
81
+ .. warning::
82
+ :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
83
+ """
84
+ if not self._enabled:
85
+ return
86
+
87
+ self._check_scale_growth_tracker("unscale_")
88
+
89
+ optimizer_state = self._per_optimizer_states[id(optimizer)]
90
+
91
+ if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
92
+ raise RuntimeError(
93
+ "unscale_() has already been called on this optimizer since the last update()."
94
+ )
95
+ elif optimizer_state["stage"] is OptState.STEPPED:
96
+ raise RuntimeError("unscale_() is being called after step().")
97
+
98
+ # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
99
+ assert self._scale is not None
100
+ if device_supports_fp64:
101
+ inv_scale = self._scale.double().reciprocal().float()
102
+ else:
103
+ inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
104
+ found_inf = torch.full(
105
+ (1,), 0.0, dtype=torch.float32, device=self._scale.device
106
+ )
107
+
108
+ optimizer_state["found_inf_per_device"] = self._unscale_grads_(
109
+ optimizer, inv_scale, found_inf, False
110
+ )
111
+ optimizer_state["stage"] = OptState.UNSCALED
112
+
113
+ def update(self, new_scale=None):
114
+ """
115
+ Updates the scale factor.
116
+ If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
117
+ to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
118
+ the scale is multiplied by ``growth_factor`` to increase it.
119
+ Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
120
+ used directly, it's used to fill GradScaler's internal scale tensor. So if
121
+ ``new_scale`` was a tensor, later in-place changes to that tensor will not further
122
+ affect the scale GradScaler uses internally.)
123
+ Args:
124
+ new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
125
+ .. warning::
126
+ :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
127
+ been invoked for all optimizers used this iteration.
128
+ """
129
+ if not self._enabled:
130
+ return
131
+
132
+ _scale, _growth_tracker = self._check_scale_growth_tracker("update")
133
+
134
+ if new_scale is not None:
135
+ # Accept a new user-defined scale.
136
+ if isinstance(new_scale, float):
137
+ self._scale.fill_(new_scale) # type: ignore[union-attr]
138
+ else:
139
+ reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
140
+ assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
141
+ assert new_scale.numel() == 1, reason
142
+ assert new_scale.requires_grad is False, reason
143
+ self._scale.copy_(new_scale) # type: ignore[union-attr]
144
+ else:
145
+ # Consume shared inf/nan data collected from optimizers to update the scale.
146
+ # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
147
+ found_infs = [
148
+ found_inf.to(device="cpu", non_blocking=True)
149
+ for state in self._per_optimizer_states.values()
150
+ for found_inf in state["found_inf_per_device"].values()
151
+ ]
152
+
153
+ assert len(found_infs) > 0, "No inf checks were recorded prior to update."
154
+
155
+ found_inf_combined = found_infs[0]
156
+ if len(found_infs) > 1:
157
+ for i in range(1, len(found_infs)):
158
+ found_inf_combined += found_infs[i]
159
+
160
+ to_device = _scale.device
161
+ _scale = _scale.to("cpu")
162
+ _growth_tracker = _growth_tracker.to("cpu")
163
+
164
+ core._amp_update_scale_(
165
+ _scale,
166
+ _growth_tracker,
167
+ found_inf_combined,
168
+ self._growth_factor,
169
+ self._backoff_factor,
170
+ self._growth_interval,
171
+ )
172
+
173
+ _scale = _scale.to(to_device)
174
+ _growth_tracker = _growth_tracker.to(to_device)
175
+ # To prepare for next iteration, clear the data collected from optimizers this iteration.
176
+ self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
177
+
178
+ def gradscaler_init():
179
+ torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
180
+ torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
181
+ torch.xpu.amp.GradScaler.unscale_ = unscale_
182
+ torch.xpu.amp.GradScaler.update = update
183
+ return torch.xpu.amp.GradScaler
external/llite/library/ipex/hijacks.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import importlib
3
+ import torch
4
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
5
+
6
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
7
+
8
+ class CondFunc: # pylint: disable=missing-class-docstring
9
+ def __new__(cls, orig_func, sub_func, cond_func):
10
+ self = super(CondFunc, cls).__new__(cls)
11
+ if isinstance(orig_func, str):
12
+ func_path = orig_func.split('.')
13
+ for i in range(len(func_path)-1, -1, -1):
14
+ try:
15
+ resolved_obj = importlib.import_module('.'.join(func_path[:i]))
16
+ break
17
+ except ImportError:
18
+ pass
19
+ for attr_name in func_path[i:-1]:
20
+ resolved_obj = getattr(resolved_obj, attr_name)
21
+ orig_func = getattr(resolved_obj, func_path[-1])
22
+ setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
23
+ self.__init__(orig_func, sub_func, cond_func)
24
+ return lambda *args, **kwargs: self(*args, **kwargs)
25
+ def __init__(self, orig_func, sub_func, cond_func):
26
+ self.__orig_func = orig_func
27
+ self.__sub_func = sub_func
28
+ self.__cond_func = cond_func
29
+ def __call__(self, *args, **kwargs):
30
+ if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
31
+ return self.__sub_func(self.__orig_func, *args, **kwargs)
32
+ else:
33
+ return self.__orig_func(*args, **kwargs)
34
+
35
+ _utils = torch.utils.data._utils
36
+ def _shutdown_workers(self):
37
+ if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
38
+ return
39
+ if hasattr(self, "_shutdown") and not self._shutdown:
40
+ self._shutdown = True
41
+ try:
42
+ if hasattr(self, '_pin_memory_thread'):
43
+ self._pin_memory_thread_done_event.set()
44
+ self._worker_result_queue.put((None, None))
45
+ self._pin_memory_thread.join()
46
+ self._worker_result_queue.cancel_join_thread()
47
+ self._worker_result_queue.close()
48
+ self._workers_done_event.set()
49
+ for worker_id in range(len(self._workers)):
50
+ if self._persistent_workers or self._workers_status[worker_id]:
51
+ self._mark_worker_as_unavailable(worker_id, shutdown=True)
52
+ for w in self._workers: # pylint: disable=invalid-name
53
+ w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
54
+ for q in self._index_queues: # pylint: disable=invalid-name
55
+ q.cancel_join_thread()
56
+ q.close()
57
+ finally:
58
+ if self._worker_pids_set:
59
+ torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
60
+ self._worker_pids_set = False
61
+ for w in self._workers: # pylint: disable=invalid-name
62
+ if w.is_alive():
63
+ w.terminate()
64
+
65
+ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
66
+ def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
67
+ if isinstance(device_ids, list) and len(device_ids) > 1:
68
+ print("IPEX backend doesn't support DataParallel on multiple XPU devices")
69
+ return module.to("xpu")
70
+
71
+ def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
72
+ return contextlib.nullcontext()
73
+
74
+ def check_device(device):
75
+ return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
76
+
77
+ def return_xpu(device):
78
+ return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
79
+
80
+ def ipex_no_cuda(orig_func, *args, **kwargs):
81
+ torch.cuda.is_available = lambda: False
82
+ orig_func(*args, **kwargs)
83
+ torch.cuda.is_available = torch.xpu.is_available
84
+
85
+ original_autocast = torch.autocast
86
+ def ipex_autocast(*args, **kwargs):
87
+ if len(args) > 0 and args[0] == "cuda":
88
+ return original_autocast("xpu", *args[1:], **kwargs)
89
+ else:
90
+ return original_autocast(*args, **kwargs)
91
+
92
+ # Embedding BF16
93
+ original_torch_cat = torch.cat
94
+ def torch_cat(tensor, *args, **kwargs):
95
+ if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
96
+ return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
97
+ else:
98
+ return original_torch_cat(tensor, *args, **kwargs)
99
+
100
+ # Latent antialias:
101
+ original_interpolate = torch.nn.functional.interpolate
102
+ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
103
+ if antialias or align_corners is not None:
104
+ return_device = tensor.device
105
+ return_dtype = tensor.dtype
106
+ return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
107
+ align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
108
+ else:
109
+ return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
110
+ align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
111
+
112
+ original_linalg_solve = torch.linalg.solve
113
+ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
114
+ if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
115
+ return_device = A.device
116
+ return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
117
+ else:
118
+ return original_linalg_solve(A, B, *args, **kwargs)
119
+
120
+ if torch.xpu.has_fp64_dtype():
121
+ original_torch_bmm = torch.bmm
122
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
123
+ else:
124
+ # 64 bit attention workarounds for Alchemist:
125
+ try:
126
+ from .attention import torch_bmm_32_bit as original_torch_bmm
127
+ from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
128
+ except Exception: # pylint: disable=broad-exception-caught
129
+ original_torch_bmm = torch.bmm
130
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
131
+
132
+ # dtype errors:
133
+ def torch_bmm(input, mat2, *, out=None):
134
+ if input.dtype != mat2.dtype:
135
+ mat2 = mat2.to(input.dtype)
136
+ return original_torch_bmm(input, mat2, out=out)
137
+
138
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
139
+ if query.dtype != key.dtype:
140
+ key = key.to(dtype=query.dtype)
141
+ if query.dtype != value.dtype:
142
+ value = value.to(dtype=query.dtype)
143
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
144
+
145
+ @property
146
+ def is_cuda(self):
147
+ return self.device.type == 'xpu'
148
+
149
+ def ipex_hijacks():
150
+ CondFunc('torch.tensor',
151
+ lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
152
+ lambda orig_func, *args, device=None, **kwargs: check_device(device))
153
+ CondFunc('torch.Tensor.to',
154
+ lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
155
+ lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
156
+ CondFunc('torch.Tensor.cuda',
157
+ lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
158
+ lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
159
+ CondFunc('torch.UntypedStorage.__init__',
160
+ lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
161
+ lambda orig_func, *args, device=None, **kwargs: check_device(device))
162
+ CondFunc('torch.UntypedStorage.cuda',
163
+ lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
164
+ lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
165
+ CondFunc('torch.empty',
166
+ lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
167
+ lambda orig_func, *args, device=None, **kwargs: check_device(device))
168
+ CondFunc('torch.randn',
169
+ lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
170
+ lambda orig_func, *args, device=None, **kwargs: check_device(device))
171
+ CondFunc('torch.ones',
172
+ lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
173
+ lambda orig_func, *args, device=None, **kwargs: check_device(device))
174
+ CondFunc('torch.zeros',
175
+ lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
176
+ lambda orig_func, *args, device=None, **kwargs: check_device(device))
177
+ CondFunc('torch.linspace',
178
+ lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
179
+ lambda orig_func, *args, device=None, **kwargs: check_device(device))
180
+ CondFunc('torch.load',
181
+ lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
182
+ orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
183
+ lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
184
+ if hasattr(torch.xpu, "Generator"):
185
+ CondFunc('torch.Generator',
186
+ lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
187
+ lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
188
+ else:
189
+ CondFunc('torch.Generator',
190
+ lambda orig_func, device=None: orig_func(return_xpu(device)),
191
+ lambda orig_func, device=None: check_device(device))
192
+
193
+ # TiledVAE and ControlNet:
194
+ CondFunc('torch.batch_norm',
195
+ lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
196
+ weight if weight is not None else torch.ones(input.size()[1], device=input.device),
197
+ bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
198
+ lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
199
+ CondFunc('torch.instance_norm',
200
+ lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
201
+ weight if weight is not None else torch.ones(input.size()[1], device=input.device),
202
+ bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
203
+ lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
204
+
205
+ # Functions with dtype errors:
206
+ CondFunc('torch.nn.modules.GroupNorm.forward',
207
+ lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
208
+ lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
209
+ # Training:
210
+ CondFunc('torch.nn.modules.linear.Linear.forward',
211
+ lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
212
+ lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
213
+ CondFunc('torch.nn.modules.conv.Conv2d.forward',
214
+ lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
215
+ lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
216
+ # BF16:
217
+ CondFunc('torch.nn.functional.layer_norm',
218
+ lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
219
+ orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
220
+ lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
221
+ weight is not None and input.dtype != weight.data.dtype)
222
+ # SwinIR BF16:
223
+ CondFunc('torch.nn.functional.pad',
224
+ lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
225
+ lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
226
+
227
+ # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
228
+ if not torch.xpu.has_fp64_dtype():
229
+ CondFunc('torch.from_numpy',
230
+ lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
231
+ lambda orig_func, ndarray: ndarray.dtype == float)
232
+
233
+ # Broken functions when torch.cuda.is_available is True:
234
+ # Pin Memory:
235
+ CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
236
+ lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
237
+ lambda orig_func, *args, **kwargs: True)
238
+
239
+ # Functions that make compile mad with CondFunc:
240
+ torch.nn.DataParallel = DummyDataParallel
241
+ torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
242
+
243
+ torch.autocast = ipex_autocast
244
+ torch.backends.cuda.sdp_kernel = return_null_context
245
+ torch.UntypedStorage.is_cuda = is_cuda
246
+
247
+ torch.nn.functional.interpolate = interpolate
248
+ torch.linalg.solve = linalg_solve
249
+
250
+ torch.bmm = torch_bmm
251
+ torch.cat = torch_cat
252
+ torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
external/llite/library/lpw_stable_diffusion.py ADDED
@@ -0,0 +1,1254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
2
+ # and modify to support SD2.x
3
+
4
+ import inspect
5
+ import re
6
+ from typing import Callable, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+ import torch
11
+ from packaging import version
12
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
13
+
14
+ import diffusers
15
+ from diffusers import SchedulerMixin, StableDiffusionPipeline
16
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
+ from diffusers.utils import logging
19
+
20
+
21
+ try:
22
+ from diffusers.utils import PIL_INTERPOLATION
23
+ except ImportError:
24
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
25
+ PIL_INTERPOLATION = {
26
+ "linear": PIL.Image.Resampling.BILINEAR,
27
+ "bilinear": PIL.Image.Resampling.BILINEAR,
28
+ "bicubic": PIL.Image.Resampling.BICUBIC,
29
+ "lanczos": PIL.Image.Resampling.LANCZOS,
30
+ "nearest": PIL.Image.Resampling.NEAREST,
31
+ }
32
+ else:
33
+ PIL_INTERPOLATION = {
34
+ "linear": PIL.Image.LINEAR,
35
+ "bilinear": PIL.Image.BILINEAR,
36
+ "bicubic": PIL.Image.BICUBIC,
37
+ "lanczos": PIL.Image.LANCZOS,
38
+ "nearest": PIL.Image.NEAREST,
39
+ }
40
+ # ------------------------------------------------------------------------------
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+ re_attention = re.compile(
45
+ r"""
46
+ \\\(|
47
+ \\\)|
48
+ \\\[|
49
+ \\]|
50
+ \\\\|
51
+ \\|
52
+ \(|
53
+ \[|
54
+ :([+-]?[.\d]+)\)|
55
+ \)|
56
+ ]|
57
+ [^\\()\[\]:]+|
58
+ :
59
+ """,
60
+ re.X,
61
+ )
62
+
63
+
64
+ def parse_prompt_attention(text):
65
+ """
66
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
67
+ Accepted tokens are:
68
+ (abc) - increases attention to abc by a multiplier of 1.1
69
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
70
+ [abc] - decreases attention to abc by a multiplier of 1.1
71
+ \( - literal character '('
72
+ \[ - literal character '['
73
+ \) - literal character ')'
74
+ \] - literal character ']'
75
+ \\ - literal character '\'
76
+ anything else - just text
77
+ >>> parse_prompt_attention('normal text')
78
+ [['normal text', 1.0]]
79
+ >>> parse_prompt_attention('an (important) word')
80
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
81
+ >>> parse_prompt_attention('(unbalanced')
82
+ [['unbalanced', 1.1]]
83
+ >>> parse_prompt_attention('\(literal\]')
84
+ [['(literal]', 1.0]]
85
+ >>> parse_prompt_attention('(unnecessary)(parens)')
86
+ [['unnecessaryparens', 1.1]]
87
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
88
+ [['a ', 1.0],
89
+ ['house', 1.5730000000000004],
90
+ [' ', 1.1],
91
+ ['on', 1.0],
92
+ [' a ', 1.1],
93
+ ['hill', 0.55],
94
+ [', sun, ', 1.1],
95
+ ['sky', 1.4641000000000006],
96
+ ['.', 1.1]]
97
+ """
98
+
99
+ res = []
100
+ round_brackets = []
101
+ square_brackets = []
102
+
103
+ round_bracket_multiplier = 1.1
104
+ square_bracket_multiplier = 1 / 1.1
105
+
106
+ def multiply_range(start_position, multiplier):
107
+ for p in range(start_position, len(res)):
108
+ res[p][1] *= multiplier
109
+
110
+ for m in re_attention.finditer(text):
111
+ text = m.group(0)
112
+ weight = m.group(1)
113
+
114
+ if text.startswith("\\"):
115
+ res.append([text[1:], 1.0])
116
+ elif text == "(":
117
+ round_brackets.append(len(res))
118
+ elif text == "[":
119
+ square_brackets.append(len(res))
120
+ elif weight is not None and len(round_brackets) > 0:
121
+ multiply_range(round_brackets.pop(), float(weight))
122
+ elif text == ")" and len(round_brackets) > 0:
123
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
124
+ elif text == "]" and len(square_brackets) > 0:
125
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
126
+ else:
127
+ res.append([text, 1.0])
128
+
129
+ for pos in round_brackets:
130
+ multiply_range(pos, round_bracket_multiplier)
131
+
132
+ for pos in square_brackets:
133
+ multiply_range(pos, square_bracket_multiplier)
134
+
135
+ if len(res) == 0:
136
+ res = [["", 1.0]]
137
+
138
+ # merge runs of identical weights
139
+ i = 0
140
+ while i + 1 < len(res):
141
+ if res[i][1] == res[i + 1][1]:
142
+ res[i][0] += res[i + 1][0]
143
+ res.pop(i + 1)
144
+ else:
145
+ i += 1
146
+
147
+ return res
148
+
149
+
150
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
151
+ r"""
152
+ Tokenize a list of prompts and return its tokens with weights of each token.
153
+
154
+ No padding, starting or ending token is included.
155
+ """
156
+ tokens = []
157
+ weights = []
158
+ truncated = False
159
+ for text in prompt:
160
+ texts_and_weights = parse_prompt_attention(text)
161
+ text_token = []
162
+ text_weight = []
163
+ for word, weight in texts_and_weights:
164
+ # tokenize and discard the starting and the ending token
165
+ token = pipe.tokenizer(word).input_ids[1:-1]
166
+ text_token += token
167
+ # copy the weight by length of token
168
+ text_weight += [weight] * len(token)
169
+ # stop if the text is too long (longer than truncation limit)
170
+ if len(text_token) > max_length:
171
+ truncated = True
172
+ break
173
+ # truncate
174
+ if len(text_token) > max_length:
175
+ truncated = True
176
+ text_token = text_token[:max_length]
177
+ text_weight = text_weight[:max_length]
178
+ tokens.append(text_token)
179
+ weights.append(text_weight)
180
+ if truncated:
181
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
182
+ return tokens, weights
183
+
184
+
185
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
186
+ r"""
187
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
188
+ """
189
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
190
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
191
+ for i in range(len(tokens)):
192
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
193
+ if no_boseos_middle:
194
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
195
+ else:
196
+ w = []
197
+ if len(weights[i]) == 0:
198
+ w = [1.0] * weights_length
199
+ else:
200
+ for j in range(max_embeddings_multiples):
201
+ w.append(1.0) # weight for starting token in this chunk
202
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
203
+ w.append(1.0) # weight for ending token in this chunk
204
+ w += [1.0] * (weights_length - len(w))
205
+ weights[i] = w[:]
206
+
207
+ return tokens, weights
208
+
209
+
210
+ def get_unweighted_text_embeddings(
211
+ pipe: StableDiffusionPipeline,
212
+ text_input: torch.Tensor,
213
+ chunk_length: int,
214
+ clip_skip: int,
215
+ eos: int,
216
+ pad: int,
217
+ no_boseos_middle: Optional[bool] = True,
218
+ ):
219
+ """
220
+ When the length of tokens is a multiple of the capacity of the text encoder,
221
+ it should be split into chunks and sent to the text encoder individually.
222
+ """
223
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
224
+ if max_embeddings_multiples > 1:
225
+ text_embeddings = []
226
+ for i in range(max_embeddings_multiples):
227
+ # extract the i-th chunk
228
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
229
+
230
+ # cover the head and the tail by the starting and the ending tokens
231
+ text_input_chunk[:, 0] = text_input[0, 0]
232
+ if pad == eos: # v1
233
+ text_input_chunk[:, -1] = text_input[0, -1]
234
+ else: # v2
235
+ for j in range(len(text_input_chunk)):
236
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
237
+ text_input_chunk[j, -1] = eos
238
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
239
+ text_input_chunk[j, 1] = eos
240
+
241
+ if clip_skip is None or clip_skip == 1:
242
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
243
+ else:
244
+ enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
245
+ text_embedding = enc_out["hidden_states"][-clip_skip]
246
+ text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
247
+
248
+ if no_boseos_middle:
249
+ if i == 0:
250
+ # discard the ending token
251
+ text_embedding = text_embedding[:, :-1]
252
+ elif i == max_embeddings_multiples - 1:
253
+ # discard the starting token
254
+ text_embedding = text_embedding[:, 1:]
255
+ else:
256
+ # discard both starting and ending tokens
257
+ text_embedding = text_embedding[:, 1:-1]
258
+
259
+ text_embeddings.append(text_embedding)
260
+ text_embeddings = torch.concat(text_embeddings, axis=1)
261
+ else:
262
+ if clip_skip is None or clip_skip == 1:
263
+ text_embeddings = pipe.text_encoder(text_input)[0]
264
+ else:
265
+ enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
266
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
267
+ text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
268
+ return text_embeddings
269
+
270
+
271
+ def get_weighted_text_embeddings(
272
+ pipe: StableDiffusionPipeline,
273
+ prompt: Union[str, List[str]],
274
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
275
+ max_embeddings_multiples: Optional[int] = 3,
276
+ no_boseos_middle: Optional[bool] = False,
277
+ skip_parsing: Optional[bool] = False,
278
+ skip_weighting: Optional[bool] = False,
279
+ clip_skip=None,
280
+ ):
281
+ r"""
282
+ Prompts can be assigned with local weights using brackets. For example,
283
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
284
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
285
+
286
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
287
+
288
+ Args:
289
+ pipe (`StableDiffusionPipeline`):
290
+ Pipe to provide access to the tokenizer and the text encoder.
291
+ prompt (`str` or `List[str]`):
292
+ The prompt or prompts to guide the image generation.
293
+ uncond_prompt (`str` or `List[str]`):
294
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
295
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
296
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
297
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
298
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
299
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
300
+ ending token in each of the chunk in the middle.
301
+ skip_parsing (`bool`, *optional*, defaults to `False`):
302
+ Skip the parsing of brackets.
303
+ skip_weighting (`bool`, *optional*, defaults to `False`):
304
+ Skip the weighting. When the parsing is skipped, it is forced True.
305
+ """
306
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
307
+ if isinstance(prompt, str):
308
+ prompt = [prompt]
309
+
310
+ if not skip_parsing:
311
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
312
+ if uncond_prompt is not None:
313
+ if isinstance(uncond_prompt, str):
314
+ uncond_prompt = [uncond_prompt]
315
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
316
+ else:
317
+ prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
318
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
319
+ if uncond_prompt is not None:
320
+ if isinstance(uncond_prompt, str):
321
+ uncond_prompt = [uncond_prompt]
322
+ uncond_tokens = [
323
+ token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
324
+ ]
325
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
326
+
327
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
328
+ max_length = max([len(token) for token in prompt_tokens])
329
+ if uncond_prompt is not None:
330
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
331
+
332
+ max_embeddings_multiples = min(
333
+ max_embeddings_multiples,
334
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
335
+ )
336
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
337
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
338
+
339
+ # pad the length of tokens and weights
340
+ bos = pipe.tokenizer.bos_token_id
341
+ eos = pipe.tokenizer.eos_token_id
342
+ pad = pipe.tokenizer.pad_token_id
343
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
344
+ prompt_tokens,
345
+ prompt_weights,
346
+ max_length,
347
+ bos,
348
+ eos,
349
+ no_boseos_middle=no_boseos_middle,
350
+ chunk_length=pipe.tokenizer.model_max_length,
351
+ )
352
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
353
+ if uncond_prompt is not None:
354
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
355
+ uncond_tokens,
356
+ uncond_weights,
357
+ max_length,
358
+ bos,
359
+ eos,
360
+ no_boseos_middle=no_boseos_middle,
361
+ chunk_length=pipe.tokenizer.model_max_length,
362
+ )
363
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
364
+
365
+ # get the embeddings
366
+ text_embeddings = get_unweighted_text_embeddings(
367
+ pipe,
368
+ prompt_tokens,
369
+ pipe.tokenizer.model_max_length,
370
+ clip_skip,
371
+ eos,
372
+ pad,
373
+ no_boseos_middle=no_boseos_middle,
374
+ )
375
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
376
+ if uncond_prompt is not None:
377
+ uncond_embeddings = get_unweighted_text_embeddings(
378
+ pipe,
379
+ uncond_tokens,
380
+ pipe.tokenizer.model_max_length,
381
+ clip_skip,
382
+ eos,
383
+ pad,
384
+ no_boseos_middle=no_boseos_middle,
385
+ )
386
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
387
+
388
+ # assign weights to the prompts and normalize in the sense of mean
389
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
390
+ if (not skip_parsing) and (not skip_weighting):
391
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
392
+ text_embeddings *= prompt_weights.unsqueeze(-1)
393
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
394
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
395
+ if uncond_prompt is not None:
396
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
397
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
398
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
399
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
400
+
401
+ if uncond_prompt is not None:
402
+ return text_embeddings, uncond_embeddings
403
+ return text_embeddings, None
404
+
405
+
406
+ def preprocess_image(image):
407
+ w, h = image.size
408
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
409
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
410
+ image = np.array(image).astype(np.float32) / 255.0
411
+ image = image[None].transpose(0, 3, 1, 2)
412
+ image = torch.from_numpy(image)
413
+ return 2.0 * image - 1.0
414
+
415
+
416
+ def preprocess_mask(mask, scale_factor=8):
417
+ mask = mask.convert("L")
418
+ w, h = mask.size
419
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
420
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
421
+ mask = np.array(mask).astype(np.float32) / 255.0
422
+ mask = np.tile(mask, (4, 1, 1))
423
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
424
+ mask = 1 - mask # repaint white, keep black
425
+ mask = torch.from_numpy(mask)
426
+ return mask
427
+
428
+
429
+ def prepare_controlnet_image(
430
+ image: PIL.Image.Image,
431
+ width: int,
432
+ height: int,
433
+ batch_size: int,
434
+ num_images_per_prompt: int,
435
+ device: torch.device,
436
+ dtype: torch.dtype,
437
+ do_classifier_free_guidance: bool = False,
438
+ guess_mode: bool = False,
439
+ ):
440
+ if not isinstance(image, torch.Tensor):
441
+ if isinstance(image, PIL.Image.Image):
442
+ image = [image]
443
+
444
+ if isinstance(image[0], PIL.Image.Image):
445
+ images = []
446
+
447
+ for image_ in image:
448
+ image_ = image_.convert("RGB")
449
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
450
+ image_ = np.array(image_)
451
+ image_ = image_[None, :]
452
+ images.append(image_)
453
+
454
+ image = images
455
+
456
+ image = np.concatenate(image, axis=0)
457
+ image = np.array(image).astype(np.float32) / 255.0
458
+ image = image.transpose(0, 3, 1, 2)
459
+ image = torch.from_numpy(image)
460
+ elif isinstance(image[0], torch.Tensor):
461
+ image = torch.cat(image, dim=0)
462
+
463
+ image_batch_size = image.shape[0]
464
+
465
+ if image_batch_size == 1:
466
+ repeat_by = batch_size
467
+ else:
468
+ # image batch size is the same as prompt batch size
469
+ repeat_by = num_images_per_prompt
470
+
471
+ image = image.repeat_interleave(repeat_by, dim=0)
472
+
473
+ image = image.to(device=device, dtype=dtype)
474
+
475
+ if do_classifier_free_guidance and not guess_mode:
476
+ image = torch.cat([image] * 2)
477
+
478
+ return image
479
+
480
+
481
+ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
482
+ r"""
483
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
484
+ weighting in prompt.
485
+
486
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
487
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
488
+
489
+ Args:
490
+ vae ([`AutoencoderKL`]):
491
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
492
+ text_encoder ([`CLIPTextModel`]):
493
+ Frozen text-encoder. Stable Diffusion uses the text portion of
494
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
495
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
496
+ tokenizer (`CLIPTokenizer`):
497
+ Tokenizer of class
498
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
499
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
500
+ scheduler ([`SchedulerMixin`]):
501
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
502
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
503
+ safety_checker ([`StableDiffusionSafetyChecker`]):
504
+ Classification module that estimates whether generated images could be considered offensive or harmful.
505
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
506
+ feature_extractor ([`CLIPFeatureExtractor`]):
507
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
508
+ """
509
+
510
+ # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
511
+
512
+ def __init__(
513
+ self,
514
+ vae: AutoencoderKL,
515
+ text_encoder: CLIPTextModel,
516
+ tokenizer: CLIPTokenizer,
517
+ unet: UNet2DConditionModel,
518
+ scheduler: SchedulerMixin,
519
+ # clip_skip: int,
520
+ safety_checker: StableDiffusionSafetyChecker,
521
+ feature_extractor: CLIPFeatureExtractor,
522
+ requires_safety_checker: bool = True,
523
+ clip_skip: int = 1,
524
+ ):
525
+ super().__init__(
526
+ vae=vae,
527
+ text_encoder=text_encoder,
528
+ tokenizer=tokenizer,
529
+ unet=unet,
530
+ scheduler=scheduler,
531
+ safety_checker=safety_checker,
532
+ feature_extractor=feature_extractor,
533
+ requires_safety_checker=requires_safety_checker,
534
+ )
535
+ self.clip_skip = clip_skip
536
+ self.__init__additional__()
537
+
538
+ # else:
539
+ # def __init__(
540
+ # self,
541
+ # vae: AutoencoderKL,
542
+ # text_encoder: CLIPTextModel,
543
+ # tokenizer: CLIPTokenizer,
544
+ # unet: UNet2DConditionModel,
545
+ # scheduler: SchedulerMixin,
546
+ # safety_checker: StableDiffusionSafetyChecker,
547
+ # feature_extractor: CLIPFeatureExtractor,
548
+ # ):
549
+ # super().__init__(
550
+ # vae=vae,
551
+ # text_encoder=text_encoder,
552
+ # tokenizer=tokenizer,
553
+ # unet=unet,
554
+ # scheduler=scheduler,
555
+ # safety_checker=safety_checker,
556
+ # feature_extractor=feature_extractor,
557
+ # )
558
+ # self.__init__additional__()
559
+
560
+ def __init__additional__(self):
561
+ if not hasattr(self, "vae_scale_factor"):
562
+ setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
563
+
564
+ @property
565
+ def _execution_device(self):
566
+ r"""
567
+ Returns the device on which the pipeline's models will be executed. After calling
568
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
569
+ hooks.
570
+ """
571
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
572
+ return self.device
573
+ for module in self.unet.modules():
574
+ if (
575
+ hasattr(module, "_hf_hook")
576
+ and hasattr(module._hf_hook, "execution_device")
577
+ and module._hf_hook.execution_device is not None
578
+ ):
579
+ return torch.device(module._hf_hook.execution_device)
580
+ return self.device
581
+
582
+ def _encode_prompt(
583
+ self,
584
+ prompt,
585
+ device,
586
+ num_images_per_prompt,
587
+ do_classifier_free_guidance,
588
+ negative_prompt,
589
+ max_embeddings_multiples,
590
+ ):
591
+ r"""
592
+ Encodes the prompt into text encoder hidden states.
593
+
594
+ Args:
595
+ prompt (`str` or `list(int)`):
596
+ prompt to be encoded
597
+ device: (`torch.device`):
598
+ torch device
599
+ num_images_per_prompt (`int`):
600
+ number of images that should be generated per prompt
601
+ do_classifier_free_guidance (`bool`):
602
+ whether to use classifier free guidance or not
603
+ negative_prompt (`str` or `List[str]`):
604
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
605
+ if `guidance_scale` is less than `1`).
606
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
607
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
608
+ """
609
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
610
+
611
+ if negative_prompt is None:
612
+ negative_prompt = [""] * batch_size
613
+ elif isinstance(negative_prompt, str):
614
+ negative_prompt = [negative_prompt] * batch_size
615
+ if batch_size != len(negative_prompt):
616
+ raise ValueError(
617
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
618
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
619
+ " the batch size of `prompt`."
620
+ )
621
+
622
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
623
+ pipe=self,
624
+ prompt=prompt,
625
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
626
+ max_embeddings_multiples=max_embeddings_multiples,
627
+ clip_skip=self.clip_skip,
628
+ )
629
+ bs_embed, seq_len, _ = text_embeddings.shape
630
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
631
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
632
+
633
+ if do_classifier_free_guidance:
634
+ bs_embed, seq_len, _ = uncond_embeddings.shape
635
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
636
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
637
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
638
+
639
+ return text_embeddings
640
+
641
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
642
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
643
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
644
+
645
+ if strength < 0 or strength > 1:
646
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
647
+
648
+ if height % 8 != 0 or width % 8 != 0:
649
+ print(height, width)
650
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
651
+
652
+ if (callback_steps is None) or (
653
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
654
+ ):
655
+ raise ValueError(
656
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
657
+ )
658
+
659
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
660
+ if is_text2img:
661
+ return self.scheduler.timesteps.to(device), num_inference_steps
662
+ else:
663
+ # get the original timestep using init_timestep
664
+ offset = self.scheduler.config.get("steps_offset", 0)
665
+ init_timestep = int(num_inference_steps * strength) + offset
666
+ init_timestep = min(init_timestep, num_inference_steps)
667
+
668
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
669
+ timesteps = self.scheduler.timesteps[t_start:].to(device)
670
+ return timesteps, num_inference_steps - t_start
671
+
672
+ def run_safety_checker(self, image, device, dtype):
673
+ if self.safety_checker is not None:
674
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
675
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
676
+ else:
677
+ has_nsfw_concept = None
678
+ return image, has_nsfw_concept
679
+
680
+ def decode_latents(self, latents):
681
+ latents = 1 / 0.18215 * latents
682
+ image = self.vae.decode(latents).sample
683
+ image = (image / 2 + 0.5).clamp(0, 1)
684
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
685
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
686
+ return image
687
+
688
+ def prepare_extra_step_kwargs(self, generator, eta):
689
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
690
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
691
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
692
+ # and should be between [0, 1]
693
+
694
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
695
+ extra_step_kwargs = {}
696
+ if accepts_eta:
697
+ extra_step_kwargs["eta"] = eta
698
+
699
+ # check if the scheduler accepts generator
700
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
701
+ if accepts_generator:
702
+ extra_step_kwargs["generator"] = generator
703
+ return extra_step_kwargs
704
+
705
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
706
+ if image is None:
707
+ shape = (
708
+ batch_size,
709
+ self.unet.in_channels,
710
+ height // self.vae_scale_factor,
711
+ width // self.vae_scale_factor,
712
+ )
713
+
714
+ if latents is None:
715
+ if device.type == "mps":
716
+ # randn does not work reproducibly on mps
717
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
718
+ else:
719
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
720
+ else:
721
+ if latents.shape != shape:
722
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
723
+ latents = latents.to(device)
724
+
725
+ # scale the initial noise by the standard deviation required by the scheduler
726
+ latents = latents * self.scheduler.init_noise_sigma
727
+ return latents, None, None
728
+ else:
729
+ init_latent_dist = self.vae.encode(image).latent_dist
730
+ init_latents = init_latent_dist.sample(generator=generator)
731
+ init_latents = 0.18215 * init_latents
732
+ init_latents = torch.cat([init_latents] * batch_size, dim=0)
733
+ init_latents_orig = init_latents
734
+ shape = init_latents.shape
735
+
736
+ # add noise to latents using the timesteps
737
+ if device.type == "mps":
738
+ noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
739
+ else:
740
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
741
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
742
+ return latents, init_latents_orig, noise
743
+
744
+ @torch.no_grad()
745
+ def __call__(
746
+ self,
747
+ prompt: Union[str, List[str]],
748
+ negative_prompt: Optional[Union[str, List[str]]] = None,
749
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
750
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
751
+ height: int = 512,
752
+ width: int = 512,
753
+ num_inference_steps: int = 50,
754
+ guidance_scale: float = 7.5,
755
+ strength: float = 0.8,
756
+ num_images_per_prompt: Optional[int] = 1,
757
+ eta: float = 0.0,
758
+ generator: Optional[torch.Generator] = None,
759
+ latents: Optional[torch.FloatTensor] = None,
760
+ max_embeddings_multiples: Optional[int] = 3,
761
+ output_type: Optional[str] = "pil",
762
+ return_dict: bool = True,
763
+ controlnet=None,
764
+ controlnet_image=None,
765
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
766
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
767
+ callback_steps: int = 1,
768
+ ):
769
+ r"""
770
+ Function invoked when calling the pipeline for generation.
771
+
772
+ Args:
773
+ prompt (`str` or `List[str]`):
774
+ The prompt or prompts to guide the image generation.
775
+ negative_prompt (`str` or `List[str]`, *optional*):
776
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
777
+ if `guidance_scale` is less than `1`).
778
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
779
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
780
+ process.
781
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
782
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
783
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
784
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
785
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
786
+ height (`int`, *optional*, defaults to 512):
787
+ The height in pixels of the generated image.
788
+ width (`int`, *optional*, defaults to 512):
789
+ The width in pixels of the generated image.
790
+ num_inference_steps (`int`, *optional*, defaults to 50):
791
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
792
+ expense of slower inference.
793
+ guidance_scale (`float`, *optional*, defaults to 7.5):
794
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
795
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
796
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
797
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
798
+ usually at the expense of lower image quality.
799
+ strength (`float`, *optional*, defaults to 0.8):
800
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
801
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
802
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
803
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
804
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
805
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
806
+ The number of images to generate per prompt.
807
+ eta (`float`, *optional*, defaults to 0.0):
808
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
809
+ [`schedulers.DDIMScheduler`], will be ignored for others.
810
+ generator (`torch.Generator`, *optional*):
811
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
812
+ deterministic.
813
+ latents (`torch.FloatTensor`, *optional*):
814
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
815
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
816
+ tensor will ge generated by sampling using the supplied random `generator`.
817
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
818
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
819
+ output_type (`str`, *optional*, defaults to `"pil"`):
820
+ The output format of the generate image. Choose between
821
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
822
+ return_dict (`bool`, *optional*, defaults to `True`):
823
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
824
+ plain tuple.
825
+ controlnet (`diffusers.ControlNetModel`, *optional*):
826
+ A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
827
+ controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
828
+ `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
829
+ inference.
830
+ callback (`Callable`, *optional*):
831
+ A function that will be called every `callback_steps` steps during inference. The function will be
832
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
833
+ is_cancelled_callback (`Callable`, *optional*):
834
+ A function that will be called every `callback_steps` steps during inference. If the function returns
835
+ `True`, the inference will be cancelled.
836
+ callback_steps (`int`, *optional*, defaults to 1):
837
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
838
+ called at every step.
839
+
840
+ Returns:
841
+ `None` if cancelled by `is_cancelled_callback`,
842
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
843
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
844
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
845
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
846
+ (nsfw) content, according to the `safety_checker`.
847
+ """
848
+ if controlnet is not None and controlnet_image is None:
849
+ raise ValueError("controlnet_image must be provided if controlnet is not None.")
850
+
851
+ # 0. Default height and width to unet
852
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
853
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
854
+
855
+ # 1. Check inputs. Raise error if not correct
856
+ self.check_inputs(prompt, height, width, strength, callback_steps)
857
+
858
+ # 2. Define call parameters
859
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
860
+ device = self._execution_device
861
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
862
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
863
+ # corresponds to doing no classifier free guidance.
864
+ do_classifier_free_guidance = guidance_scale > 1.0
865
+
866
+ # 3. Encode input prompt
867
+ text_embeddings = self._encode_prompt(
868
+ prompt,
869
+ device,
870
+ num_images_per_prompt,
871
+ do_classifier_free_guidance,
872
+ negative_prompt,
873
+ max_embeddings_multiples,
874
+ )
875
+ dtype = text_embeddings.dtype
876
+
877
+ # 4. Preprocess image and mask
878
+ if isinstance(image, PIL.Image.Image):
879
+ image = preprocess_image(image)
880
+ if image is not None:
881
+ image = image.to(device=self.device, dtype=dtype)
882
+ if isinstance(mask_image, PIL.Image.Image):
883
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
884
+ if mask_image is not None:
885
+ mask = mask_image.to(device=self.device, dtype=dtype)
886
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
887
+ else:
888
+ mask = None
889
+
890
+ if controlnet_image is not None:
891
+ controlnet_image = prepare_controlnet_image(
892
+ controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
893
+ )
894
+
895
+ # 5. set timesteps
896
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
897
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
898
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
899
+
900
+ # 6. Prepare latent variables
901
+ latents, init_latents_orig, noise = self.prepare_latents(
902
+ image,
903
+ latent_timestep,
904
+ batch_size * num_images_per_prompt,
905
+ height,
906
+ width,
907
+ dtype,
908
+ device,
909
+ generator,
910
+ latents,
911
+ )
912
+
913
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
914
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
915
+
916
+ # 8. Denoising loop
917
+ for i, t in enumerate(self.progress_bar(timesteps)):
918
+ # expand the latents if we are doing classifier free guidance
919
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
920
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
921
+
922
+ unet_additional_args = {}
923
+ if controlnet is not None:
924
+ down_block_res_samples, mid_block_res_sample = controlnet(
925
+ latent_model_input,
926
+ t,
927
+ encoder_hidden_states=text_embeddings,
928
+ controlnet_cond=controlnet_image,
929
+ conditioning_scale=1.0,
930
+ guess_mode=False,
931
+ return_dict=False,
932
+ )
933
+ unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
934
+ unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
935
+
936
+ # predict the noise residual
937
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
938
+
939
+ # perform guidance
940
+ if do_classifier_free_guidance:
941
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
942
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
943
+
944
+ # compute the previous noisy sample x_t -> x_t-1
945
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
946
+
947
+ if mask is not None:
948
+ # masking
949
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
950
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
951
+
952
+ # call the callback, if provided
953
+ if i % callback_steps == 0:
954
+ if callback is not None:
955
+ callback(i, t, latents)
956
+ if is_cancelled_callback is not None and is_cancelled_callback():
957
+ return None
958
+
959
+ return latents
960
+
961
+ def latents_to_image(self, latents):
962
+ # 9. Post-processing
963
+ image = self.decode_latents(latents.to(self.vae.dtype))
964
+ image = self.numpy_to_pil(image)
965
+ return image
966
+
967
+ def text2img(
968
+ self,
969
+ prompt: Union[str, List[str]],
970
+ negative_prompt: Optional[Union[str, List[str]]] = None,
971
+ height: int = 512,
972
+ width: int = 512,
973
+ num_inference_steps: int = 50,
974
+ guidance_scale: float = 7.5,
975
+ num_images_per_prompt: Optional[int] = 1,
976
+ eta: float = 0.0,
977
+ generator: Optional[torch.Generator] = None,
978
+ latents: Optional[torch.FloatTensor] = None,
979
+ max_embeddings_multiples: Optional[int] = 3,
980
+ output_type: Optional[str] = "pil",
981
+ return_dict: bool = True,
982
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
983
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
984
+ callback_steps: int = 1,
985
+ ):
986
+ r"""
987
+ Function for text-to-image generation.
988
+ Args:
989
+ prompt (`str` or `List[str]`):
990
+ The prompt or prompts to guide the image generation.
991
+ negative_prompt (`str` or `List[str]`, *optional*):
992
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
993
+ if `guidance_scale` is less than `1`).
994
+ height (`int`, *optional*, defaults to 512):
995
+ The height in pixels of the generated image.
996
+ width (`int`, *optional*, defaults to 512):
997
+ The width in pixels of the generated image.
998
+ num_inference_steps (`int`, *optional*, defaults to 50):
999
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1000
+ expense of slower inference.
1001
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1002
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1003
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1004
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1005
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1006
+ usually at the expense of lower image quality.
1007
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1008
+ The number of images to generate per prompt.
1009
+ eta (`float`, *optional*, defaults to 0.0):
1010
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1011
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1012
+ generator (`torch.Generator`, *optional*):
1013
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1014
+ deterministic.
1015
+ latents (`torch.FloatTensor`, *optional*):
1016
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1017
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1018
+ tensor will ge generated by sampling using the supplied random `generator`.
1019
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1020
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1021
+ output_type (`str`, *optional*, defaults to `"pil"`):
1022
+ The output format of the generate image. Choose between
1023
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1024
+ return_dict (`bool`, *optional*, defaults to `True`):
1025
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1026
+ plain tuple.
1027
+ callback (`Callable`, *optional*):
1028
+ A function that will be called every `callback_steps` steps during inference. The function will be
1029
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1030
+ is_cancelled_callback (`Callable`, *optional*):
1031
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1032
+ `True`, the inference will be cancelled.
1033
+ callback_steps (`int`, *optional*, defaults to 1):
1034
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1035
+ called at every step.
1036
+ Returns:
1037
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1038
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1039
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1040
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1041
+ (nsfw) content, according to the `safety_checker`.
1042
+ """
1043
+ return self.__call__(
1044
+ prompt=prompt,
1045
+ negative_prompt=negative_prompt,
1046
+ height=height,
1047
+ width=width,
1048
+ num_inference_steps=num_inference_steps,
1049
+ guidance_scale=guidance_scale,
1050
+ num_images_per_prompt=num_images_per_prompt,
1051
+ eta=eta,
1052
+ generator=generator,
1053
+ latents=latents,
1054
+ max_embeddings_multiples=max_embeddings_multiples,
1055
+ output_type=output_type,
1056
+ return_dict=return_dict,
1057
+ callback=callback,
1058
+ is_cancelled_callback=is_cancelled_callback,
1059
+ callback_steps=callback_steps,
1060
+ )
1061
+
1062
+ def img2img(
1063
+ self,
1064
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1065
+ prompt: Union[str, List[str]],
1066
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1067
+ strength: float = 0.8,
1068
+ num_inference_steps: Optional[int] = 50,
1069
+ guidance_scale: Optional[float] = 7.5,
1070
+ num_images_per_prompt: Optional[int] = 1,
1071
+ eta: Optional[float] = 0.0,
1072
+ generator: Optional[torch.Generator] = None,
1073
+ max_embeddings_multiples: Optional[int] = 3,
1074
+ output_type: Optional[str] = "pil",
1075
+ return_dict: bool = True,
1076
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1077
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1078
+ callback_steps: int = 1,
1079
+ ):
1080
+ r"""
1081
+ Function for image-to-image generation.
1082
+ Args:
1083
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1084
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1085
+ process.
1086
+ prompt (`str` or `List[str]`):
1087
+ The prompt or prompts to guide the image generation.
1088
+ negative_prompt (`str` or `List[str]`, *optional*):
1089
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1090
+ if `guidance_scale` is less than `1`).
1091
+ strength (`float`, *optional*, defaults to 0.8):
1092
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1093
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1094
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1095
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1096
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1097
+ num_inference_steps (`int`, *optional*, defaults to 50):
1098
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1099
+ expense of slower inference. This parameter will be modulated by `strength`.
1100
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1101
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1102
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1103
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1104
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1105
+ usually at the expense of lower image quality.
1106
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1107
+ The number of images to generate per prompt.
1108
+ eta (`float`, *optional*, defaults to 0.0):
1109
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1110
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1111
+ generator (`torch.Generator`, *optional*):
1112
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1113
+ deterministic.
1114
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1115
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1116
+ output_type (`str`, *optional*, defaults to `"pil"`):
1117
+ The output format of the generate image. Choose between
1118
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1119
+ return_dict (`bool`, *optional*, defaults to `True`):
1120
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1121
+ plain tuple.
1122
+ callback (`Callable`, *optional*):
1123
+ A function that will be called every `callback_steps` steps during inference. The function will be
1124
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1125
+ is_cancelled_callback (`Callable`, *optional*):
1126
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1127
+ `True`, the inference will be cancelled.
1128
+ callback_steps (`int`, *optional*, defaults to 1):
1129
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1130
+ called at every step.
1131
+ Returns:
1132
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1133
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1134
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1135
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1136
+ (nsfw) content, according to the `safety_checker`.
1137
+ """
1138
+ return self.__call__(
1139
+ prompt=prompt,
1140
+ negative_prompt=negative_prompt,
1141
+ image=image,
1142
+ num_inference_steps=num_inference_steps,
1143
+ guidance_scale=guidance_scale,
1144
+ strength=strength,
1145
+ num_images_per_prompt=num_images_per_prompt,
1146
+ eta=eta,
1147
+ generator=generator,
1148
+ max_embeddings_multiples=max_embeddings_multiples,
1149
+ output_type=output_type,
1150
+ return_dict=return_dict,
1151
+ callback=callback,
1152
+ is_cancelled_callback=is_cancelled_callback,
1153
+ callback_steps=callback_steps,
1154
+ )
1155
+
1156
+ def inpaint(
1157
+ self,
1158
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1159
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1160
+ prompt: Union[str, List[str]],
1161
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1162
+ strength: float = 0.8,
1163
+ num_inference_steps: Optional[int] = 50,
1164
+ guidance_scale: Optional[float] = 7.5,
1165
+ num_images_per_prompt: Optional[int] = 1,
1166
+ eta: Optional[float] = 0.0,
1167
+ generator: Optional[torch.Generator] = None,
1168
+ max_embeddings_multiples: Optional[int] = 3,
1169
+ output_type: Optional[str] = "pil",
1170
+ return_dict: bool = True,
1171
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1172
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1173
+ callback_steps: int = 1,
1174
+ ):
1175
+ r"""
1176
+ Function for inpaint.
1177
+ Args:
1178
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1179
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1180
+ process. This is the image whose masked region will be inpainted.
1181
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1182
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1183
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1184
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1185
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1186
+ prompt (`str` or `List[str]`):
1187
+ The prompt or prompts to guide the image generation.
1188
+ negative_prompt (`str` or `List[str]`, *optional*):
1189
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1190
+ if `guidance_scale` is less than `1`).
1191
+ strength (`float`, *optional*, defaults to 0.8):
1192
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1193
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1194
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1195
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1196
+ num_inference_steps (`int`, *optional*, defaults to 50):
1197
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1198
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1199
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1200
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1201
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1202
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1203
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1204
+ usually at the expense of lower image quality.
1205
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1206
+ The number of images to generate per prompt.
1207
+ eta (`float`, *optional*, defaults to 0.0):
1208
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1209
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1210
+ generator (`torch.Generator`, *optional*):
1211
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1212
+ deterministic.
1213
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1214
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1215
+ output_type (`str`, *optional*, defaults to `"pil"`):
1216
+ The output format of the generate image. Choose between
1217
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1218
+ return_dict (`bool`, *optional*, defaults to `True`):
1219
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1220
+ plain tuple.
1221
+ callback (`Callable`, *optional*):
1222
+ A function that will be called every `callback_steps` steps during inference. The function will be
1223
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1224
+ is_cancelled_callback (`Callable`, *optional*):
1225
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1226
+ `True`, the inference will be cancelled.
1227
+ callback_steps (`int`, *optional*, defaults to 1):
1228
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1229
+ called at every step.
1230
+ Returns:
1231
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1232
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1233
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1234
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1235
+ (nsfw) content, according to the `safety_checker`.
1236
+ """
1237
+ return self.__call__(
1238
+ prompt=prompt,
1239
+ negative_prompt=negative_prompt,
1240
+ image=image,
1241
+ mask_image=mask_image,
1242
+ num_inference_steps=num_inference_steps,
1243
+ guidance_scale=guidance_scale,
1244
+ strength=strength,
1245
+ num_images_per_prompt=num_images_per_prompt,
1246
+ eta=eta,
1247
+ generator=generator,
1248
+ max_embeddings_multiples=max_embeddings_multiples,
1249
+ output_type=output_type,
1250
+ return_dict=return_dict,
1251
+ callback=callback,
1252
+ is_cancelled_callback=is_cancelled_callback,
1253
+ callback_steps=callback_steps,
1254
+ )
external/llite/library/model_util.py ADDED
@@ -0,0 +1,1350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: split from train_db_fixed.py.
2
+ # v2: support safetensors
3
+
4
+ import math
5
+ import os
6
+ import torch
7
+ try:
8
+ import intel_extension_for_pytorch as ipex
9
+ if torch.xpu.is_available():
10
+ from library.ipex import ipex_init
11
+ ipex_init()
12
+ except Exception:
13
+ pass
14
+ import diffusers
15
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
16
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
17
+ from safetensors.torch import load_file, save_file
18
+ from external.llite.library.original_unet import UNet2DConditionModel
19
+
20
+ # DiffUsers版StableDiffusionのモデルパラメータ
21
+ NUM_TRAIN_TIMESTEPS = 1000
22
+ BETA_START = 0.00085
23
+ BETA_END = 0.0120
24
+
25
+ UNET_PARAMS_MODEL_CHANNELS = 320
26
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
27
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
28
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
29
+ UNET_PARAMS_IN_CHANNELS = 4
30
+ UNET_PARAMS_OUT_CHANNELS = 4
31
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
32
+ UNET_PARAMS_CONTEXT_DIM = 768
33
+ UNET_PARAMS_NUM_HEADS = 8
34
+ # UNET_PARAMS_USE_LINEAR_PROJECTION = False
35
+
36
+ VAE_PARAMS_Z_CHANNELS = 4
37
+ VAE_PARAMS_RESOLUTION = 256
38
+ VAE_PARAMS_IN_CHANNELS = 3
39
+ VAE_PARAMS_OUT_CH = 3
40
+ VAE_PARAMS_CH = 128
41
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
42
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
43
+
44
+ # V2
45
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
46
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
47
+ # V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
48
+
49
+ # Diffusersの設定を読み込むための参照モデル
50
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
51
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
52
+
53
+
54
+ # region StableDiffusion->Diffusersの変換コード
55
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
56
+
57
+
58
+ def shave_segments(path, n_shave_prefix_segments=1):
59
+ """
60
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
61
+ """
62
+ if n_shave_prefix_segments >= 0:
63
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
64
+ else:
65
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
66
+
67
+
68
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
69
+ """
70
+ Updates paths inside resnets to the new naming scheme (local renaming)
71
+ """
72
+ mapping = []
73
+ for old_item in old_list:
74
+ new_item = old_item.replace("in_layers.0", "norm1")
75
+ new_item = new_item.replace("in_layers.2", "conv1")
76
+
77
+ new_item = new_item.replace("out_layers.0", "norm2")
78
+ new_item = new_item.replace("out_layers.3", "conv2")
79
+
80
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
81
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
82
+
83
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
84
+
85
+ mapping.append({"old": old_item, "new": new_item})
86
+
87
+ return mapping
88
+
89
+
90
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
91
+ """
92
+ Updates paths inside resnets to the new naming scheme (local renaming)
93
+ """
94
+ mapping = []
95
+ for old_item in old_list:
96
+ new_item = old_item
97
+
98
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
99
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
100
+
101
+ mapping.append({"old": old_item, "new": new_item})
102
+
103
+ return mapping
104
+
105
+
106
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
107
+ """
108
+ Updates paths inside attentions to the new naming scheme (local renaming)
109
+ """
110
+ mapping = []
111
+ for old_item in old_list:
112
+ new_item = old_item
113
+
114
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
115
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
116
+
117
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
118
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
119
+
120
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
121
+
122
+ mapping.append({"old": old_item, "new": new_item})
123
+
124
+ return mapping
125
+
126
+
127
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
128
+ """
129
+ Updates paths inside attentions to the new naming scheme (local renaming)
130
+ """
131
+ mapping = []
132
+ for old_item in old_list:
133
+ new_item = old_item
134
+
135
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
136
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
137
+
138
+ if diffusers.__version__ < "0.17.0":
139
+ new_item = new_item.replace("q.weight", "query.weight")
140
+ new_item = new_item.replace("q.bias", "query.bias")
141
+
142
+ new_item = new_item.replace("k.weight", "key.weight")
143
+ new_item = new_item.replace("k.bias", "key.bias")
144
+
145
+ new_item = new_item.replace("v.weight", "value.weight")
146
+ new_item = new_item.replace("v.bias", "value.bias")
147
+
148
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
149
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
150
+ else:
151
+ new_item = new_item.replace("q.weight", "to_q.weight")
152
+ new_item = new_item.replace("q.bias", "to_q.bias")
153
+
154
+ new_item = new_item.replace("k.weight", "to_k.weight")
155
+ new_item = new_item.replace("k.bias", "to_k.bias")
156
+
157
+ new_item = new_item.replace("v.weight", "to_v.weight")
158
+ new_item = new_item.replace("v.bias", "to_v.bias")
159
+
160
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
161
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
162
+
163
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
164
+
165
+ mapping.append({"old": old_item, "new": new_item})
166
+
167
+ return mapping
168
+
169
+
170
+ def assign_to_checkpoint(
171
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
172
+ ):
173
+ """
174
+ This does the final conversion step: take locally converted weights and apply a global renaming
175
+ to them. It splits attention layers, and takes into account additional replacements
176
+ that may arise.
177
+
178
+ Assigns the weights to the new checkpoint.
179
+ """
180
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
181
+
182
+ # Splits the attention layers into three variables.
183
+ if attention_paths_to_split is not None:
184
+ for path, path_map in attention_paths_to_split.items():
185
+ old_tensor = old_checkpoint[path]
186
+ channels = old_tensor.shape[0] // 3
187
+
188
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
189
+
190
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
191
+
192
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
193
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
194
+
195
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
196
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
197
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
198
+
199
+ for path in paths:
200
+ new_path = path["new"]
201
+
202
+ # These have already been assigned
203
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
204
+ continue
205
+
206
+ # Global renaming happens here
207
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
208
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
209
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
210
+
211
+ if additional_replacements is not None:
212
+ for replacement in additional_replacements:
213
+ new_path = new_path.replace(replacement["old"], replacement["new"])
214
+
215
+ # proj_attn.weight has to be converted from conv 1D to linear
216
+ reshaping = False
217
+ if diffusers.__version__ < "0.17.0":
218
+ if "proj_attn.weight" in new_path:
219
+ reshaping = True
220
+ else:
221
+ if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
222
+ reshaping = True
223
+
224
+ if reshaping:
225
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
226
+ else:
227
+ checkpoint[new_path] = old_checkpoint[path["old"]]
228
+
229
+
230
+ def conv_attn_to_linear(checkpoint):
231
+ keys = list(checkpoint.keys())
232
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
233
+ for key in keys:
234
+ if ".".join(key.split(".")[-2:]) in attn_keys:
235
+ if checkpoint[key].ndim > 2:
236
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
237
+ elif "proj_attn.weight" in key:
238
+ if checkpoint[key].ndim > 2:
239
+ checkpoint[key] = checkpoint[key][:, :, 0]
240
+
241
+
242
+ def linear_transformer_to_conv(checkpoint):
243
+ keys = list(checkpoint.keys())
244
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
245
+ for key in keys:
246
+ if ".".join(key.split(".")[-2:]) in tf_keys:
247
+ if checkpoint[key].ndim == 2:
248
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
249
+
250
+
251
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
252
+ """
253
+ Takes a state dict and a config, and returns a converted checkpoint.
254
+ """
255
+
256
+ # extract state_dict for UNet
257
+ unet_state_dict = {}
258
+ unet_key = "model.diffusion_model."
259
+ keys = list(checkpoint.keys())
260
+ for key in keys:
261
+ if key.startswith(unet_key):
262
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
263
+
264
+ new_checkpoint = {}
265
+
266
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
267
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
268
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
269
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
270
+
271
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
272
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
273
+
274
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
275
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
276
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
277
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
278
+
279
+ # Retrieves the keys for the input blocks only
280
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
281
+ input_blocks = {
282
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
283
+ }
284
+
285
+ # Retrieves the keys for the middle blocks only
286
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
287
+ middle_blocks = {
288
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
289
+ }
290
+
291
+ # Retrieves the keys for the output blocks only
292
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
293
+ output_blocks = {
294
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
295
+ }
296
+
297
+ for i in range(1, num_input_blocks):
298
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
299
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
300
+
301
+ resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
302
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
303
+
304
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
305
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
306
+ f"input_blocks.{i}.0.op.weight"
307
+ )
308
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
309
+
310
+ paths = renew_resnet_paths(resnets)
311
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
312
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
313
+
314
+ if len(attentions):
315
+ paths = renew_attention_paths(attentions)
316
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
317
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
318
+
319
+ resnet_0 = middle_blocks[0]
320
+ attentions = middle_blocks[1]
321
+ resnet_1 = middle_blocks[2]
322
+
323
+ resnet_0_paths = renew_resnet_paths(resnet_0)
324
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
325
+
326
+ resnet_1_paths = renew_resnet_paths(resnet_1)
327
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
328
+
329
+ attentions_paths = renew_attention_paths(attentions)
330
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
331
+ assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
332
+
333
+ for i in range(num_output_blocks):
334
+ block_id = i // (config["layers_per_block"] + 1)
335
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
336
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
337
+ output_block_list = {}
338
+
339
+ for layer in output_block_layers:
340
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
341
+ if layer_id in output_block_list:
342
+ output_block_list[layer_id].append(layer_name)
343
+ else:
344
+ output_block_list[layer_id] = [layer_name]
345
+
346
+ if len(output_block_list) > 1:
347
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
348
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
349
+
350
+ resnet_0_paths = renew_resnet_paths(resnets)
351
+ paths = renew_resnet_paths(resnets)
352
+
353
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
354
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
355
+
356
+ # オリジナル:
357
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
358
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
359
+
360
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
361
+ for l in output_block_list.values():
362
+ l.sort()
363
+
364
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
365
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
366
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
367
+ f"output_blocks.{i}.{index}.conv.bias"
368
+ ]
369
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
370
+ f"output_blocks.{i}.{index}.conv.weight"
371
+ ]
372
+
373
+ # Clear attentions as they have been attributed above.
374
+ if len(attentions) == 2:
375
+ attentions = []
376
+
377
+ if len(attentions):
378
+ paths = renew_attention_paths(attentions)
379
+ meta_path = {
380
+ "old": f"output_blocks.{i}.1",
381
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
382
+ }
383
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
384
+ else:
385
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
386
+ for path in resnet_0_paths:
387
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
388
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
389
+
390
+ new_checkpoint[new_path] = unet_state_dict[old_path]
391
+
392
+ # SDのv2では1*1のconv2dがlinearに変わっている
393
+ # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
394
+ if v2 and not config.get("use_linear_projection", False):
395
+ linear_transformer_to_conv(new_checkpoint)
396
+
397
+ return new_checkpoint
398
+
399
+
400
+ def convert_ldm_vae_checkpoint(checkpoint, config):
401
+ # extract state dict for VAE
402
+ vae_state_dict = {}
403
+ vae_key = "first_stage_model."
404
+ keys = list(checkpoint.keys())
405
+ for key in keys:
406
+ if key.startswith(vae_key):
407
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
408
+ # if len(vae_state_dict) == 0:
409
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
410
+ # vae_state_dict = checkpoint
411
+
412
+ new_checkpoint = {}
413
+
414
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
415
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
416
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
417
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
418
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
419
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
420
+
421
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
422
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
423
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
424
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
425
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
426
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
427
+
428
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
429
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
430
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
431
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
432
+
433
+ # Retrieves the keys for the encoder down blocks only
434
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
435
+ down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
436
+
437
+ # Retrieves the keys for the decoder up blocks only
438
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
439
+ up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
440
+
441
+ for i in range(num_down_blocks):
442
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
443
+
444
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
445
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
446
+ f"encoder.down.{i}.downsample.conv.weight"
447
+ )
448
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
449
+ f"encoder.down.{i}.downsample.conv.bias"
450
+ )
451
+
452
+ paths = renew_vae_resnet_paths(resnets)
453
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
454
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
455
+
456
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
457
+ num_mid_res_blocks = 2
458
+ for i in range(1, num_mid_res_blocks + 1):
459
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
460
+
461
+ paths = renew_vae_resnet_paths(resnets)
462
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
463
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
464
+
465
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
466
+ paths = renew_vae_attention_paths(mid_attentions)
467
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
468
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
469
+ conv_attn_to_linear(new_checkpoint)
470
+
471
+ for i in range(num_up_blocks):
472
+ block_id = num_up_blocks - 1 - i
473
+ resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
474
+
475
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
476
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
477
+ f"decoder.up.{block_id}.upsample.conv.weight"
478
+ ]
479
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
480
+ f"decoder.up.{block_id}.upsample.conv.bias"
481
+ ]
482
+
483
+ paths = renew_vae_resnet_paths(resnets)
484
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
485
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
486
+
487
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
488
+ num_mid_res_blocks = 2
489
+ for i in range(1, num_mid_res_blocks + 1):
490
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
491
+
492
+ paths = renew_vae_resnet_paths(resnets)
493
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
494
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
495
+
496
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
497
+ paths = renew_vae_attention_paths(mid_attentions)
498
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
499
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
500
+ conv_attn_to_linear(new_checkpoint)
501
+ return new_checkpoint
502
+
503
+
504
+ def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
505
+ """
506
+ Creates a config for the diffusers based on the config of the LDM model.
507
+ """
508
+ # unet_params = original_config.model.params.unet_config.params
509
+
510
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
511
+
512
+ down_block_types = []
513
+ resolution = 1
514
+ for i in range(len(block_out_channels)):
515
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
516
+ down_block_types.append(block_type)
517
+ if i != len(block_out_channels) - 1:
518
+ resolution *= 2
519
+
520
+ up_block_types = []
521
+ for i in range(len(block_out_channels)):
522
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
523
+ up_block_types.append(block_type)
524
+ resolution //= 2
525
+
526
+ config = dict(
527
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
528
+ in_channels=UNET_PARAMS_IN_CHANNELS,
529
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
530
+ down_block_types=tuple(down_block_types),
531
+ up_block_types=tuple(up_block_types),
532
+ block_out_channels=tuple(block_out_channels),
533
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
534
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
535
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
536
+ # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
537
+ )
538
+ if v2 and use_linear_projection_in_v2:
539
+ config["use_linear_projection"] = True
540
+
541
+ return config
542
+
543
+
544
+ def create_vae_diffusers_config():
545
+ """
546
+ Creates a config for the diffusers based on the config of the LDM model.
547
+ """
548
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
549
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
550
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
551
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
552
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
553
+
554
+ config = dict(
555
+ sample_size=VAE_PARAMS_RESOLUTION,
556
+ in_channels=VAE_PARAMS_IN_CHANNELS,
557
+ out_channels=VAE_PARAMS_OUT_CH,
558
+ down_block_types=tuple(down_block_types),
559
+ up_block_types=tuple(up_block_types),
560
+ block_out_channels=tuple(block_out_channels),
561
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
562
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
563
+ )
564
+ return config
565
+
566
+
567
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
568
+ keys = list(checkpoint.keys())
569
+ text_model_dict = {}
570
+ for key in keys:
571
+ if key.startswith("cond_stage_model.transformer"):
572
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
573
+
574
+ # support checkpoint without position_ids (invalid checkpoint)
575
+ if "text_model.embeddings.position_ids" not in text_model_dict:
576
+ text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
577
+
578
+ return text_model_dict
579
+
580
+
581
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
582
+ # 嫌になるくらい違うぞ!
583
+ def convert_key(key):
584
+ if not key.startswith("cond_stage_model"):
585
+ return None
586
+
587
+ # common conversion
588
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
589
+ key = key.replace("cond_stage_model.model.", "text_model.")
590
+
591
+ if "resblocks" in key:
592
+ # resblocks conversion
593
+ key = key.replace(".resblocks.", ".layers.")
594
+ if ".ln_" in key:
595
+ key = key.replace(".ln_", ".layer_norm")
596
+ elif ".mlp." in key:
597
+ key = key.replace(".c_fc.", ".fc1.")
598
+ key = key.replace(".c_proj.", ".fc2.")
599
+ elif ".attn.out_proj" in key:
600
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
601
+ elif ".attn.in_proj" in key:
602
+ key = None # 特殊なので後で処理する
603
+ else:
604
+ raise ValueError(f"unexpected key in SD: {key}")
605
+ elif ".positional_embedding" in key:
606
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
607
+ elif ".text_projection" in key:
608
+ key = None # 使われない???
609
+ elif ".logit_scale" in key:
610
+ key = None # 使われない???
611
+ elif ".token_embedding" in key:
612
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
613
+ elif ".ln_final" in key:
614
+ key = key.replace(".ln_final", ".final_layer_norm")
615
+ return key
616
+
617
+ keys = list(checkpoint.keys())
618
+ new_sd = {}
619
+ for key in keys:
620
+ # remove resblocks 23
621
+ if ".resblocks.23." in key:
622
+ continue
623
+ new_key = convert_key(key)
624
+ if new_key is None:
625
+ continue
626
+ new_sd[new_key] = checkpoint[key]
627
+
628
+ # attnの変換
629
+ for key in keys:
630
+ if ".resblocks.23." in key:
631
+ continue
632
+ if ".resblocks" in key and ".attn.in_proj_" in key:
633
+ # 三つに分割
634
+ values = torch.chunk(checkpoint[key], 3)
635
+
636
+ key_suffix = ".weight" if "weight" in key else ".bias"
637
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
638
+ key_pfx = key_pfx.replace("_weight", "")
639
+ key_pfx = key_pfx.replace("_bias", "")
640
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
641
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
642
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
643
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
644
+
645
+ # rename or add position_ids
646
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
647
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
648
+ # waifu diffusion v1.4
649
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
650
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
651
+ else:
652
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
653
+
654
+ new_sd["text_model.embeddings.position_ids"] = position_ids
655
+ return new_sd
656
+
657
+
658
+ # endregion
659
+
660
+
661
+ # region Diffusers->StableDiffusion の変換コード
662
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
663
+
664
+
665
+ def conv_transformer_to_linear(checkpoint):
666
+ keys = list(checkpoint.keys())
667
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
668
+ for key in keys:
669
+ if ".".join(key.split(".")[-2:]) in tf_keys:
670
+ if checkpoint[key].ndim > 2:
671
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
672
+
673
+
674
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
675
+ unet_conversion_map = [
676
+ # (stable-diffusion, HF Diffusers)
677
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
678
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
679
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
680
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
681
+ ("input_blocks.0.0.weight", "conv_in.weight"),
682
+ ("input_blocks.0.0.bias", "conv_in.bias"),
683
+ ("out.0.weight", "conv_norm_out.weight"),
684
+ ("out.0.bias", "conv_norm_out.bias"),
685
+ ("out.2.weight", "conv_out.weight"),
686
+ ("out.2.bias", "conv_out.bias"),
687
+ ]
688
+
689
+ unet_conversion_map_resnet = [
690
+ # (stable-diffusion, HF Diffusers)
691
+ ("in_layers.0", "norm1"),
692
+ ("in_layers.2", "conv1"),
693
+ ("out_layers.0", "norm2"),
694
+ ("out_layers.3", "conv2"),
695
+ ("emb_layers.1", "time_emb_proj"),
696
+ ("skip_connection", "conv_shortcut"),
697
+ ]
698
+
699
+ unet_conversion_map_layer = []
700
+ for i in range(4):
701
+ # loop over downblocks/upblocks
702
+
703
+ for j in range(2):
704
+ # loop over resnets/attentions for downblocks
705
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
706
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
707
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
708
+
709
+ if i < 3:
710
+ # no attention layers in down_blocks.3
711
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
712
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
713
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
714
+
715
+ for j in range(3):
716
+ # loop over resnets/attentions for upblocks
717
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
718
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
719
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
720
+
721
+ if i > 0:
722
+ # no attention layers in up_blocks.0
723
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
724
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
725
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
726
+
727
+ if i < 3:
728
+ # no downsample in down_blocks.3
729
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
730
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
731
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
732
+
733
+ # no upsample in up_blocks.3
734
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
735
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
736
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
737
+
738
+ hf_mid_atn_prefix = "mid_block.attentions.0."
739
+ sd_mid_atn_prefix = "middle_block.1."
740
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
741
+
742
+ for j in range(2):
743
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
744
+ sd_mid_res_prefix = f"middle_block.{2*j}."
745
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
746
+
747
+ # buyer beware: this is a *brittle* function,
748
+ # and correct output requires that all of these pieces interact in
749
+ # the exact order in which I have arranged them.
750
+ mapping = {k: k for k in unet_state_dict.keys()}
751
+ for sd_name, hf_name in unet_conversion_map:
752
+ mapping[hf_name] = sd_name
753
+ for k, v in mapping.items():
754
+ if "resnets" in k:
755
+ for sd_part, hf_part in unet_conversion_map_resnet:
756
+ v = v.replace(hf_part, sd_part)
757
+ mapping[k] = v
758
+ for k, v in mapping.items():
759
+ for sd_part, hf_part in unet_conversion_map_layer:
760
+ v = v.replace(hf_part, sd_part)
761
+ mapping[k] = v
762
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
763
+
764
+ if v2:
765
+ conv_transformer_to_linear(new_state_dict)
766
+
767
+ return new_state_dict
768
+
769
+
770
+ def controlnet_conversion_map():
771
+ unet_conversion_map = [
772
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
773
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
774
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
775
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
776
+ ("input_blocks.0.0.weight", "conv_in.weight"),
777
+ ("input_blocks.0.0.bias", "conv_in.bias"),
778
+ ("middle_block_out.0.weight", "controlnet_mid_block.weight"),
779
+ ("middle_block_out.0.bias", "controlnet_mid_block.bias"),
780
+ ]
781
+
782
+ unet_conversion_map_resnet = [
783
+ ("in_layers.0", "norm1"),
784
+ ("in_layers.2", "conv1"),
785
+ ("out_layers.0", "norm2"),
786
+ ("out_layers.3", "conv2"),
787
+ ("emb_layers.1", "time_emb_proj"),
788
+ ("skip_connection", "conv_shortcut"),
789
+ ]
790
+
791
+ unet_conversion_map_layer = []
792
+ for i in range(4):
793
+ for j in range(2):
794
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
795
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
796
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
797
+
798
+ if i < 3:
799
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
800
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
801
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
802
+
803
+ if i < 3:
804
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
805
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
806
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
807
+
808
+ hf_mid_atn_prefix = "mid_block.attentions.0."
809
+ sd_mid_atn_prefix = "middle_block.1."
810
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
811
+
812
+ for j in range(2):
813
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
814
+ sd_mid_res_prefix = f"middle_block.{2*j}."
815
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
816
+
817
+ controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
818
+ for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
819
+ hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
820
+ sd_prefix = f"input_hint_block.{i*2}."
821
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
822
+
823
+ for i in range(12):
824
+ hf_prefix = f"controlnet_down_blocks.{i}."
825
+ sd_prefix = f"zero_convs.{i}.0."
826
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
827
+
828
+ return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
829
+
830
+
831
+ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
832
+ unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
833
+
834
+ mapping = {k: k for k in controlnet_state_dict.keys()}
835
+ for sd_name, diffusers_name in unet_conversion_map:
836
+ mapping[diffusers_name] = sd_name
837
+ for k, v in mapping.items():
838
+ if "resnets" in k:
839
+ for sd_part, diffusers_part in unet_conversion_map_resnet:
840
+ v = v.replace(diffusers_part, sd_part)
841
+ mapping[k] = v
842
+ for k, v in mapping.items():
843
+ for sd_part, diffusers_part in unet_conversion_map_layer:
844
+ v = v.replace(diffusers_part, sd_part)
845
+ mapping[k] = v
846
+ new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
847
+ return new_state_dict
848
+
849
+
850
+ def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
851
+ unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
852
+
853
+ mapping = {k: k for k in controlnet_state_dict.keys()}
854
+ for sd_name, diffusers_name in unet_conversion_map:
855
+ mapping[sd_name] = diffusers_name
856
+ for k, v in mapping.items():
857
+ for sd_part, diffusers_part in unet_conversion_map_layer:
858
+ v = v.replace(sd_part, diffusers_part)
859
+ mapping[k] = v
860
+ for k, v in mapping.items():
861
+ if "resnets" in v:
862
+ for sd_part, diffusers_part in unet_conversion_map_resnet:
863
+ v = v.replace(sd_part, diffusers_part)
864
+ mapping[k] = v
865
+ new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
866
+ return new_state_dict
867
+
868
+
869
+ # ================#
870
+ # VAE Conversion #
871
+ # ================#
872
+
873
+
874
+ def reshape_weight_for_sd(w):
875
+ # convert HF linear weights to SD conv2d weights
876
+ return w.reshape(*w.shape, 1, 1)
877
+
878
+
879
+ def convert_vae_state_dict(vae_state_dict):
880
+ vae_conversion_map = [
881
+ # (stable-diffusion, HF Diffusers)
882
+ ("nin_shortcut", "conv_shortcut"),
883
+ ("norm_out", "conv_norm_out"),
884
+ ("mid.attn_1.", "mid_block.attentions.0."),
885
+ ]
886
+
887
+ for i in range(4):
888
+ # down_blocks have two resnets
889
+ for j in range(2):
890
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
891
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
892
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
893
+
894
+ if i < 3:
895
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
896
+ sd_downsample_prefix = f"down.{i}.downsample."
897
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
898
+
899
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
900
+ sd_upsample_prefix = f"up.{3-i}.upsample."
901
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
902
+
903
+ # up_blocks have three resnets
904
+ # also, up blocks in hf are numbered in reverse from sd
905
+ for j in range(3):
906
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
907
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
908
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
909
+
910
+ # this part accounts for mid blocks in both the encoder and the decoder
911
+ for i in range(2):
912
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
913
+ sd_mid_res_prefix = f"mid.block_{i+1}."
914
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
915
+
916
+ if diffusers.__version__ < "0.17.0":
917
+ vae_conversion_map_attn = [
918
+ # (stable-diffusion, HF Diffusers)
919
+ ("norm.", "group_norm."),
920
+ ("q.", "query."),
921
+ ("k.", "key."),
922
+ ("v.", "value."),
923
+ ("proj_out.", "proj_attn."),
924
+ ]
925
+ else:
926
+ vae_conversion_map_attn = [
927
+ # (stable-diffusion, HF Diffusers)
928
+ ("norm.", "group_norm."),
929
+ ("q.", "to_q."),
930
+ ("k.", "to_k."),
931
+ ("v.", "to_v."),
932
+ ("proj_out.", "to_out.0."),
933
+ ]
934
+
935
+ mapping = {k: k for k in vae_state_dict.keys()}
936
+ for k, v in mapping.items():
937
+ for sd_part, hf_part in vae_conversion_map:
938
+ v = v.replace(hf_part, sd_part)
939
+ mapping[k] = v
940
+ for k, v in mapping.items():
941
+ if "attentions" in k:
942
+ for sd_part, hf_part in vae_conversion_map_attn:
943
+ v = v.replace(hf_part, sd_part)
944
+ mapping[k] = v
945
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
946
+ weights_to_convert = ["q", "k", "v", "proj_out"]
947
+ for k, v in new_state_dict.items():
948
+ for weight_name in weights_to_convert:
949
+ if f"mid.attn_1.{weight_name}.weight" in k:
950
+ # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
951
+ new_state_dict[k] = reshape_weight_for_sd(v)
952
+
953
+ return new_state_dict
954
+
955
+
956
+ # endregion
957
+
958
+ # region 自作のモデル読み書きなど
959
+
960
+
961
+ def is_safetensors(path):
962
+ return os.path.splitext(path)[1].lower() == ".safetensors"
963
+
964
+
965
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
966
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
967
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
968
+ ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
969
+ ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
970
+ ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
971
+ ]
972
+
973
+ if is_safetensors(ckpt_path):
974
+ checkpoint = None
975
+ state_dict = load_file(ckpt_path) # , device) # may causes error
976
+ else:
977
+ checkpoint = torch.load(ckpt_path, map_location=device)
978
+ if "state_dict" in checkpoint:
979
+ state_dict = checkpoint["state_dict"]
980
+ else:
981
+ state_dict = checkpoint
982
+ checkpoint = None
983
+
984
+ key_reps = []
985
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
986
+ for key in state_dict.keys():
987
+ if key.startswith(rep_from):
988
+ new_key = rep_to + key[len(rep_from) :]
989
+ key_reps.append((key, new_key))
990
+
991
+ for key, new_key in key_reps:
992
+ state_dict[new_key] = state_dict[key]
993
+ del state_dict[key]
994
+
995
+ return checkpoint, state_dict
996
+
997
+
998
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
999
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
1000
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
1001
+
1002
+ # Convert the UNet2DConditionModel model.
1003
+ unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
1004
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
1005
+
1006
+ unet = UNet2DConditionModel(**unet_config).to(device)
1007
+ info = unet.load_state_dict(converted_unet_checkpoint)
1008
+ print("loading u-net:", info)
1009
+
1010
+ # Convert the VAE model.
1011
+ vae_config = create_vae_diffusers_config()
1012
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
1013
+
1014
+ vae = AutoencoderKL(**vae_config).to(device)
1015
+ info = vae.load_state_dict(converted_vae_checkpoint)
1016
+ print("loading vae:", info)
1017
+
1018
+ # convert text_model
1019
+ if v2:
1020
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
1021
+ cfg = CLIPTextConfig(
1022
+ vocab_size=49408,
1023
+ hidden_size=1024,
1024
+ intermediate_size=4096,
1025
+ num_hidden_layers=23,
1026
+ num_attention_heads=16,
1027
+ max_position_embeddings=77,
1028
+ hidden_act="gelu",
1029
+ layer_norm_eps=1e-05,
1030
+ dropout=0.0,
1031
+ attention_dropout=0.0,
1032
+ initializer_range=0.02,
1033
+ initializer_factor=1.0,
1034
+ pad_token_id=1,
1035
+ bos_token_id=0,
1036
+ eos_token_id=2,
1037
+ model_type="clip_text_model",
1038
+ projection_dim=512,
1039
+ torch_dtype="float32",
1040
+ transformers_version="4.25.0.dev0",
1041
+ )
1042
+ text_model = CLIPTextModel._from_config(cfg)
1043
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1044
+ else:
1045
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
1046
+
1047
+ # logging.set_verbosity_error() # don't show annoying warning
1048
+ # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
1049
+ # logging.set_verbosity_warning()
1050
+ # print(f"config: {text_model.config}")
1051
+ cfg = CLIPTextConfig(
1052
+ vocab_size=49408,
1053
+ hidden_size=768,
1054
+ intermediate_size=3072,
1055
+ num_hidden_layers=12,
1056
+ num_attention_heads=12,
1057
+ max_position_embeddings=77,
1058
+ hidden_act="quick_gelu",
1059
+ layer_norm_eps=1e-05,
1060
+ dropout=0.0,
1061
+ attention_dropout=0.0,
1062
+ initializer_range=0.02,
1063
+ initializer_factor=1.0,
1064
+ pad_token_id=1,
1065
+ bos_token_id=0,
1066
+ eos_token_id=2,
1067
+ model_type="clip_text_model",
1068
+ projection_dim=768,
1069
+ torch_dtype="float32",
1070
+ )
1071
+ text_model = CLIPTextModel._from_config(cfg)
1072
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1073
+ print("loading text encoder:", info)
1074
+
1075
+ return text_model, vae, unet
1076
+
1077
+
1078
+ def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
1079
+ # only for reference
1080
+ version_str = "sd"
1081
+ if v2:
1082
+ version_str += "_v2"
1083
+ else:
1084
+ version_str += "_v1"
1085
+ if v_parameterization:
1086
+ version_str += "_v"
1087
+ return version_str
1088
+
1089
+
1090
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
1091
+ def convert_key(key):
1092
+ # position_idsの除去
1093
+ if ".position_ids" in key:
1094
+ return None
1095
+
1096
+ # common
1097
+ key = key.replace("text_model.encoder.", "transformer.")
1098
+ key = key.replace("text_model.", "")
1099
+ if "layers" in key:
1100
+ # resblocks conversion
1101
+ key = key.replace(".layers.", ".resblocks.")
1102
+ if ".layer_norm" in key:
1103
+ key = key.replace(".layer_norm", ".ln_")
1104
+ elif ".mlp." in key:
1105
+ key = key.replace(".fc1.", ".c_fc.")
1106
+ key = key.replace(".fc2.", ".c_proj.")
1107
+ elif ".self_attn.out_proj" in key:
1108
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
1109
+ elif ".self_attn." in key:
1110
+ key = None # 特殊なので後で処理する
1111
+ else:
1112
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
1113
+ elif ".position_embedding" in key:
1114
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
1115
+ elif ".token_embedding" in key:
1116
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
1117
+ elif "final_layer_norm" in key:
1118
+ key = key.replace("final_layer_norm", "ln_final")
1119
+ return key
1120
+
1121
+ keys = list(checkpoint.keys())
1122
+ new_sd = {}
1123
+ for key in keys:
1124
+ new_key = convert_key(key)
1125
+ if new_key is None:
1126
+ continue
1127
+ new_sd[new_key] = checkpoint[key]
1128
+
1129
+ # attnの変換
1130
+ for key in keys:
1131
+ if "layers" in key and "q_proj" in key:
1132
+ # 三つを結合
1133
+ key_q = key
1134
+ key_k = key.replace("q_proj", "k_proj")
1135
+ key_v = key.replace("q_proj", "v_proj")
1136
+
1137
+ value_q = checkpoint[key_q]
1138
+ value_k = checkpoint[key_k]
1139
+ value_v = checkpoint[key_v]
1140
+ value = torch.cat([value_q, value_k, value_v])
1141
+
1142
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
1143
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
1144
+ new_sd[new_key] = value
1145
+
1146
+ # 最後の層などを捏造するか
1147
+ if make_dummy_weights:
1148
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
1149
+ keys = list(new_sd.keys())
1150
+ for key in keys:
1151
+ if key.startswith("transformer.resblocks.22."):
1152
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
1153
+
1154
+ # Diffusersに含まれない重みを作っておく
1155
+ new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
1156
+ new_sd["logit_scale"] = torch.tensor(1)
1157
+
1158
+ return new_sd
1159
+
1160
+
1161
+ def save_stable_diffusion_checkpoint(
1162
+ v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
1163
+ ):
1164
+ if ckpt_path is not None:
1165
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1166
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1167
+ if checkpoint is None: # safetensors または state_dictのckpt
1168
+ checkpoint = {}
1169
+ strict = False
1170
+ else:
1171
+ strict = True
1172
+ if "state_dict" in state_dict:
1173
+ del state_dict["state_dict"]
1174
+ else:
1175
+ # 新しく作る
1176
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1177
+ checkpoint = {}
1178
+ state_dict = {}
1179
+ strict = False
1180
+
1181
+ def update_sd(prefix, sd):
1182
+ for k, v in sd.items():
1183
+ key = prefix + k
1184
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1185
+ if save_dtype is not None:
1186
+ v = v.detach().clone().to("cpu").to(save_dtype)
1187
+ state_dict[key] = v
1188
+
1189
+ # Convert the UNet model
1190
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1191
+ update_sd("model.diffusion_model.", unet_state_dict)
1192
+
1193
+ # Convert the text encoder model
1194
+ if v2:
1195
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1196
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1197
+ update_sd("cond_stage_model.model.", text_enc_dict)
1198
+ else:
1199
+ text_enc_dict = text_encoder.state_dict()
1200
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1201
+
1202
+ # Convert the VAE
1203
+ if vae is not None:
1204
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1205
+ update_sd("first_stage_model.", vae_dict)
1206
+
1207
+ # Put together new checkpoint
1208
+ key_count = len(state_dict.keys())
1209
+ new_ckpt = {"state_dict": state_dict}
1210
+
1211
+ # epoch and global_step are sometimes not int
1212
+ try:
1213
+ if "epoch" in checkpoint:
1214
+ epochs += checkpoint["epoch"]
1215
+ if "global_step" in checkpoint:
1216
+ steps += checkpoint["global_step"]
1217
+ except:
1218
+ pass
1219
+
1220
+ new_ckpt["epoch"] = epochs
1221
+ new_ckpt["global_step"] = steps
1222
+
1223
+ if is_safetensors(output_file):
1224
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1225
+ save_file(state_dict, output_file, metadata)
1226
+ else:
1227
+ torch.save(new_ckpt, output_file)
1228
+
1229
+ return key_count
1230
+
1231
+
1232
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1233
+ if pretrained_model_name_or_path is None:
1234
+ # load default settings for v1/v2
1235
+ if v2:
1236
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1237
+ else:
1238
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1239
+
1240
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1241
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1242
+ if vae is None:
1243
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1244
+
1245
+ pipeline = StableDiffusionPipeline(
1246
+ unet=unet,
1247
+ text_encoder=text_encoder,
1248
+ vae=vae,
1249
+ scheduler=scheduler,
1250
+ tokenizer=tokenizer,
1251
+ safety_checker=None,
1252
+ feature_extractor=None,
1253
+ requires_safety_checker=None,
1254
+ )
1255
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1256
+
1257
+
1258
+ VAE_PREFIX = "first_stage_model."
1259
+
1260
+
1261
+ def load_vae(vae_id, dtype):
1262
+ print(f"load VAE: {vae_id}")
1263
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1264
+ # Diffusers local/remote
1265
+ try:
1266
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1267
+ except EnvironmentError as e:
1268
+ print(f"exception occurs in loading vae: {e}")
1269
+ print("retry with subfolder='vae'")
1270
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1271
+ return vae
1272
+
1273
+ # local
1274
+ vae_config = create_vae_diffusers_config()
1275
+
1276
+ if vae_id.endswith(".bin"):
1277
+ # SD 1.5 VAE on Huggingface
1278
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1279
+ else:
1280
+ # StableDiffusion
1281
+ vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
1282
+ vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
1283
+
1284
+ # vae only or full model
1285
+ full_model = False
1286
+ for vae_key in vae_sd:
1287
+ if vae_key.startswith(VAE_PREFIX):
1288
+ full_model = True
1289
+ break
1290
+ if not full_model:
1291
+ sd = {}
1292
+ for key, value in vae_sd.items():
1293
+ sd[VAE_PREFIX + key] = value
1294
+ vae_sd = sd
1295
+ del sd
1296
+
1297
+ # Convert the VAE model.
1298
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1299
+
1300
+ vae = AutoencoderKL(**vae_config)
1301
+ vae.load_state_dict(converted_vae_checkpoint)
1302
+ return vae
1303
+
1304
+
1305
+ # endregion
1306
+
1307
+
1308
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1309
+ max_width, max_height = max_reso
1310
+ max_area = max_width * max_height
1311
+
1312
+ resos = set()
1313
+
1314
+ width = int(math.sqrt(max_area) // divisible) * divisible
1315
+ resos.add((width, width))
1316
+
1317
+ width = min_size
1318
+ while width <= max_size:
1319
+ height = min(max_size, int((max_area // width) // divisible) * divisible)
1320
+ if height >= min_size:
1321
+ resos.add((width, height))
1322
+ resos.add((height, width))
1323
+
1324
+ # # make additional resos
1325
+ # if width >= height and width - divisible >= min_size:
1326
+ # resos.add((width - divisible, height))
1327
+ # resos.add((height, width - divisible))
1328
+ # if height >= width and height - divisible >= min_size:
1329
+ # resos.add((width, height - divisible))
1330
+ # resos.add((height - divisible, width))
1331
+
1332
+ width += divisible
1333
+
1334
+ resos = list(resos)
1335
+ resos.sort()
1336
+ return resos
1337
+
1338
+
1339
+ if __name__ == "__main__":
1340
+ resos = make_bucket_resolutions((512, 768))
1341
+ print(len(resos))
1342
+ print(resos)
1343
+ aspect_ratios = [w / h for w, h in resos]
1344
+ print(aspect_ratios)
1345
+
1346
+ ars = set()
1347
+ for ar in aspect_ratios:
1348
+ if ar in ars:
1349
+ print("error! duplicate ar:", ar)
1350
+ ars.add(ar)
external/llite/library/original_unet.py ADDED
@@ -0,0 +1,1915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
2
+ # 条件分岐等で不要な部分は削除している
3
+ # コードの多くはDiffusersからコピーしている
4
+ # 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
5
+
6
+ # Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
7
+ # Unnecessary parts are deleted by condition branching.
8
+ # As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
9
+
10
+ """
11
+ v1.5とv2.1の相違点は
12
+ - attention_head_dimがintかlist[int]か
13
+ - cross_attention_dimが768か1024か
14
+ - use_linear_projection: trueがない(=False, 1.5)かあるか
15
+ - upcast_attentionがFalse(1.5)かTrue(2.1)か
16
+ - (以下は多分無視していい)
17
+ - sample_sizeが64か96か
18
+ - dual_cross_attentionがあるかないか
19
+ - num_class_embedsがあるかないか
20
+ - only_cross_attentionがあるかないか
21
+
22
+ v1.5
23
+ {
24
+ "_class_name": "UNet2DConditionModel",
25
+ "_diffusers_version": "0.6.0",
26
+ "act_fn": "silu",
27
+ "attention_head_dim": 8,
28
+ "block_out_channels": [
29
+ 320,
30
+ 640,
31
+ 1280,
32
+ 1280
33
+ ],
34
+ "center_input_sample": false,
35
+ "cross_attention_dim": 768,
36
+ "down_block_types": [
37
+ "CrossAttnDownBlock2D",
38
+ "CrossAttnDownBlock2D",
39
+ "CrossAttnDownBlock2D",
40
+ "DownBlock2D"
41
+ ],
42
+ "downsample_padding": 1,
43
+ "flip_sin_to_cos": true,
44
+ "freq_shift": 0,
45
+ "in_channels": 4,
46
+ "layers_per_block": 2,
47
+ "mid_block_scale_factor": 1,
48
+ "norm_eps": 1e-05,
49
+ "norm_num_groups": 32,
50
+ "out_channels": 4,
51
+ "sample_size": 64,
52
+ "up_block_types": [
53
+ "UpBlock2D",
54
+ "CrossAttnUpBlock2D",
55
+ "CrossAttnUpBlock2D",
56
+ "CrossAttnUpBlock2D"
57
+ ]
58
+ }
59
+
60
+ v2.1
61
+ {
62
+ "_class_name": "UNet2DConditionModel",
63
+ "_diffusers_version": "0.10.0.dev0",
64
+ "act_fn": "silu",
65
+ "attention_head_dim": [
66
+ 5,
67
+ 10,
68
+ 20,
69
+ 20
70
+ ],
71
+ "block_out_channels": [
72
+ 320,
73
+ 640,
74
+ 1280,
75
+ 1280
76
+ ],
77
+ "center_input_sample": false,
78
+ "cross_attention_dim": 1024,
79
+ "down_block_types": [
80
+ "CrossAttnDownBlock2D",
81
+ "CrossAttnDownBlock2D",
82
+ "CrossAttnDownBlock2D",
83
+ "DownBlock2D"
84
+ ],
85
+ "downsample_padding": 1,
86
+ "dual_cross_attention": false,
87
+ "flip_sin_to_cos": true,
88
+ "freq_shift": 0,
89
+ "in_channels": 4,
90
+ "layers_per_block": 2,
91
+ "mid_block_scale_factor": 1,
92
+ "norm_eps": 1e-05,
93
+ "norm_num_groups": 32,
94
+ "num_class_embeds": null,
95
+ "only_cross_attention": false,
96
+ "out_channels": 4,
97
+ "sample_size": 96,
98
+ "up_block_types": [
99
+ "UpBlock2D",
100
+ "CrossAttnUpBlock2D",
101
+ "CrossAttnUpBlock2D",
102
+ "CrossAttnUpBlock2D"
103
+ ],
104
+ "use_linear_projection": true,
105
+ "upcast_attention": true
106
+ }
107
+ """
108
+
109
+ import math
110
+ from types import SimpleNamespace
111
+ from typing import Dict, Optional, Tuple, Union
112
+ import torch
113
+ from torch import nn
114
+ from torch.nn import functional as F
115
+ from einops import rearrange
116
+
117
+ BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
118
+ TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
119
+ TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4
120
+ IN_CHANNELS: int = 4
121
+ OUT_CHANNELS: int = 4
122
+ LAYERS_PER_BLOCK: int = 2
123
+ LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
124
+ TIME_EMBED_FLIP_SIN_TO_COS: bool = True
125
+ TIME_EMBED_FREQ_SHIFT: int = 0
126
+ NORM_GROUPS: int = 32
127
+ NORM_EPS: float = 1e-5
128
+ TRANSFORMER_NORM_NUM_GROUPS = 32
129
+
130
+ DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
131
+ UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
132
+
133
+
134
+ # region memory efficient attention
135
+
136
+ # FlashAttentionを使うCrossAttention
137
+ # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
138
+ # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
139
+
140
+ # constants
141
+
142
+ EPSILON = 1e-6
143
+
144
+ # helper functions
145
+
146
+
147
+ def exists(val):
148
+ return val is not None
149
+
150
+
151
+ def default(val, d):
152
+ return val if exists(val) else d
153
+
154
+
155
+ # flash attention forwards and backwards
156
+
157
+ # https://arxiv.org/abs/2205.14135
158
+
159
+
160
+ class FlashAttentionFunction(torch.autograd.Function):
161
+ @staticmethod
162
+ @torch.no_grad()
163
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
164
+ """Algorithm 2 in the paper"""
165
+
166
+ device = q.device
167
+ dtype = q.dtype
168
+ max_neg_value = -torch.finfo(q.dtype).max
169
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
170
+
171
+ o = torch.zeros_like(q)
172
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
173
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
174
+
175
+ scale = q.shape[-1] ** -0.5
176
+
177
+ if not exists(mask):
178
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
179
+ else:
180
+ mask = rearrange(mask, "b n -> b 1 1 n")
181
+ mask = mask.split(q_bucket_size, dim=-1)
182
+
183
+ row_splits = zip(
184
+ q.split(q_bucket_size, dim=-2),
185
+ o.split(q_bucket_size, dim=-2),
186
+ mask,
187
+ all_row_sums.split(q_bucket_size, dim=-2),
188
+ all_row_maxes.split(q_bucket_size, dim=-2),
189
+ )
190
+
191
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
192
+ q_start_index = ind * q_bucket_size - qk_len_diff
193
+
194
+ col_splits = zip(
195
+ k.split(k_bucket_size, dim=-2),
196
+ v.split(k_bucket_size, dim=-2),
197
+ )
198
+
199
+ for k_ind, (kc, vc) in enumerate(col_splits):
200
+ k_start_index = k_ind * k_bucket_size
201
+
202
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
203
+
204
+ if exists(row_mask):
205
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
206
+
207
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
208
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
209
+ q_start_index - k_start_index + 1
210
+ )
211
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
212
+
213
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
214
+ attn_weights -= block_row_maxes
215
+ exp_weights = torch.exp(attn_weights)
216
+
217
+ if exists(row_mask):
218
+ exp_weights.masked_fill_(~row_mask, 0.0)
219
+
220
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
221
+
222
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
223
+
224
+ exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
225
+
226
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
227
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
228
+
229
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
230
+
231
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
232
+
233
+ row_maxes.copy_(new_row_maxes)
234
+ row_sums.copy_(new_row_sums)
235
+
236
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
237
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
238
+
239
+ return o
240
+
241
+ @staticmethod
242
+ @torch.no_grad()
243
+ def backward(ctx, do):
244
+ """Algorithm 4 in the paper"""
245
+
246
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
247
+ q, k, v, o, l, m = ctx.saved_tensors
248
+
249
+ device = q.device
250
+
251
+ max_neg_value = -torch.finfo(q.dtype).max
252
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
253
+
254
+ dq = torch.zeros_like(q)
255
+ dk = torch.zeros_like(k)
256
+ dv = torch.zeros_like(v)
257
+
258
+ row_splits = zip(
259
+ q.split(q_bucket_size, dim=-2),
260
+ o.split(q_bucket_size, dim=-2),
261
+ do.split(q_bucket_size, dim=-2),
262
+ mask,
263
+ l.split(q_bucket_size, dim=-2),
264
+ m.split(q_bucket_size, dim=-2),
265
+ dq.split(q_bucket_size, dim=-2),
266
+ )
267
+
268
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
269
+ q_start_index = ind * q_bucket_size - qk_len_diff
270
+
271
+ col_splits = zip(
272
+ k.split(k_bucket_size, dim=-2),
273
+ v.split(k_bucket_size, dim=-2),
274
+ dk.split(k_bucket_size, dim=-2),
275
+ dv.split(k_bucket_size, dim=-2),
276
+ )
277
+
278
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
279
+ k_start_index = k_ind * k_bucket_size
280
+
281
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
282
+
283
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
284
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
285
+ q_start_index - k_start_index + 1
286
+ )
287
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
288
+
289
+ exp_attn_weights = torch.exp(attn_weights - mc)
290
+
291
+ if exists(row_mask):
292
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
293
+
294
+ p = exp_attn_weights / lc
295
+
296
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
297
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
298
+
299
+ D = (doc * oc).sum(dim=-1, keepdims=True)
300
+ ds = p * scale * (dp - D)
301
+
302
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
303
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
304
+
305
+ dqc.add_(dq_chunk)
306
+ dkc.add_(dk_chunk)
307
+ dvc.add_(dv_chunk)
308
+
309
+ return dq, dk, dv, None, None, None, None
310
+
311
+
312
+ # endregion
313
+
314
+
315
+ def get_parameter_dtype(parameter: torch.nn.Module):
316
+ return next(parameter.parameters()).dtype
317
+
318
+
319
+ def get_parameter_device(parameter: torch.nn.Module):
320
+ return next(parameter.parameters()).device
321
+
322
+
323
+ def get_timestep_embedding(
324
+ timesteps: torch.Tensor,
325
+ embedding_dim: int,
326
+ flip_sin_to_cos: bool = False,
327
+ downscale_freq_shift: float = 1,
328
+ scale: float = 1,
329
+ max_period: int = 10000,
330
+ ):
331
+ """
332
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
333
+
334
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
335
+ These may be fractional.
336
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
337
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
338
+ """
339
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
340
+
341
+ half_dim = embedding_dim // 2
342
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
343
+ exponent = exponent / (half_dim - downscale_freq_shift)
344
+
345
+ emb = torch.exp(exponent)
346
+ emb = timesteps[:, None].float() * emb[None, :]
347
+
348
+ # scale embeddings
349
+ emb = scale * emb
350
+
351
+ # concat sine and cosine embeddings
352
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
353
+
354
+ # flip sine and cosine embeddings
355
+ if flip_sin_to_cos:
356
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
357
+
358
+ # zero pad
359
+ if embedding_dim % 2 == 1:
360
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
361
+ return emb
362
+
363
+
364
+ # Deep Shrink: We do not common this function, because minimize dependencies.
365
+ def resize_like(x, target, mode="bicubic", align_corners=False):
366
+ org_dtype = x.dtype
367
+ if org_dtype == torch.bfloat16:
368
+ x = x.to(torch.float32)
369
+
370
+ if x.shape[-2:] != target.shape[-2:]:
371
+ if mode == "nearest":
372
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode)
373
+ else:
374
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
375
+
376
+ if org_dtype == torch.bfloat16:
377
+ x = x.to(org_dtype)
378
+ return x
379
+
380
+
381
+ class SampleOutput:
382
+ def __init__(self, sample):
383
+ self.sample = sample
384
+
385
+
386
+ class TimestepEmbedding(nn.Module):
387
+ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
388
+ super().__init__()
389
+
390
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
391
+ self.act = None
392
+ if act_fn == "silu":
393
+ self.act = nn.SiLU()
394
+ elif act_fn == "mish":
395
+ self.act = nn.Mish()
396
+
397
+ if out_dim is not None:
398
+ time_embed_dim_out = out_dim
399
+ else:
400
+ time_embed_dim_out = time_embed_dim
401
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
402
+
403
+ def forward(self, sample):
404
+ sample = self.linear_1(sample)
405
+
406
+ if self.act is not None:
407
+ sample = self.act(sample)
408
+
409
+ sample = self.linear_2(sample)
410
+ return sample
411
+
412
+
413
+ class Timesteps(nn.Module):
414
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
415
+ super().__init__()
416
+ self.num_channels = num_channels
417
+ self.flip_sin_to_cos = flip_sin_to_cos
418
+ self.downscale_freq_shift = downscale_freq_shift
419
+
420
+ def forward(self, timesteps):
421
+ t_emb = get_timestep_embedding(
422
+ timesteps,
423
+ self.num_channels,
424
+ flip_sin_to_cos=self.flip_sin_to_cos,
425
+ downscale_freq_shift=self.downscale_freq_shift,
426
+ )
427
+ return t_emb
428
+
429
+
430
+ class ResnetBlock2D(nn.Module):
431
+ def __init__(
432
+ self,
433
+ in_channels,
434
+ out_channels,
435
+ ):
436
+ super().__init__()
437
+ self.in_channels = in_channels
438
+ self.out_channels = out_channels
439
+
440
+ self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
441
+
442
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
443
+
444
+ self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
445
+
446
+ self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
447
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
448
+
449
+ # if non_linearity == "swish":
450
+ self.nonlinearity = lambda x: F.silu(x)
451
+
452
+ self.use_in_shortcut = self.in_channels != self.out_channels
453
+
454
+ self.conv_shortcut = None
455
+ if self.use_in_shortcut:
456
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
457
+
458
+ def forward(self, input_tensor, temb):
459
+ hidden_states = input_tensor
460
+
461
+ hidden_states = self.norm1(hidden_states)
462
+ hidden_states = self.nonlinearity(hidden_states)
463
+
464
+ hidden_states = self.conv1(hidden_states)
465
+
466
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
467
+ hidden_states = hidden_states + temb
468
+
469
+ hidden_states = self.norm2(hidden_states)
470
+ hidden_states = self.nonlinearity(hidden_states)
471
+
472
+ hidden_states = self.conv2(hidden_states)
473
+
474
+ if self.conv_shortcut is not None:
475
+ input_tensor = self.conv_shortcut(input_tensor)
476
+
477
+ output_tensor = input_tensor + hidden_states
478
+
479
+ return output_tensor
480
+
481
+
482
+ class DownBlock2D(nn.Module):
483
+ def __init__(
484
+ self,
485
+ in_channels: int,
486
+ out_channels: int,
487
+ add_downsample=True,
488
+ ):
489
+ super().__init__()
490
+
491
+ self.has_cross_attention = False
492
+ resnets = []
493
+
494
+ for i in range(LAYERS_PER_BLOCK):
495
+ in_channels = in_channels if i == 0 else out_channels
496
+ resnets.append(
497
+ ResnetBlock2D(
498
+ in_channels=in_channels,
499
+ out_channels=out_channels,
500
+ )
501
+ )
502
+ self.resnets = nn.ModuleList(resnets)
503
+
504
+ if add_downsample:
505
+ self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)]
506
+ else:
507
+ self.downsamplers = None
508
+
509
+ self.gradient_checkpointing = False
510
+
511
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
512
+ pass
513
+
514
+ def set_use_sdpa(self, sdpa):
515
+ pass
516
+
517
+ def forward(self, hidden_states, temb=None):
518
+ output_states = ()
519
+
520
+ for resnet in self.resnets:
521
+ if self.training and self.gradient_checkpointing:
522
+
523
+ def create_custom_forward(module):
524
+ def custom_forward(*inputs):
525
+ return module(*inputs)
526
+
527
+ return custom_forward
528
+
529
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
530
+ else:
531
+ hidden_states = resnet(hidden_states, temb)
532
+
533
+ output_states += (hidden_states,)
534
+
535
+ if self.downsamplers is not None:
536
+ for downsampler in self.downsamplers:
537
+ hidden_states = downsampler(hidden_states)
538
+
539
+ output_states += (hidden_states,)
540
+
541
+ return hidden_states, output_states
542
+
543
+
544
+ class Downsample2D(nn.Module):
545
+ def __init__(self, channels, out_channels):
546
+ super().__init__()
547
+
548
+ self.channels = channels
549
+ self.out_channels = out_channels
550
+
551
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
552
+
553
+ def forward(self, hidden_states):
554
+ assert hidden_states.shape[1] == self.channels
555
+ hidden_states = self.conv(hidden_states)
556
+
557
+ return hidden_states
558
+
559
+
560
+ class CrossAttention(nn.Module):
561
+ def __init__(
562
+ self,
563
+ query_dim: int,
564
+ cross_attention_dim: Optional[int] = None,
565
+ heads: int = 8,
566
+ dim_head: int = 64,
567
+ upcast_attention: bool = False,
568
+ ):
569
+ super().__init__()
570
+ inner_dim = dim_head * heads
571
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
572
+ self.upcast_attention = upcast_attention
573
+
574
+ self.scale = dim_head**-0.5
575
+ self.heads = heads
576
+
577
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
578
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
579
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
580
+
581
+ self.to_out = nn.ModuleList([])
582
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
583
+ # no dropout here
584
+
585
+ self.use_memory_efficient_attention_xformers = False
586
+ self.use_memory_efficient_attention_mem_eff = False
587
+ self.use_sdpa = False
588
+
589
+ # Attention processor
590
+ self.processor = None
591
+
592
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
593
+ self.use_memory_efficient_attention_xformers = xformers
594
+ self.use_memory_efficient_attention_mem_eff = mem_eff
595
+
596
+ def set_use_sdpa(self, sdpa):
597
+ self.use_sdpa = sdpa
598
+
599
+ def reshape_heads_to_batch_dim(self, tensor):
600
+ batch_size, seq_len, dim = tensor.shape
601
+ head_size = self.heads
602
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
603
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
604
+ return tensor
605
+
606
+ def reshape_batch_dim_to_heads(self, tensor):
607
+ batch_size, seq_len, dim = tensor.shape
608
+ head_size = self.heads
609
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
610
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
611
+ return tensor
612
+
613
+ def set_processor(self):
614
+ return self.processor
615
+
616
+ def get_processor(self):
617
+ return self.processor
618
+
619
+ def forward(self, hidden_states, context=None, mask=None, **kwargs):
620
+ if self.processor is not None:
621
+ (
622
+ hidden_states,
623
+ encoder_hidden_states,
624
+ attention_mask,
625
+ ) = translate_attention_names_from_diffusers(
626
+ hidden_states=hidden_states, context=context, mask=mask, **kwargs
627
+ )
628
+ return self.processor(
629
+ attn=self,
630
+ hidden_states=hidden_states,
631
+ encoder_hidden_states=context,
632
+ attention_mask=mask,
633
+ **kwargs
634
+ )
635
+ if self.use_memory_efficient_attention_xformers:
636
+ return self.forward_memory_efficient_xformers(hidden_states, context, mask)
637
+ if self.use_memory_efficient_attention_mem_eff:
638
+ return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
639
+ if self.use_sdpa:
640
+ return self.forward_sdpa(hidden_states, context, mask)
641
+
642
+ query = self.to_q(hidden_states)
643
+ context = context if context is not None else hidden_states
644
+ key = self.to_k(context)
645
+ value = self.to_v(context)
646
+
647
+ query = self.reshape_heads_to_batch_dim(query)
648
+ key = self.reshape_heads_to_batch_dim(key)
649
+ value = self.reshape_heads_to_batch_dim(value)
650
+
651
+ hidden_states = self._attention(query, key, value)
652
+
653
+ # linear proj
654
+ hidden_states = self.to_out[0](hidden_states)
655
+ # hidden_states = self.to_out[1](hidden_states) # no dropout
656
+ return hidden_states
657
+
658
+ def _attention(self, query, key, value):
659
+ if self.upcast_attention:
660
+ query = query.float()
661
+ key = key.float()
662
+
663
+ attention_scores = torch.baddbmm(
664
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
665
+ query,
666
+ key.transpose(-1, -2),
667
+ beta=0,
668
+ alpha=self.scale,
669
+ )
670
+ attention_probs = attention_scores.softmax(dim=-1)
671
+
672
+ # cast back to the original dtype
673
+ attention_probs = attention_probs.to(value.dtype)
674
+
675
+ # compute attention output
676
+ hidden_states = torch.bmm(attention_probs, value)
677
+
678
+ # reshape hidden_states
679
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
680
+ return hidden_states
681
+
682
+ # TODO support Hypernetworks
683
+ def forward_memory_efficient_xformers(self, x, context=None, mask=None):
684
+ import xformers.ops
685
+
686
+ h = self.heads
687
+ q_in = self.to_q(x)
688
+ context = context if context is not None else x
689
+ context = context.to(x.dtype)
690
+ k_in = self.to_k(context)
691
+ v_in = self.to_v(context)
692
+
693
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
694
+ del q_in, k_in, v_in
695
+
696
+ q = q.contiguous()
697
+ k = k.contiguous()
698
+ v = v.contiguous()
699
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
700
+
701
+ out = rearrange(out, "b n h d -> b n (h d)", h=h)
702
+
703
+ out = self.to_out[0](out)
704
+ return out
705
+
706
+ def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
707
+ flash_func = FlashAttentionFunction
708
+
709
+ q_bucket_size = 512
710
+ k_bucket_size = 1024
711
+
712
+ h = self.heads
713
+ q = self.to_q(x)
714
+ context = context if context is not None else x
715
+ context = context.to(x.dtype)
716
+ k = self.to_k(context)
717
+ v = self.to_v(context)
718
+ del context, x
719
+
720
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
721
+
722
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
723
+
724
+ out = rearrange(out, "b h n d -> b n (h d)")
725
+
726
+ out = self.to_out[0](out)
727
+ return out
728
+
729
+ def forward_sdpa(self, x, context=None, mask=None):
730
+ h = self.heads
731
+ q_in = self.to_q(x)
732
+ context = context if context is not None else x
733
+ context = context.to(x.dtype)
734
+ k_in = self.to_k(context)
735
+ v_in = self.to_v(context)
736
+
737
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
738
+ del q_in, k_in, v_in
739
+
740
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
741
+
742
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
743
+
744
+ out = self.to_out[0](out)
745
+ return out
746
+
747
+ def translate_attention_names_from_diffusers(
748
+ hidden_states: torch.FloatTensor,
749
+ context: Optional[torch.FloatTensor] = None,
750
+ mask: Optional[torch.FloatTensor] = None,
751
+ # HF naming
752
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
753
+ attention_mask: Optional[torch.FloatTensor] = None
754
+ ):
755
+ # translate from hugging face diffusers
756
+ context = context if context is not None else encoder_hidden_states
757
+
758
+ # translate from hugging face diffusers
759
+ mask = mask if mask is not None else attention_mask
760
+
761
+ return hidden_states, context, mask
762
+
763
+ # feedforward
764
+ class GEGLU(nn.Module):
765
+ r"""
766
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
767
+
768
+ Parameters:
769
+ dim_in (`int`): The number of channels in the input.
770
+ dim_out (`int`): The number of channels in the output.
771
+ """
772
+
773
+ def __init__(self, dim_in: int, dim_out: int):
774
+ super().__init__()
775
+ self.proj = nn.Linear(dim_in, dim_out * 2)
776
+
777
+ def gelu(self, gate):
778
+ if gate.device.type != "mps":
779
+ return F.gelu(gate)
780
+ # mps: gelu is not implemented for float16
781
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
782
+
783
+ def forward(self, hidden_states):
784
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
785
+ return hidden_states * self.gelu(gate)
786
+
787
+
788
+ class FeedForward(nn.Module):
789
+ def __init__(
790
+ self,
791
+ dim: int,
792
+ ):
793
+ super().__init__()
794
+ inner_dim = int(dim * 4) # mult is always 4
795
+
796
+ self.net = nn.ModuleList([])
797
+ # project in
798
+ self.net.append(GEGLU(dim, inner_dim))
799
+ # project dropout
800
+ self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
801
+ # project out
802
+ self.net.append(nn.Linear(inner_dim, dim))
803
+
804
+ def forward(self, hidden_states):
805
+ for module in self.net:
806
+ hidden_states = module(hidden_states)
807
+ return hidden_states
808
+
809
+
810
+ class BasicTransformerBlock(nn.Module):
811
+ def __init__(
812
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
813
+ ):
814
+ super().__init__()
815
+
816
+ # 1. Self-Attn
817
+ self.attn1 = CrossAttention(
818
+ query_dim=dim,
819
+ cross_attention_dim=None,
820
+ heads=num_attention_heads,
821
+ dim_head=attention_head_dim,
822
+ upcast_attention=upcast_attention,
823
+ )
824
+ self.ff = FeedForward(dim)
825
+
826
+ # 2. Cross-Attn
827
+ self.attn2 = CrossAttention(
828
+ query_dim=dim,
829
+ cross_attention_dim=cross_attention_dim,
830
+ heads=num_attention_heads,
831
+ dim_head=attention_head_dim,
832
+ upcast_attention=upcast_attention,
833
+ )
834
+
835
+ self.norm1 = nn.LayerNorm(dim)
836
+ self.norm2 = nn.LayerNorm(dim)
837
+
838
+ # 3. Feed-forward
839
+ self.norm3 = nn.LayerNorm(dim)
840
+
841
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
842
+ self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
843
+ self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
844
+
845
+ def set_use_sdpa(self, sdpa: bool):
846
+ self.attn1.set_use_sdpa(sdpa)
847
+ self.attn2.set_use_sdpa(sdpa)
848
+
849
+ def forward(self, hidden_states, context=None, timestep=None):
850
+ # 1. Self-Attention
851
+ norm_hidden_states = self.norm1(hidden_states)
852
+
853
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
854
+
855
+ # 2. Cross-Attention
856
+ norm_hidden_states = self.norm2(hidden_states)
857
+ hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
858
+
859
+ # 3. Feed-forward
860
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
861
+
862
+ return hidden_states
863
+
864
+
865
+ class Transformer2DModel(nn.Module):
866
+ def __init__(
867
+ self,
868
+ num_attention_heads: int = 16,
869
+ attention_head_dim: int = 88,
870
+ in_channels: Optional[int] = None,
871
+ cross_attention_dim: Optional[int] = None,
872
+ use_linear_projection: bool = False,
873
+ upcast_attention: bool = False,
874
+ ):
875
+ super().__init__()
876
+ self.in_channels = in_channels
877
+ self.num_attention_heads = num_attention_heads
878
+ self.attention_head_dim = attention_head_dim
879
+ inner_dim = num_attention_heads * attention_head_dim
880
+ self.use_linear_projection = use_linear_projection
881
+
882
+ self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True)
883
+
884
+ if use_linear_projection:
885
+ self.proj_in = nn.Linear(in_channels, inner_dim)
886
+ else:
887
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
888
+
889
+ self.transformer_blocks = nn.ModuleList(
890
+ [
891
+ BasicTransformerBlock(
892
+ inner_dim,
893
+ num_attention_heads,
894
+ attention_head_dim,
895
+ cross_attention_dim=cross_attention_dim,
896
+ upcast_attention=upcast_attention,
897
+ )
898
+ ]
899
+ )
900
+
901
+ if use_linear_projection:
902
+ self.proj_out = nn.Linear(in_channels, inner_dim)
903
+ else:
904
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
905
+
906
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
907
+ for transformer in self.transformer_blocks:
908
+ transformer.set_use_memory_efficient_attention(xformers, mem_eff)
909
+
910
+ def set_use_sdpa(self, sdpa):
911
+ for transformer in self.transformer_blocks:
912
+ transformer.set_use_sdpa(sdpa)
913
+
914
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
915
+ # 1. Input
916
+ batch, _, height, weight = hidden_states.shape
917
+ residual = hidden_states
918
+
919
+ hidden_states = self.norm(hidden_states)
920
+ if not self.use_linear_projection:
921
+ hidden_states = self.proj_in(hidden_states)
922
+ inner_dim = hidden_states.shape[1]
923
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
924
+ else:
925
+ inner_dim = hidden_states.shape[1]
926
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
927
+ hidden_states = self.proj_in(hidden_states)
928
+
929
+ # 2. Blocks
930
+ for block in self.transformer_blocks:
931
+ hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
932
+
933
+ # 3. Output
934
+ if not self.use_linear_projection:
935
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
936
+ hidden_states = self.proj_out(hidden_states)
937
+ else:
938
+ hidden_states = self.proj_out(hidden_states)
939
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
940
+
941
+ output = hidden_states + residual
942
+
943
+ if not return_dict:
944
+ return (output,)
945
+
946
+ return SampleOutput(sample=output)
947
+
948
+
949
+ class CrossAttnDownBlock2D(nn.Module):
950
+ def __init__(
951
+ self,
952
+ in_channels: int,
953
+ out_channels: int,
954
+ add_downsample=True,
955
+ cross_attention_dim=1280,
956
+ attn_num_head_channels=1,
957
+ use_linear_projection=False,
958
+ upcast_attention=False,
959
+ ):
960
+ super().__init__()
961
+ self.has_cross_attention = True
962
+ resnets = []
963
+ attentions = []
964
+
965
+ self.attn_num_head_channels = attn_num_head_channels
966
+
967
+ for i in range(LAYERS_PER_BLOCK):
968
+ in_channels = in_channels if i == 0 else out_channels
969
+
970
+ resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels))
971
+ attentions.append(
972
+ Transformer2DModel(
973
+ attn_num_head_channels,
974
+ out_channels // attn_num_head_channels,
975
+ in_channels=out_channels,
976
+ cross_attention_dim=cross_attention_dim,
977
+ use_linear_projection=use_linear_projection,
978
+ upcast_attention=upcast_attention,
979
+ )
980
+ )
981
+ self.attentions = nn.ModuleList(attentions)
982
+ self.resnets = nn.ModuleList(resnets)
983
+
984
+ if add_downsample:
985
+ self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
986
+ else:
987
+ self.downsamplers = None
988
+
989
+ self.gradient_checkpointing = False
990
+
991
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
992
+ for attn in self.attentions:
993
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
994
+
995
+ def set_use_sdpa(self, sdpa):
996
+ for attn in self.attentions:
997
+ attn.set_use_sdpa(sdpa)
998
+
999
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
1000
+ output_states = ()
1001
+
1002
+ for resnet, attn in zip(self.resnets, self.attentions):
1003
+ if self.training and self.gradient_checkpointing:
1004
+
1005
+ def create_custom_forward(module, return_dict=None):
1006
+ def custom_forward(*inputs):
1007
+ if return_dict is not None:
1008
+ return module(*inputs, return_dict=return_dict)
1009
+ else:
1010
+ return module(*inputs)
1011
+
1012
+ return custom_forward
1013
+
1014
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1015
+ hidden_states = torch.utils.checkpoint.checkpoint(
1016
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1017
+ )[0]
1018
+ else:
1019
+ hidden_states = resnet(hidden_states, temb)
1020
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1021
+
1022
+ output_states += (hidden_states,)
1023
+
1024
+ if self.downsamplers is not None:
1025
+ for downsampler in self.downsamplers:
1026
+ hidden_states = downsampler(hidden_states)
1027
+
1028
+ output_states += (hidden_states,)
1029
+
1030
+ return hidden_states, output_states
1031
+
1032
+
1033
+ class UNetMidBlock2DCrossAttn(nn.Module):
1034
+ def __init__(
1035
+ self,
1036
+ in_channels: int,
1037
+ attn_num_head_channels=1,
1038
+ cross_attention_dim=1280,
1039
+ use_linear_projection=False,
1040
+ ):
1041
+ super().__init__()
1042
+
1043
+ self.has_cross_attention = True
1044
+ self.attn_num_head_channels = attn_num_head_channels
1045
+
1046
+ # Middle block has two resnets and one attention
1047
+ resnets = [
1048
+ ResnetBlock2D(
1049
+ in_channels=in_channels,
1050
+ out_channels=in_channels,
1051
+ ),
1052
+ ResnetBlock2D(
1053
+ in_channels=in_channels,
1054
+ out_channels=in_channels,
1055
+ ),
1056
+ ]
1057
+ attentions = [
1058
+ Transformer2DModel(
1059
+ attn_num_head_channels,
1060
+ in_channels // attn_num_head_channels,
1061
+ in_channels=in_channels,
1062
+ cross_attention_dim=cross_attention_dim,
1063
+ use_linear_projection=use_linear_projection,
1064
+ )
1065
+ ]
1066
+
1067
+ self.attentions = nn.ModuleList(attentions)
1068
+ self.resnets = nn.ModuleList(resnets)
1069
+
1070
+ self.gradient_checkpointing = False
1071
+
1072
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1073
+ for attn in self.attentions:
1074
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
1075
+
1076
+ def set_use_sdpa(self, sdpa):
1077
+ for attn in self.attentions:
1078
+ attn.set_use_sdpa(sdpa)
1079
+
1080
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
1081
+ for i, resnet in enumerate(self.resnets):
1082
+ attn = None if i == 0 else self.attentions[i - 1]
1083
+
1084
+ if self.training and self.gradient_checkpointing:
1085
+
1086
+ def create_custom_forward(module, return_dict=None):
1087
+ def custom_forward(*inputs):
1088
+ if return_dict is not None:
1089
+ return module(*inputs, return_dict=return_dict)
1090
+ else:
1091
+ return module(*inputs)
1092
+
1093
+ return custom_forward
1094
+
1095
+ if attn is not None:
1096
+ hidden_states = torch.utils.checkpoint.checkpoint(
1097
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1098
+ )[0]
1099
+
1100
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1101
+ else:
1102
+ if attn is not None:
1103
+ hidden_states = attn(hidden_states, encoder_hidden_states).sample
1104
+ hidden_states = resnet(hidden_states, temb)
1105
+
1106
+ return hidden_states
1107
+
1108
+
1109
+ class Upsample2D(nn.Module):
1110
+ def __init__(self, channels, out_channels):
1111
+ super().__init__()
1112
+ self.channels = channels
1113
+ self.out_channels = out_channels
1114
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
1115
+
1116
+ def forward(self, hidden_states, output_size):
1117
+ assert hidden_states.shape[1] == self.channels
1118
+
1119
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
1120
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
1121
+ # https://github.com/pytorch/pytorch/issues/86679
1122
+ dtype = hidden_states.dtype
1123
+ if dtype == torch.bfloat16:
1124
+ hidden_states = hidden_states.to(torch.float32)
1125
+
1126
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
1127
+ if hidden_states.shape[0] >= 64:
1128
+ hidden_states = hidden_states.contiguous()
1129
+
1130
+ # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
1131
+ if output_size is None:
1132
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
1133
+ else:
1134
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
1135
+
1136
+ # If the input is bfloat16, we cast back to bfloat16
1137
+ if dtype == torch.bfloat16:
1138
+ hidden_states = hidden_states.to(dtype)
1139
+
1140
+ hidden_states = self.conv(hidden_states)
1141
+
1142
+ return hidden_states
1143
+
1144
+
1145
+ class UpBlock2D(nn.Module):
1146
+ def __init__(
1147
+ self,
1148
+ in_channels: int,
1149
+ prev_output_channel: int,
1150
+ out_channels: int,
1151
+ add_upsample=True,
1152
+ ):
1153
+ super().__init__()
1154
+
1155
+ self.has_cross_attention = False
1156
+ resnets = []
1157
+
1158
+ for i in range(LAYERS_PER_BLOCK_UP):
1159
+ res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
1160
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1161
+
1162
+ resnets.append(
1163
+ ResnetBlock2D(
1164
+ in_channels=resnet_in_channels + res_skip_channels,
1165
+ out_channels=out_channels,
1166
+ )
1167
+ )
1168
+
1169
+ self.resnets = nn.ModuleList(resnets)
1170
+
1171
+ if add_upsample:
1172
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
1173
+ else:
1174
+ self.upsamplers = None
1175
+
1176
+ self.gradient_checkpointing = False
1177
+
1178
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1179
+ pass
1180
+
1181
+ def set_use_sdpa(self, sdpa):
1182
+ pass
1183
+
1184
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1185
+ for resnet in self.resnets:
1186
+ # pop res hidden states
1187
+ res_hidden_states = res_hidden_states_tuple[-1]
1188
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1189
+
1190
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1191
+
1192
+ if self.training and self.gradient_checkpointing:
1193
+
1194
+ def create_custom_forward(module):
1195
+ def custom_forward(*inputs):
1196
+ return module(*inputs)
1197
+
1198
+ return custom_forward
1199
+
1200
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1201
+ else:
1202
+ hidden_states = resnet(hidden_states, temb)
1203
+
1204
+ if self.upsamplers is not None:
1205
+ for upsampler in self.upsamplers:
1206
+ hidden_states = upsampler(hidden_states, upsample_size)
1207
+
1208
+ return hidden_states
1209
+
1210
+
1211
+ class CrossAttnUpBlock2D(nn.Module):
1212
+ def __init__(
1213
+ self,
1214
+ in_channels: int,
1215
+ out_channels: int,
1216
+ prev_output_channel: int,
1217
+ attn_num_head_channels=1,
1218
+ cross_attention_dim=1280,
1219
+ add_upsample=True,
1220
+ use_linear_projection=False,
1221
+ upcast_attention=False,
1222
+ ):
1223
+ super().__init__()
1224
+ resnets = []
1225
+ attentions = []
1226
+
1227
+ self.has_cross_attention = True
1228
+ self.attn_num_head_channels = attn_num_head_channels
1229
+
1230
+ for i in range(LAYERS_PER_BLOCK_UP):
1231
+ res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
1232
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1233
+
1234
+ resnets.append(
1235
+ ResnetBlock2D(
1236
+ in_channels=resnet_in_channels + res_skip_channels,
1237
+ out_channels=out_channels,
1238
+ )
1239
+ )
1240
+ attentions.append(
1241
+ Transformer2DModel(
1242
+ attn_num_head_channels,
1243
+ out_channels // attn_num_head_channels,
1244
+ in_channels=out_channels,
1245
+ cross_attention_dim=cross_attention_dim,
1246
+ use_linear_projection=use_linear_projection,
1247
+ upcast_attention=upcast_attention,
1248
+ )
1249
+ )
1250
+
1251
+ self.attentions = nn.ModuleList(attentions)
1252
+ self.resnets = nn.ModuleList(resnets)
1253
+
1254
+ if add_upsample:
1255
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
1256
+ else:
1257
+ self.upsamplers = None
1258
+
1259
+ self.gradient_checkpointing = False
1260
+
1261
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1262
+ for attn in self.attentions:
1263
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
1264
+
1265
+ def set_use_sdpa(self, spda):
1266
+ for attn in self.attentions:
1267
+ attn.set_use_sdpa(spda)
1268
+
1269
+ def forward(
1270
+ self,
1271
+ hidden_states,
1272
+ res_hidden_states_tuple,
1273
+ temb=None,
1274
+ encoder_hidden_states=None,
1275
+ upsample_size=None,
1276
+ ):
1277
+ for resnet, attn in zip(self.resnets, self.attentions):
1278
+ # pop res hidden states
1279
+ res_hidden_states = res_hidden_states_tuple[-1]
1280
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1281
+
1282
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1283
+
1284
+ if self.training and self.gradient_checkpointing:
1285
+
1286
+ def create_custom_forward(module, return_dict=None):
1287
+ def custom_forward(*inputs):
1288
+ if return_dict is not None:
1289
+ return module(*inputs, return_dict=return_dict)
1290
+ else:
1291
+ return module(*inputs)
1292
+
1293
+ return custom_forward
1294
+
1295
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1296
+ hidden_states = torch.utils.checkpoint.checkpoint(
1297
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1298
+ )[0]
1299
+ else:
1300
+ hidden_states = resnet(hidden_states, temb)
1301
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1302
+
1303
+ if self.upsamplers is not None:
1304
+ for upsampler in self.upsamplers:
1305
+ hidden_states = upsampler(hidden_states, upsample_size)
1306
+
1307
+ return hidden_states
1308
+
1309
+
1310
+ def get_down_block(
1311
+ down_block_type,
1312
+ in_channels,
1313
+ out_channels,
1314
+ add_downsample,
1315
+ attn_num_head_channels,
1316
+ cross_attention_dim,
1317
+ use_linear_projection,
1318
+ upcast_attention,
1319
+ ):
1320
+ if down_block_type == "DownBlock2D":
1321
+ return DownBlock2D(
1322
+ in_channels=in_channels,
1323
+ out_channels=out_channels,
1324
+ add_downsample=add_downsample,
1325
+ )
1326
+ elif down_block_type == "CrossAttnDownBlock2D":
1327
+ return CrossAttnDownBlock2D(
1328
+ in_channels=in_channels,
1329
+ out_channels=out_channels,
1330
+ add_downsample=add_downsample,
1331
+ cross_attention_dim=cross_attention_dim,
1332
+ attn_num_head_channels=attn_num_head_channels,
1333
+ use_linear_projection=use_linear_projection,
1334
+ upcast_attention=upcast_attention,
1335
+ )
1336
+
1337
+
1338
+ def get_up_block(
1339
+ up_block_type,
1340
+ in_channels,
1341
+ out_channels,
1342
+ prev_output_channel,
1343
+ add_upsample,
1344
+ attn_num_head_channels,
1345
+ cross_attention_dim=None,
1346
+ use_linear_projection=False,
1347
+ upcast_attention=False,
1348
+ ):
1349
+ if up_block_type == "UpBlock2D":
1350
+ return UpBlock2D(
1351
+ in_channels=in_channels,
1352
+ prev_output_channel=prev_output_channel,
1353
+ out_channels=out_channels,
1354
+ add_upsample=add_upsample,
1355
+ )
1356
+ elif up_block_type == "CrossAttnUpBlock2D":
1357
+ return CrossAttnUpBlock2D(
1358
+ in_channels=in_channels,
1359
+ out_channels=out_channels,
1360
+ prev_output_channel=prev_output_channel,
1361
+ attn_num_head_channels=attn_num_head_channels,
1362
+ cross_attention_dim=cross_attention_dim,
1363
+ add_upsample=add_upsample,
1364
+ use_linear_projection=use_linear_projection,
1365
+ upcast_attention=upcast_attention,
1366
+ )
1367
+
1368
+
1369
+ class UNet2DConditionModel(nn.Module):
1370
+ _supports_gradient_checkpointing = True
1371
+
1372
+ def __init__(
1373
+ self,
1374
+ sample_size: Optional[int] = None,
1375
+ attention_head_dim: Union[int, Tuple[int]] = 8,
1376
+ cross_attention_dim: int = 1280,
1377
+ use_linear_projection: bool = False,
1378
+ upcast_attention: bool = False,
1379
+ **kwargs,
1380
+ ):
1381
+ super().__init__()
1382
+ assert sample_size is not None, "sample_size must be specified"
1383
+ print(
1384
+ f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
1385
+ )
1386
+
1387
+ # 外部からの参照用に定義しておく
1388
+ self.in_channels = IN_CHANNELS
1389
+ self.out_channels = OUT_CHANNELS
1390
+
1391
+ self.sample_size = sample_size
1392
+ self.prepare_config(sample_size=sample_size)
1393
+
1394
+ # state_dictの書式が変わるのでmoduleの持ち方は変えられない
1395
+
1396
+ # input
1397
+ self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
1398
+
1399
+ # time
1400
+ self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT)
1401
+
1402
+ self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM)
1403
+
1404
+ self.down_blocks = nn.ModuleList([])
1405
+ self.mid_block = None
1406
+ self.up_blocks = nn.ModuleList([])
1407
+
1408
+ if isinstance(attention_head_dim, int):
1409
+ attention_head_dim = (attention_head_dim,) * 4
1410
+
1411
+ # down
1412
+ output_channel = BLOCK_OUT_CHANNELS[0]
1413
+ for i, down_block_type in enumerate(DOWN_BLOCK_TYPES):
1414
+ input_channel = output_channel
1415
+ output_channel = BLOCK_OUT_CHANNELS[i]
1416
+ is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
1417
+
1418
+ down_block = get_down_block(
1419
+ down_block_type,
1420
+ in_channels=input_channel,
1421
+ out_channels=output_channel,
1422
+ add_downsample=not is_final_block,
1423
+ attn_num_head_channels=attention_head_dim[i],
1424
+ cross_attention_dim=cross_attention_dim,
1425
+ use_linear_projection=use_linear_projection,
1426
+ upcast_attention=upcast_attention,
1427
+ )
1428
+ self.down_blocks.append(down_block)
1429
+
1430
+ # mid
1431
+ self.mid_block = UNetMidBlock2DCrossAttn(
1432
+ in_channels=BLOCK_OUT_CHANNELS[-1],
1433
+ attn_num_head_channels=attention_head_dim[-1],
1434
+ cross_attention_dim=cross_attention_dim,
1435
+ use_linear_projection=use_linear_projection,
1436
+ )
1437
+
1438
+ # count how many layers upsample the images
1439
+ self.num_upsamplers = 0
1440
+
1441
+ # up
1442
+ reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS))
1443
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
1444
+ output_channel = reversed_block_out_channels[0]
1445
+ for i, up_block_type in enumerate(UP_BLOCK_TYPES):
1446
+ is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
1447
+
1448
+ prev_output_channel = output_channel
1449
+ output_channel = reversed_block_out_channels[i]
1450
+ input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)]
1451
+
1452
+ # add upsample block for all BUT final layer
1453
+ if not is_final_block:
1454
+ add_upsample = True
1455
+ self.num_upsamplers += 1
1456
+ else:
1457
+ add_upsample = False
1458
+
1459
+ up_block = get_up_block(
1460
+ up_block_type,
1461
+ in_channels=input_channel,
1462
+ out_channels=output_channel,
1463
+ prev_output_channel=prev_output_channel,
1464
+ add_upsample=add_upsample,
1465
+ attn_num_head_channels=reversed_attention_head_dim[i],
1466
+ cross_attention_dim=cross_attention_dim,
1467
+ use_linear_projection=use_linear_projection,
1468
+ upcast_attention=upcast_attention,
1469
+ )
1470
+ self.up_blocks.append(up_block)
1471
+ prev_output_channel = output_channel
1472
+
1473
+ # out
1474
+ self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
1475
+ self.conv_act = nn.SiLU()
1476
+ self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
1477
+
1478
+ # region diffusers compatibility
1479
+ def prepare_config(self, *args, **kwargs):
1480
+ self.config = SimpleNamespace(**kwargs)
1481
+
1482
+ @property
1483
+ def dtype(self) -> torch.dtype:
1484
+ # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1485
+ return get_parameter_dtype(self)
1486
+
1487
+ @property
1488
+ def device(self) -> torch.device:
1489
+ # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
1490
+ return get_parameter_device(self)
1491
+
1492
+ def set_attention_slice(self, slice_size):
1493
+ raise NotImplementedError("Attention slicing is not supported for this model.")
1494
+
1495
+ def is_gradient_checkpointing(self) -> bool:
1496
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
1497
+
1498
+ def enable_gradient_checkpointing(self):
1499
+ self.set_gradient_checkpointing(value=True)
1500
+
1501
+ def disable_gradient_checkpointing(self):
1502
+ self.set_gradient_checkpointing(value=False)
1503
+
1504
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
1505
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1506
+ for module in modules:
1507
+ module.set_use_memory_efficient_attention(xformers, mem_eff)
1508
+
1509
+ def set_use_sdpa(self, sdpa: bool) -> None:
1510
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1511
+ for module in modules:
1512
+ module.set_use_sdpa(sdpa)
1513
+
1514
+ def set_gradient_checkpointing(self, value=False):
1515
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1516
+ for module in modules:
1517
+ print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
1518
+ module.gradient_checkpointing = value
1519
+
1520
+ # endregion
1521
+
1522
+ def forward(
1523
+ self,
1524
+ sample: torch.FloatTensor,
1525
+ timestep: Union[torch.Tensor, float, int],
1526
+ encoder_hidden_states: torch.Tensor,
1527
+ class_labels: Optional[torch.Tensor] = None,
1528
+ return_dict: bool = True,
1529
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1530
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1531
+ ) -> Union[Dict, Tuple]:
1532
+ r"""
1533
+ Args:
1534
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
1535
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
1536
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
1537
+ return_dict (`bool`, *optional*, defaults to `True`):
1538
+ Whether or not to return a dict instead of a plain tuple.
1539
+
1540
+ Returns:
1541
+ `SampleOutput` or `tuple`:
1542
+ `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
1543
+ """
1544
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1545
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
1546
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1547
+ # on the fly if necessary.
1548
+ # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
1549
+ # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
1550
+ # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
1551
+ default_overall_up_factor = 2**self.num_upsamplers
1552
+
1553
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1554
+ # 64で割り切れないときはupsamplerにサイズを伝える
1555
+ forward_upsample_size = False
1556
+ upsample_size = None
1557
+
1558
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
1559
+ # logger.info("Forward upsample size to force interpolation output size.")
1560
+ forward_upsample_size = True
1561
+
1562
+ # 1. time
1563
+ timesteps = timestep
1564
+ timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
1565
+
1566
+ t_emb = self.time_proj(timesteps)
1567
+
1568
+ # timesteps does not contain any weights and will always return f32 tensors
1569
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1570
+ # there might be better ways to encapsulate this.
1571
+ # timestepsは重みを含まないので常にfloat32のテンソルを返す
1572
+ # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
1573
+ # time_projでキャストしておけばいいんじゃね?
1574
+ t_emb = t_emb.to(dtype=self.dtype)
1575
+ emb = self.time_embedding(t_emb)
1576
+
1577
+ # 2. pre-process
1578
+ sample = self.conv_in(sample)
1579
+
1580
+ down_block_res_samples = (sample,)
1581
+ for downsample_block in self.down_blocks:
1582
+ # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
1583
+ # まあこちらのほうがわかりやすいかもしれない
1584
+ if downsample_block.has_cross_attention:
1585
+ sample, res_samples = downsample_block(
1586
+ hidden_states=sample,
1587
+ temb=emb,
1588
+ encoder_hidden_states=encoder_hidden_states,
1589
+ )
1590
+ else:
1591
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1592
+
1593
+ down_block_res_samples += res_samples
1594
+
1595
+ # skip connectionにControlNetの出力を追加する
1596
+ if down_block_additional_residuals is not None:
1597
+ down_block_res_samples = list(down_block_res_samples)
1598
+ for i in range(len(down_block_res_samples)):
1599
+ down_block_res_samples[i] += down_block_additional_residuals[i]
1600
+ down_block_res_samples = tuple(down_block_res_samples)
1601
+
1602
+ # 4. mid
1603
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
1604
+
1605
+ # ControlNetの出力を追加する
1606
+ if mid_block_additional_residual is not None:
1607
+ sample += mid_block_additional_residual
1608
+
1609
+ # 5. up
1610
+ for i, upsample_block in enumerate(self.up_blocks):
1611
+ is_final_block = i == len(self.up_blocks) - 1
1612
+
1613
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1614
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
1615
+
1616
+ # if we have not reached the final block and need to forward the upsample size, we do it here
1617
+ # 前述のように最後のブロック以外ではupsample_sizeを伝える
1618
+ if not is_final_block and forward_upsample_size:
1619
+ upsample_size = down_block_res_samples[-1].shape[2:]
1620
+
1621
+ if upsample_block.has_cross_attention:
1622
+ sample = upsample_block(
1623
+ hidden_states=sample,
1624
+ temb=emb,
1625
+ res_hidden_states_tuple=res_samples,
1626
+ encoder_hidden_states=encoder_hidden_states,
1627
+ upsample_size=upsample_size,
1628
+ )
1629
+ else:
1630
+ sample = upsample_block(
1631
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1632
+ )
1633
+
1634
+ # 6. post-process
1635
+ sample = self.conv_norm_out(sample)
1636
+ sample = self.conv_act(sample)
1637
+ sample = self.conv_out(sample)
1638
+
1639
+ if not return_dict:
1640
+ return (sample,)
1641
+
1642
+ return SampleOutput(sample=sample)
1643
+
1644
+ def handle_unusual_timesteps(self, sample, timesteps):
1645
+ r"""
1646
+ timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。
1647
+ """
1648
+ if not torch.is_tensor(timesteps):
1649
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1650
+ # This would be a good case for the `match` statement (Python 3.10+)
1651
+ is_mps = sample.device.type == "mps"
1652
+ if isinstance(timesteps, float):
1653
+ dtype = torch.float32 if is_mps else torch.float64
1654
+ else:
1655
+ dtype = torch.int32 if is_mps else torch.int64
1656
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1657
+ elif len(timesteps.shape) == 0:
1658
+ timesteps = timesteps[None].to(sample.device)
1659
+
1660
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1661
+ timesteps = timesteps.expand(sample.shape[0])
1662
+
1663
+ return timesteps
1664
+
1665
+
1666
+ class InferUNet2DConditionModel:
1667
+ def __init__(self, original_unet: UNet2DConditionModel):
1668
+ self.delegate = original_unet
1669
+
1670
+ # override original model's forward method: because forward is not called by `__call__`
1671
+ # overriding `__call__` is not enough, because nn.Module.forward has a special handling
1672
+ self.delegate.forward = self.forward
1673
+
1674
+ # override original model's up blocks' forward method
1675
+ for up_block in self.delegate.up_blocks:
1676
+ if up_block.__class__.__name__ == "UpBlock2D":
1677
+
1678
+ def resnet_wrapper(func, block):
1679
+ def forward(*args, **kwargs):
1680
+ return func(block, *args, **kwargs)
1681
+
1682
+ return forward
1683
+
1684
+ up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
1685
+
1686
+ elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
1687
+
1688
+ def cross_attn_up_wrapper(func, block):
1689
+ def forward(*args, **kwargs):
1690
+ return func(block, *args, **kwargs)
1691
+
1692
+ return forward
1693
+
1694
+ up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
1695
+
1696
+ # Deep Shrink
1697
+ self.ds_depth_1 = None
1698
+ self.ds_depth_2 = None
1699
+ self.ds_timesteps_1 = None
1700
+ self.ds_timesteps_2 = None
1701
+ self.ds_ratio = None
1702
+
1703
+ # call original model's methods
1704
+ def __getattr__(self, name):
1705
+ return getattr(self.delegate, name)
1706
+
1707
+ def __call__(self, *args, **kwargs):
1708
+ return self.delegate(*args, **kwargs)
1709
+
1710
+ def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
1711
+ if ds_depth_1 is None:
1712
+ print("Deep Shrink is disabled.")
1713
+ self.ds_depth_1 = None
1714
+ self.ds_timesteps_1 = None
1715
+ self.ds_depth_2 = None
1716
+ self.ds_timesteps_2 = None
1717
+ self.ds_ratio = None
1718
+ else:
1719
+ print(
1720
+ f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
1721
+ )
1722
+ self.ds_depth_1 = ds_depth_1
1723
+ self.ds_timesteps_1 = ds_timesteps_1
1724
+ self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
1725
+ self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
1726
+ self.ds_ratio = ds_ratio
1727
+
1728
+ def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1729
+ for resnet in _self.resnets:
1730
+ # pop res hidden states
1731
+ res_hidden_states = res_hidden_states_tuple[-1]
1732
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1733
+
1734
+ # Deep Shrink
1735
+ if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
1736
+ hidden_states = resize_like(hidden_states, res_hidden_states)
1737
+
1738
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1739
+ hidden_states = resnet(hidden_states, temb)
1740
+
1741
+ if _self.upsamplers is not None:
1742
+ for upsampler in _self.upsamplers:
1743
+ hidden_states = upsampler(hidden_states, upsample_size)
1744
+
1745
+ return hidden_states
1746
+
1747
+ def cross_attn_up_block_forward(
1748
+ self,
1749
+ _self,
1750
+ hidden_states,
1751
+ res_hidden_states_tuple,
1752
+ temb=None,
1753
+ encoder_hidden_states=None,
1754
+ upsample_size=None,
1755
+ ):
1756
+ for resnet, attn in zip(_self.resnets, _self.attentions):
1757
+ # pop res hidden states
1758
+ res_hidden_states = res_hidden_states_tuple[-1]
1759
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1760
+
1761
+ # Deep Shrink
1762
+ if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
1763
+ hidden_states = resize_like(hidden_states, res_hidden_states)
1764
+
1765
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1766
+ hidden_states = resnet(hidden_states, temb)
1767
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1768
+
1769
+ if _self.upsamplers is not None:
1770
+ for upsampler in _self.upsamplers:
1771
+ hidden_states = upsampler(hidden_states, upsample_size)
1772
+
1773
+ return hidden_states
1774
+
1775
+ def forward(
1776
+ self,
1777
+ sample: torch.FloatTensor,
1778
+ timestep: Union[torch.Tensor, float, int],
1779
+ encoder_hidden_states: torch.Tensor,
1780
+ class_labels: Optional[torch.Tensor] = None,
1781
+ return_dict: bool = True,
1782
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1783
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1784
+ ) -> Union[Dict, Tuple]:
1785
+ r"""
1786
+ current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
1787
+ """
1788
+
1789
+ r"""
1790
+ Args:
1791
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
1792
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
1793
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
1794
+ return_dict (`bool`, *optional*, defaults to `True`):
1795
+ Whether or not to return a dict instead of a plain tuple.
1796
+
1797
+ Returns:
1798
+ `SampleOutput` or `tuple`:
1799
+ `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
1800
+ """
1801
+
1802
+ _self = self.delegate
1803
+
1804
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1805
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
1806
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1807
+ # on the fly if necessary.
1808
+ # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
1809
+ # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
1810
+ # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
1811
+ default_overall_up_factor = 2**_self.num_upsamplers
1812
+
1813
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1814
+ # 64で割り切れないときはupsamplerにサイズを伝える
1815
+ forward_upsample_size = False
1816
+ upsample_size = None
1817
+
1818
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
1819
+ # logger.info("Forward upsample size to force interpolation output size.")
1820
+ forward_upsample_size = True
1821
+
1822
+ # 1. time
1823
+ timesteps = timestep
1824
+ timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
1825
+
1826
+ t_emb = _self.time_proj(timesteps)
1827
+
1828
+ # timesteps does not contain any weights and will always return f32 tensors
1829
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1830
+ # there might be better ways to encapsulate this.
1831
+ # timestepsは重みを含まないので常にfloat32のテンソルを返す
1832
+ # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
1833
+ # time_projでキャストしておけばいいんじゃね?
1834
+ t_emb = t_emb.to(dtype=_self.dtype)
1835
+ emb = _self.time_embedding(t_emb)
1836
+
1837
+ # 2. pre-process
1838
+ sample = _self.conv_in(sample)
1839
+
1840
+ down_block_res_samples = (sample,)
1841
+ for depth, downsample_block in enumerate(_self.down_blocks):
1842
+ # Deep Shrink
1843
+ if self.ds_depth_1 is not None:
1844
+ if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
1845
+ self.ds_depth_2 is not None
1846
+ and depth == self.ds_depth_2
1847
+ and timesteps[0] < self.ds_timesteps_1
1848
+ and timesteps[0] >= self.ds_timesteps_2
1849
+ ):
1850
+ org_dtype = sample.dtype
1851
+ if org_dtype == torch.bfloat16:
1852
+ sample = sample.to(torch.float32)
1853
+ sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
1854
+
1855
+ # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
1856
+ # まあこちらのほうがわかりやすいかもしれない
1857
+ if downsample_block.has_cross_attention:
1858
+ sample, res_samples = downsample_block(
1859
+ hidden_states=sample,
1860
+ temb=emb,
1861
+ encoder_hidden_states=encoder_hidden_states,
1862
+ )
1863
+ else:
1864
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1865
+
1866
+ down_block_res_samples += res_samples
1867
+
1868
+ # skip connectionにControlNetの出力を追加する
1869
+ if down_block_additional_residuals is not None:
1870
+ down_block_res_samples = list(down_block_res_samples)
1871
+ for i in range(len(down_block_res_samples)):
1872
+ down_block_res_samples[i] += down_block_additional_residuals[i]
1873
+ down_block_res_samples = tuple(down_block_res_samples)
1874
+
1875
+ # 4. mid
1876
+ sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
1877
+
1878
+ # ControlNetの出力を追加する
1879
+ if mid_block_additional_residual is not None:
1880
+ sample += mid_block_additional_residual
1881
+
1882
+ # 5. up
1883
+ for i, upsample_block in enumerate(_self.up_blocks):
1884
+ is_final_block = i == len(_self.up_blocks) - 1
1885
+
1886
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1887
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
1888
+
1889
+ # if we have not reached the final block and need to forward the upsample size, we do it here
1890
+ # 前述のように最後のブロック以外ではupsample_sizeを伝える
1891
+ if not is_final_block and forward_upsample_size:
1892
+ upsample_size = down_block_res_samples[-1].shape[2:]
1893
+
1894
+ if upsample_block.has_cross_attention:
1895
+ sample = upsample_block(
1896
+ hidden_states=sample,
1897
+ temb=emb,
1898
+ res_hidden_states_tuple=res_samples,
1899
+ encoder_hidden_states=encoder_hidden_states,
1900
+ upsample_size=upsample_size,
1901
+ )
1902
+ else:
1903
+ sample = upsample_block(
1904
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1905
+ )
1906
+
1907
+ # 6. post-process
1908
+ sample = _self.conv_norm_out(sample)
1909
+ sample = _self.conv_act(sample)
1910
+ sample = _self.conv_out(sample)
1911
+
1912
+ if not return_dict:
1913
+ return (sample,)
1914
+
1915
+ return SampleOutput(sample=sample)
external/llite/library/sai_model_spec.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/Stability-AI/ModelSpec
2
+ import datetime
3
+ import hashlib
4
+ from io import BytesIO
5
+ import os
6
+ from typing import List, Optional, Tuple, Union
7
+ import safetensors
8
+
9
+ r"""
10
+ # Metadata Example
11
+ metadata = {
12
+ # === Must ===
13
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
14
+ "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
15
+ "modelspec.implementation": "sgm",
16
+ "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
17
+ # === Should ===
18
+ "modelspec.author": "Example Corp", # Your name or company name
19
+ "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
20
+ "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
21
+ # === Can ===
22
+ "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
23
+ "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
24
+ }
25
+ """
26
+
27
+ BASE_METADATA = {
28
+ # === Must ===
29
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
30
+ "modelspec.architecture": None,
31
+ "modelspec.implementation": None,
32
+ "modelspec.title": None,
33
+ "modelspec.resolution": None,
34
+ # === Should ===
35
+ "modelspec.description": None,
36
+ "modelspec.author": None,
37
+ "modelspec.date": None,
38
+ # === Can ===
39
+ "modelspec.license": None,
40
+ "modelspec.tags": None,
41
+ "modelspec.merged_from": None,
42
+ "modelspec.prediction_type": None,
43
+ "modelspec.timestep_range": None,
44
+ "modelspec.encoder_layer": None,
45
+ }
46
+
47
+ # 別に使うやつだけ定義
48
+ MODELSPEC_TITLE = "modelspec.title"
49
+
50
+ ARCH_SD_V1 = "stable-diffusion-v1"
51
+ ARCH_SD_V2_512 = "stable-diffusion-v2-512"
52
+ ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
53
+ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
54
+
55
+ ADAPTER_LORA = "lora"
56
+ ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
57
+
58
+ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
59
+ IMPL_DIFFUSERS = "diffusers"
60
+
61
+ PRED_TYPE_EPSILON = "epsilon"
62
+ PRED_TYPE_V = "v"
63
+
64
+
65
+ def load_bytes_in_safetensors(tensors):
66
+ bytes = safetensors.torch.save(tensors)
67
+ b = BytesIO(bytes)
68
+
69
+ b.seek(0)
70
+ header = b.read(8)
71
+ n = int.from_bytes(header, "little")
72
+
73
+ offset = n + 8
74
+ b.seek(offset)
75
+
76
+ return b.read()
77
+
78
+
79
+ def precalculate_safetensors_hashes(state_dict):
80
+ # calculate each tensor one by one to reduce memory usage
81
+ hash_sha256 = hashlib.sha256()
82
+ for tensor in state_dict.values():
83
+ single_tensor_sd = {"tensor": tensor}
84
+ bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
85
+ hash_sha256.update(bytes_for_tensor)
86
+
87
+ return f"0x{hash_sha256.hexdigest()}"
88
+
89
+
90
+ def update_hash_sha256(metadata: dict, state_dict: dict):
91
+ raise NotImplementedError
92
+
93
+
94
+ def build_metadata(
95
+ state_dict: Optional[dict],
96
+ v2: bool,
97
+ v_parameterization: bool,
98
+ sdxl: bool,
99
+ lora: bool,
100
+ textual_inversion: bool,
101
+ timestamp: float,
102
+ title: Optional[str] = None,
103
+ reso: Optional[Union[int, Tuple[int, int]]] = None,
104
+ is_stable_diffusion_ckpt: Optional[bool] = None,
105
+ author: Optional[str] = None,
106
+ description: Optional[str] = None,
107
+ license: Optional[str] = None,
108
+ tags: Optional[str] = None,
109
+ merged_from: Optional[str] = None,
110
+ timesteps: Optional[Tuple[int, int]] = None,
111
+ clip_skip: Optional[int] = None,
112
+ ):
113
+ # if state_dict is None, hash is not calculated
114
+
115
+ metadata = {}
116
+ metadata.update(BASE_METADATA)
117
+
118
+ # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
119
+ # if state_dict is not None:
120
+ # hash = precalculate_safetensors_hashes(state_dict)
121
+ # metadata["modelspec.hash_sha256"] = hash
122
+
123
+ if sdxl:
124
+ arch = ARCH_SD_XL_V1_BASE
125
+ elif v2:
126
+ if v_parameterization:
127
+ arch = ARCH_SD_V2_768_V
128
+ else:
129
+ arch = ARCH_SD_V2_512
130
+ else:
131
+ arch = ARCH_SD_V1
132
+
133
+ if lora:
134
+ arch += f"/{ADAPTER_LORA}"
135
+ elif textual_inversion:
136
+ arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
137
+
138
+ metadata["modelspec.architecture"] = arch
139
+
140
+ if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
141
+ is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
142
+
143
+ if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
144
+ # Stable Diffusion ckpt, TI, SDXL LoRA
145
+ impl = IMPL_STABILITY_AI
146
+ else:
147
+ # v1/v2 LoRA or Diffusers
148
+ impl = IMPL_DIFFUSERS
149
+ metadata["modelspec.implementation"] = impl
150
+
151
+ if title is None:
152
+ if lora:
153
+ title = "LoRA"
154
+ elif textual_inversion:
155
+ title = "TextualInversion"
156
+ else:
157
+ title = "Checkpoint"
158
+ title += f"@{timestamp}"
159
+ metadata[MODELSPEC_TITLE] = title
160
+
161
+ if author is not None:
162
+ metadata["modelspec.author"] = author
163
+ else:
164
+ del metadata["modelspec.author"]
165
+
166
+ if description is not None:
167
+ metadata["modelspec.description"] = description
168
+ else:
169
+ del metadata["modelspec.description"]
170
+
171
+ if merged_from is not None:
172
+ metadata["modelspec.merged_from"] = merged_from
173
+ else:
174
+ del metadata["modelspec.merged_from"]
175
+
176
+ if license is not None:
177
+ metadata["modelspec.license"] = license
178
+ else:
179
+ del metadata["modelspec.license"]
180
+
181
+ if tags is not None:
182
+ metadata["modelspec.tags"] = tags
183
+ else:
184
+ del metadata["modelspec.tags"]
185
+
186
+ # remove microsecond from time
187
+ int_ts = int(timestamp)
188
+
189
+ # time to iso-8601 compliant date
190
+ date = datetime.datetime.fromtimestamp(int_ts).isoformat()
191
+ metadata["modelspec.date"] = date
192
+
193
+ if reso is not None:
194
+ # comma separated to tuple
195
+ if isinstance(reso, str):
196
+ reso = tuple(map(int, reso.split(",")))
197
+ if len(reso) == 1:
198
+ reso = (reso[0], reso[0])
199
+ else:
200
+ # resolution is defined in dataset, so use default
201
+ if sdxl:
202
+ reso = 1024
203
+ elif v2 and v_parameterization:
204
+ reso = 768
205
+ else:
206
+ reso = 512
207
+ if isinstance(reso, int):
208
+ reso = (reso, reso)
209
+
210
+ metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
211
+
212
+ if v_parameterization:
213
+ metadata["modelspec.prediction_type"] = PRED_TYPE_V
214
+ else:
215
+ metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
216
+
217
+ if timesteps is not None:
218
+ if isinstance(timesteps, str) or isinstance(timesteps, int):
219
+ timesteps = (timesteps, timesteps)
220
+ if len(timesteps) == 1:
221
+ timesteps = (timesteps[0], timesteps[0])
222
+ metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
223
+ else:
224
+ del metadata["modelspec.timestep_range"]
225
+
226
+ if clip_skip is not None:
227
+ metadata["modelspec.encoder_layer"] = f"{clip_skip}"
228
+ else:
229
+ del metadata["modelspec.encoder_layer"]
230
+
231
+ # # assert all values are filled
232
+ # assert all([v is not None for v in metadata.values()]), metadata
233
+ if not all([v is not None for v in metadata.values()]):
234
+ print(f"Internal error: some metadata values are None: {metadata}")
235
+
236
+ return metadata
237
+
238
+
239
+ # region utils
240
+
241
+
242
+ def get_title(metadata: dict) -> Optional[str]:
243
+ return metadata.get(MODELSPEC_TITLE, None)
244
+
245
+
246
+ def load_metadata_from_safetensors(model: str) -> dict:
247
+ if not model.endswith(".safetensors"):
248
+ return {}
249
+
250
+ with safetensors.safe_open(model, framework="pt") as f:
251
+ metadata = f.metadata()
252
+ if metadata is None:
253
+ metadata = {}
254
+ return metadata
255
+
256
+
257
+ def build_merged_from(models: List[str]) -> str:
258
+ def get_title(model: str):
259
+ metadata = load_metadata_from_safetensors(model)
260
+ title = metadata.get(MODELSPEC_TITLE, None)
261
+ if title is None:
262
+ title = os.path.splitext(os.path.basename(model))[0] # use filename
263
+ return title
264
+
265
+ titles = [get_title(model) for model in models]
266
+ return ", ".join(titles)
267
+
268
+
269
+ # endregion
270
+
271
+
272
+ r"""
273
+ if __name__ == "__main__":
274
+ import argparse
275
+ import torch
276
+ from safetensors.torch import load_file
277
+ from library import train_util
278
+
279
+ parser = argparse.ArgumentParser()
280
+ parser.add_argument("--ckpt", type=str, required=True)
281
+ args = parser.parse_args()
282
+
283
+ print(f"Loading {args.ckpt}")
284
+ state_dict = load_file(args.ckpt)
285
+
286
+ print(f"Calculating metadata")
287
+ metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
288
+ print(metadata)
289
+ del state_dict
290
+
291
+ # by reference implementation
292
+ with open(args.ckpt, mode="rb") as file_data:
293
+ file_hash = hashlib.sha256()
294
+ head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
295
+ header = json.loads(file_data.read(head_len[0])) # header itself, json string
296
+ content = (
297
+ file_data.read()
298
+ ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
299
+ file_hash.update(content)
300
+ # ===== Update the hash for modelspec =====
301
+ by_ref = f"0x{file_hash.hexdigest()}"
302
+ print(by_ref)
303
+ print("is same?", by_ref == metadata["modelspec.hash_sha256"])
304
+
305
+ """
external/llite/library/sdxl_lpw_stable_diffusion.py ADDED
@@ -0,0 +1,1342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
2
+ # and modify to support SD2.x
3
+
4
+ import inspect
5
+ import re
6
+ from typing import Callable, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+ import torch
11
+ from packaging import version
12
+ from tqdm import tqdm
13
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
14
+
15
+ from diffusers import SchedulerMixin, StableDiffusionPipeline
16
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
+ from diffusers.utils import logging
19
+ from PIL import Image
20
+
21
+ from external.llite.library import sdxl_model_util, sdxl_train_util, train_util
22
+
23
+
24
+ try:
25
+ from diffusers.utils import PIL_INTERPOLATION
26
+ except ImportError:
27
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
28
+ PIL_INTERPOLATION = {
29
+ "linear": PIL.Image.Resampling.BILINEAR,
30
+ "bilinear": PIL.Image.Resampling.BILINEAR,
31
+ "bicubic": PIL.Image.Resampling.BICUBIC,
32
+ "lanczos": PIL.Image.Resampling.LANCZOS,
33
+ "nearest": PIL.Image.Resampling.NEAREST,
34
+ }
35
+ else:
36
+ PIL_INTERPOLATION = {
37
+ "linear": PIL.Image.LINEAR,
38
+ "bilinear": PIL.Image.BILINEAR,
39
+ "bicubic": PIL.Image.BICUBIC,
40
+ "lanczos": PIL.Image.LANCZOS,
41
+ "nearest": PIL.Image.NEAREST,
42
+ }
43
+ # ------------------------------------------------------------------------------
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+ re_attention = re.compile(
48
+ r"""
49
+ \\\(|
50
+ \\\)|
51
+ \\\[|
52
+ \\]|
53
+ \\\\|
54
+ \\|
55
+ \(|
56
+ \[|
57
+ :([+-]?[.\d]+)\)|
58
+ \)|
59
+ ]|
60
+ [^\\()\[\]:]+|
61
+ :
62
+ """,
63
+ re.X,
64
+ )
65
+
66
+
67
+ def parse_prompt_attention(text):
68
+ """
69
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
70
+ Accepted tokens are:
71
+ (abc) - increases attention to abc by a multiplier of 1.1
72
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
73
+ [abc] - decreases attention to abc by a multiplier of 1.1
74
+ \( - literal character '('
75
+ \[ - literal character '['
76
+ \) - literal character ')'
77
+ \] - literal character ']'
78
+ \\ - literal character '\'
79
+ anything else - just text
80
+ >>> parse_prompt_attention('normal text')
81
+ [['normal text', 1.0]]
82
+ >>> parse_prompt_attention('an (important) word')
83
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
84
+ >>> parse_prompt_attention('(unbalanced')
85
+ [['unbalanced', 1.1]]
86
+ >>> parse_prompt_attention('\(literal\]')
87
+ [['(literal]', 1.0]]
88
+ >>> parse_prompt_attention('(unnecessary)(parens)')
89
+ [['unnecessaryparens', 1.1]]
90
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
91
+ [['a ', 1.0],
92
+ ['house', 1.5730000000000004],
93
+ [' ', 1.1],
94
+ ['on', 1.0],
95
+ [' a ', 1.1],
96
+ ['hill', 0.55],
97
+ [', sun, ', 1.1],
98
+ ['sky', 1.4641000000000006],
99
+ ['.', 1.1]]
100
+ """
101
+
102
+ res = []
103
+ round_brackets = []
104
+ square_brackets = []
105
+
106
+ round_bracket_multiplier = 1.1
107
+ square_bracket_multiplier = 1 / 1.1
108
+
109
+ def multiply_range(start_position, multiplier):
110
+ for p in range(start_position, len(res)):
111
+ res[p][1] *= multiplier
112
+
113
+ for m in re_attention.finditer(text):
114
+ text = m.group(0)
115
+ weight = m.group(1)
116
+
117
+ if text.startswith("\\"):
118
+ res.append([text[1:], 1.0])
119
+ elif text == "(":
120
+ round_brackets.append(len(res))
121
+ elif text == "[":
122
+ square_brackets.append(len(res))
123
+ elif weight is not None and len(round_brackets) > 0:
124
+ multiply_range(round_brackets.pop(), float(weight))
125
+ elif text == ")" and len(round_brackets) > 0:
126
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
127
+ elif text == "]" and len(square_brackets) > 0:
128
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
129
+ else:
130
+ res.append([text, 1.0])
131
+
132
+ for pos in round_brackets:
133
+ multiply_range(pos, round_bracket_multiplier)
134
+
135
+ for pos in square_brackets:
136
+ multiply_range(pos, square_bracket_multiplier)
137
+
138
+ if len(res) == 0:
139
+ res = [["", 1.0]]
140
+
141
+ # merge runs of identical weights
142
+ i = 0
143
+ while i + 1 < len(res):
144
+ if res[i][1] == res[i + 1][1]:
145
+ res[i][0] += res[i + 1][0]
146
+ res.pop(i + 1)
147
+ else:
148
+ i += 1
149
+
150
+ return res
151
+
152
+
153
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
154
+ r"""
155
+ Tokenize a list of prompts and return its tokens with weights of each token.
156
+
157
+ No padding, starting or ending token is included.
158
+ """
159
+ tokens = []
160
+ weights = []
161
+ truncated = False
162
+ for text in prompt:
163
+ texts_and_weights = parse_prompt_attention(text)
164
+ text_token = []
165
+ text_weight = []
166
+ for word, weight in texts_and_weights:
167
+ # tokenize and discard the starting and the ending token
168
+ token = pipe.tokenizer(word).input_ids[1:-1]
169
+ text_token += token
170
+ # copy the weight by length of token
171
+ text_weight += [weight] * len(token)
172
+ # stop if the text is too long (longer than truncation limit)
173
+ if len(text_token) > max_length:
174
+ truncated = True
175
+ break
176
+ # truncate
177
+ if len(text_token) > max_length:
178
+ truncated = True
179
+ text_token = text_token[:max_length]
180
+ text_weight = text_weight[:max_length]
181
+ tokens.append(text_token)
182
+ weights.append(text_weight)
183
+ if truncated:
184
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
185
+ return tokens, weights
186
+
187
+
188
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
189
+ r"""
190
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
191
+ """
192
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
193
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
194
+ for i in range(len(tokens)):
195
+ tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
196
+ if no_boseos_middle:
197
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
198
+ else:
199
+ w = []
200
+ if len(weights[i]) == 0:
201
+ w = [1.0] * weights_length
202
+ else:
203
+ for j in range(max_embeddings_multiples):
204
+ w.append(1.0) # weight for starting token in this chunk
205
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
206
+ w.append(1.0) # weight for ending token in this chunk
207
+ w += [1.0] * (weights_length - len(w))
208
+ weights[i] = w[:]
209
+
210
+ return tokens, weights
211
+
212
+
213
+ def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
214
+ if not is_sdxl_text_encoder2:
215
+ # text_encoder1: same as SD1/2
216
+ enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
217
+ hidden_states = enc_out["hidden_states"][11]
218
+ pool = None
219
+ else:
220
+ # text_encoder2
221
+ enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
222
+ hidden_states = enc_out["hidden_states"][-2] # penuultimate layer
223
+ # pool = enc_out["text_embeds"]
224
+ pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
225
+ hidden_states = hidden_states.to(device)
226
+ if pool is not None:
227
+ pool = pool.to(device)
228
+ return hidden_states, pool
229
+
230
+
231
+ def get_unweighted_text_embeddings(
232
+ pipe: StableDiffusionPipeline,
233
+ text_input: torch.Tensor,
234
+ chunk_length: int,
235
+ clip_skip: int,
236
+ eos: int,
237
+ pad: int,
238
+ is_sdxl_text_encoder2: bool,
239
+ no_boseos_middle: Optional[bool] = True,
240
+ ):
241
+ """
242
+ When the length of tokens is a multiple of the capacity of the text encoder,
243
+ it should be split into chunks and sent to the text encoder individually.
244
+ """
245
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
246
+ text_pool = None
247
+ if max_embeddings_multiples > 1:
248
+ text_embeddings = []
249
+ for i in range(max_embeddings_multiples):
250
+ # extract the i-th chunk
251
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
252
+
253
+ # cover the head and the tail by the starting and the ending tokens
254
+ text_input_chunk[:, 0] = text_input[0, 0]
255
+ if pad == eos: # v1
256
+ text_input_chunk[:, -1] = text_input[0, -1]
257
+ else: # v2
258
+ for j in range(len(text_input_chunk)):
259
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
260
+ text_input_chunk[j, -1] = eos
261
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
262
+ text_input_chunk[j, 1] = eos
263
+
264
+ text_embedding, current_text_pool = get_hidden_states(
265
+ pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
266
+ )
267
+ if text_pool is None:
268
+ text_pool = current_text_pool
269
+
270
+ if no_boseos_middle:
271
+ if i == 0:
272
+ # discard the ending token
273
+ text_embedding = text_embedding[:, :-1]
274
+ elif i == max_embeddings_multiples - 1:
275
+ # discard the starting token
276
+ text_embedding = text_embedding[:, 1:]
277
+ else:
278
+ # discard both starting and ending tokens
279
+ text_embedding = text_embedding[:, 1:-1]
280
+
281
+ text_embeddings.append(text_embedding)
282
+ text_embeddings = torch.concat(text_embeddings, axis=1)
283
+ else:
284
+ text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
285
+ return text_embeddings, text_pool
286
+
287
+
288
+ def get_weighted_text_embeddings(
289
+ pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline,
290
+ prompt: Union[str, List[str]],
291
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
292
+ max_embeddings_multiples: Optional[int] = 3,
293
+ no_boseos_middle: Optional[bool] = False,
294
+ skip_parsing: Optional[bool] = False,
295
+ skip_weighting: Optional[bool] = False,
296
+ clip_skip=None,
297
+ is_sdxl_text_encoder2=False,
298
+ ):
299
+ r"""
300
+ Prompts can be assigned with local weights using brackets. For example,
301
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
302
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
303
+
304
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
305
+
306
+ Args:
307
+ pipe (`StableDiffusionPipeline`):
308
+ Pipe to provide access to the tokenizer and the text encoder.
309
+ prompt (`str` or `List[str]`):
310
+ The prompt or prompts to guide the image generation.
311
+ uncond_prompt (`str` or `List[str]`):
312
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
313
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
314
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
315
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
316
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
317
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
318
+ ending token in each of the chunk in the middle.
319
+ skip_parsing (`bool`, *optional*, defaults to `False`):
320
+ Skip the parsing of brackets.
321
+ skip_weighting (`bool`, *optional*, defaults to `False`):
322
+ Skip the weighting. When the parsing is skipped, it is forced True.
323
+ """
324
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
325
+ if isinstance(prompt, str):
326
+ prompt = [prompt]
327
+
328
+ if not skip_parsing:
329
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
330
+ if uncond_prompt is not None:
331
+ if isinstance(uncond_prompt, str):
332
+ uncond_prompt = [uncond_prompt]
333
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
334
+ else:
335
+ prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
336
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
337
+ if uncond_prompt is not None:
338
+ if isinstance(uncond_prompt, str):
339
+ uncond_prompt = [uncond_prompt]
340
+ uncond_tokens = [
341
+ token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
342
+ ]
343
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
344
+
345
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
346
+ max_length = max([len(token) for token in prompt_tokens])
347
+ if uncond_prompt is not None:
348
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
349
+
350
+ max_embeddings_multiples = min(
351
+ max_embeddings_multiples,
352
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
353
+ )
354
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
355
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
356
+
357
+ # pad the length of tokens and weights
358
+ bos = pipe.tokenizer.bos_token_id
359
+ eos = pipe.tokenizer.eos_token_id
360
+ pad = pipe.tokenizer.pad_token_id
361
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
362
+ prompt_tokens,
363
+ prompt_weights,
364
+ max_length,
365
+ bos,
366
+ eos,
367
+ pad,
368
+ no_boseos_middle=no_boseos_middle,
369
+ chunk_length=pipe.tokenizer.model_max_length,
370
+ )
371
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
372
+ if uncond_prompt is not None:
373
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
374
+ uncond_tokens,
375
+ uncond_weights,
376
+ max_length,
377
+ bos,
378
+ eos,
379
+ pad,
380
+ no_boseos_middle=no_boseos_middle,
381
+ chunk_length=pipe.tokenizer.model_max_length,
382
+ )
383
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
384
+
385
+ # get the embeddings
386
+ text_embeddings, text_pool = get_unweighted_text_embeddings(
387
+ pipe,
388
+ prompt_tokens,
389
+ pipe.tokenizer.model_max_length,
390
+ clip_skip,
391
+ eos,
392
+ pad,
393
+ is_sdxl_text_encoder2,
394
+ no_boseos_middle=no_boseos_middle,
395
+ )
396
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
397
+
398
+ if uncond_prompt is not None:
399
+ uncond_embeddings, uncond_pool = get_unweighted_text_embeddings(
400
+ pipe,
401
+ uncond_tokens,
402
+ pipe.tokenizer.model_max_length,
403
+ clip_skip,
404
+ eos,
405
+ pad,
406
+ is_sdxl_text_encoder2,
407
+ no_boseos_middle=no_boseos_middle,
408
+ )
409
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
410
+
411
+ # assign weights to the prompts and normalize in the sense of mean
412
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
413
+ if (not skip_parsing) and (not skip_weighting):
414
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
415
+ text_embeddings *= prompt_weights.unsqueeze(-1)
416
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
417
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
418
+ if uncond_prompt is not None:
419
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
420
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
421
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
422
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
423
+
424
+ if uncond_prompt is not None:
425
+ return text_embeddings, text_pool, uncond_embeddings, uncond_pool
426
+ return text_embeddings, text_pool, None, None
427
+
428
+
429
+ def preprocess_image(image):
430
+ w, h = image.size
431
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
432
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
433
+ image = np.array(image).astype(np.float32) / 255.0
434
+ image = image[None].transpose(0, 3, 1, 2)
435
+ image = torch.from_numpy(image)
436
+ return 2.0 * image - 1.0
437
+
438
+
439
+ def preprocess_mask(mask, scale_factor=8):
440
+ mask = mask.convert("L")
441
+ w, h = mask.size
442
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
443
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
444
+ mask = np.array(mask).astype(np.float32) / 255.0
445
+ mask = np.tile(mask, (4, 1, 1))
446
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
447
+ mask = 1 - mask # repaint white, keep black
448
+ mask = torch.from_numpy(mask)
449
+ return mask
450
+
451
+
452
+ def prepare_controlnet_image(
453
+ image: PIL.Image.Image,
454
+ width: int,
455
+ height: int,
456
+ batch_size: int,
457
+ num_images_per_prompt: int,
458
+ device: torch.device,
459
+ dtype: torch.dtype,
460
+ do_classifier_free_guidance: bool = False,
461
+ guess_mode: bool = False,
462
+ ):
463
+ if not isinstance(image, torch.Tensor):
464
+ if isinstance(image, PIL.Image.Image):
465
+ image = [image]
466
+
467
+ if isinstance(image[0], PIL.Image.Image):
468
+ images = []
469
+
470
+ for image_ in image:
471
+ image_ = image_.convert("RGB")
472
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
473
+ image_ = np.array(image_)
474
+ image_ = image_[None, :]
475
+ images.append(image_)
476
+
477
+ image = images
478
+
479
+ image = np.concatenate(image, axis=0)
480
+ image = np.array(image).astype(np.float32) / 255.0
481
+ image = image.transpose(0, 3, 1, 2)
482
+ image = torch.from_numpy(image)
483
+ elif isinstance(image[0], torch.Tensor):
484
+ image = torch.cat(image, dim=0)
485
+
486
+ image_batch_size = image.shape[0]
487
+
488
+ if image_batch_size == 1:
489
+ repeat_by = batch_size
490
+ else:
491
+ # image batch size is the same as prompt batch size
492
+ repeat_by = num_images_per_prompt
493
+
494
+ image = image.repeat_interleave(repeat_by, dim=0)
495
+
496
+ image = image.to(device=device, dtype=dtype)
497
+
498
+ if do_classifier_free_guidance and not guess_mode:
499
+ image = torch.cat([image] * 2)
500
+
501
+ return image
502
+
503
+
504
+ class SdxlStableDiffusionLongPromptWeightingPipeline:
505
+ r"""
506
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
507
+ weighting in prompt.
508
+
509
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
510
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
511
+
512
+ Args:
513
+ vae ([`AutoencoderKL`]):
514
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
515
+ text_encoder ([`CLIPTextModel`]):
516
+ Frozen text-encoder. Stable Diffusion uses the text portion of
517
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
518
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
519
+ tokenizer (`CLIPTokenizer`):
520
+ Tokenizer of class
521
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
522
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
523
+ scheduler ([`SchedulerMixin`]):
524
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
525
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
526
+ safety_checker ([`StableDiffusionSafetyChecker`]):
527
+ Classification module that estimates whether generated images could be considered offensive or harmful.
528
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
529
+ feature_extractor ([`CLIPFeatureExtractor`]):
530
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
531
+ """
532
+
533
+ # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
534
+
535
+ def __init__(
536
+ self,
537
+ vae: AutoencoderKL,
538
+ text_encoder: List[CLIPTextModel],
539
+ tokenizer: List[CLIPTokenizer],
540
+ unet: UNet2DConditionModel,
541
+ scheduler: SchedulerMixin,
542
+ # clip_skip: int,
543
+ safety_checker: StableDiffusionSafetyChecker,
544
+ feature_extractor: CLIPFeatureExtractor,
545
+ requires_safety_checker: bool = True,
546
+ clip_skip: int = 1,
547
+ ):
548
+ # clip skip is ignored currently
549
+ self.tokenizer = tokenizer[0]
550
+ self.text_encoder = text_encoder[0]
551
+ self.unet = unet
552
+ self.scheduler = scheduler
553
+ self.safety_checker = safety_checker
554
+ self.feature_extractor = feature_extractor
555
+ self.requires_safety_checker = requires_safety_checker
556
+ self.vae = vae
557
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
558
+ self.progress_bar = lambda x: tqdm(x, leave=False)
559
+
560
+ self.clip_skip = clip_skip
561
+ self.tokenizers = tokenizer
562
+ self.text_encoders = text_encoder
563
+
564
+ # self.__init__additional__()
565
+
566
+ # def __init__additional__(self):
567
+ # if not hasattr(self, "vae_scale_factor"):
568
+ # setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
569
+
570
+ def to(self, device=None, dtype=None):
571
+ if device is not None:
572
+ self.device = device
573
+ # self.vae.to(device=self.device)
574
+ if dtype is not None:
575
+ self.dtype = dtype
576
+
577
+ # do not move Text Encoders to device, because Text Encoder should be on CPU
578
+
579
+ @property
580
+ def _execution_device(self):
581
+ r"""
582
+ Returns the device on which the pipeline's models will be executed. After calling
583
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
584
+ hooks.
585
+ """
586
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
587
+ return self.device
588
+ for module in self.unet.modules():
589
+ if (
590
+ hasattr(module, "_hf_hook")
591
+ and hasattr(module._hf_hook, "execution_device")
592
+ and module._hf_hook.execution_device is not None
593
+ ):
594
+ return torch.device(module._hf_hook.execution_device)
595
+ return self.device
596
+
597
+ def _encode_prompt(
598
+ self,
599
+ prompt,
600
+ device,
601
+ num_images_per_prompt,
602
+ do_classifier_free_guidance,
603
+ negative_prompt,
604
+ max_embeddings_multiples,
605
+ is_sdxl_text_encoder2,
606
+ ):
607
+ r"""
608
+ Encodes the prompt into text encoder hidden states.
609
+
610
+ Args:
611
+ prompt (`str` or `list(int)`):
612
+ prompt to be encoded
613
+ device: (`torch.device`):
614
+ torch device
615
+ num_images_per_prompt (`int`):
616
+ number of images that should be generated per prompt
617
+ do_classifier_free_guidance (`bool`):
618
+ whether to use classifier free guidance or not
619
+ negative_prompt (`str` or `List[str]`):
620
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
621
+ if `guidance_scale` is less than `1`).
622
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
623
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
624
+ """
625
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
626
+
627
+ if negative_prompt is None:
628
+ negative_prompt = [""] * batch_size
629
+ elif isinstance(negative_prompt, str):
630
+ negative_prompt = [negative_prompt] * batch_size
631
+ if batch_size != len(negative_prompt):
632
+ raise ValueError(
633
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
634
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
635
+ " the batch size of `prompt`."
636
+ )
637
+
638
+ text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
639
+ pipe=self,
640
+ prompt=prompt,
641
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
642
+ max_embeddings_multiples=max_embeddings_multiples,
643
+ clip_skip=self.clip_skip,
644
+ is_sdxl_text_encoder2=is_sdxl_text_encoder2,
645
+ )
646
+ bs_embed, seq_len, _ = text_embeddings.shape
647
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ??
648
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
649
+ if text_pool is not None:
650
+ text_pool = text_pool.repeat(1, num_images_per_prompt)
651
+ text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
652
+
653
+ if do_classifier_free_guidance:
654
+ bs_embed, seq_len, _ = uncond_embeddings.shape
655
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
656
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
657
+ if uncond_pool is not None:
658
+ uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
659
+ uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
660
+
661
+ return text_embeddings, text_pool, uncond_embeddings, uncond_pool
662
+
663
+ return text_embeddings, text_pool, None, None
664
+
665
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
666
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
667
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
668
+
669
+ if strength < 0 or strength > 1:
670
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
671
+
672
+ if height % 8 != 0 or width % 8 != 0:
673
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
674
+
675
+ if (callback_steps is None) or (
676
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
677
+ ):
678
+ raise ValueError(
679
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
680
+ )
681
+
682
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
683
+ if is_text2img:
684
+ return self.scheduler.timesteps.to(device), num_inference_steps
685
+ else:
686
+ # get the original timestep using init_timestep
687
+ offset = self.scheduler.config.get("steps_offset", 0)
688
+ init_timestep = int(num_inference_steps * strength) + offset
689
+ init_timestep = min(init_timestep, num_inference_steps)
690
+
691
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
692
+ timesteps = self.scheduler.timesteps[t_start:].to(device)
693
+ return timesteps, num_inference_steps - t_start
694
+
695
+ def run_safety_checker(self, image, device, dtype):
696
+ if self.safety_checker is not None:
697
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
698
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
699
+ else:
700
+ has_nsfw_concept = None
701
+ return image, has_nsfw_concept
702
+
703
+ def decode_latents(self, latents):
704
+ with torch.no_grad():
705
+ latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
706
+
707
+ # print("post_quant_conv dtype:", self.vae.post_quant_conv.weight.dtype) # torch.float32
708
+ # x = torch.nn.functional.conv2d(latents, self.vae.post_quant_conv.weight.detach(), stride=1, padding=0)
709
+ # print("latents dtype:", latents.dtype, "x dtype:", x.dtype) # torch.float32, torch.float16
710
+ # self.vae.to("cpu")
711
+ # self.vae.set_use_memory_efficient_attention_xformers(False)
712
+ # image = self.vae.decode(latents.to("cpu")).sample
713
+
714
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
715
+ image = (image / 2 + 0.5).clamp(0, 1)
716
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
717
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
718
+ return image
719
+
720
+ def prepare_extra_step_kwargs(self, generator, eta):
721
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
722
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
723
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
724
+ # and should be between [0, 1]
725
+
726
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
727
+ extra_step_kwargs = {}
728
+ if accepts_eta:
729
+ extra_step_kwargs["eta"] = eta
730
+
731
+ # check if the scheduler accepts generator
732
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
733
+ if accepts_generator:
734
+ extra_step_kwargs["generator"] = generator
735
+ return extra_step_kwargs
736
+
737
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
738
+ if image is None:
739
+ shape = (
740
+ batch_size,
741
+ self.unet.in_channels,
742
+ height // self.vae_scale_factor,
743
+ width // self.vae_scale_factor,
744
+ )
745
+
746
+ if latents is None:
747
+ if device.type == "mps":
748
+ # randn does not work reproducibly on mps
749
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
750
+ else:
751
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
752
+ else:
753
+ if latents.shape != shape:
754
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
755
+ latents = latents.to(device)
756
+
757
+ # scale the initial noise by the standard deviation required by the scheduler
758
+ latents = latents * self.scheduler.init_noise_sigma
759
+ return latents, None, None
760
+ else:
761
+ init_latent_dist = self.vae.encode(image).latent_dist
762
+ init_latents = init_latent_dist.sample(generator=generator)
763
+ init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents
764
+ init_latents = torch.cat([init_latents] * batch_size, dim=0)
765
+ init_latents_orig = init_latents
766
+ shape = init_latents.shape
767
+
768
+ # add noise to latents using the timesteps
769
+ if device.type == "mps":
770
+ noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
771
+ else:
772
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
773
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
774
+ return latents, init_latents_orig, noise
775
+
776
+ @torch.no_grad()
777
+ def __call__(
778
+ self,
779
+ prompt: Union[str, List[str]],
780
+ negative_prompt: Optional[Union[str, List[str]]] = None,
781
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
782
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
783
+ height: int = 512,
784
+ width: int = 512,
785
+ num_inference_steps: int = 50,
786
+ guidance_scale: float = 7.5,
787
+ strength: float = 0.8,
788
+ num_images_per_prompt: Optional[int] = 1,
789
+ eta: float = 0.0,
790
+ generator: Optional[torch.Generator] = None,
791
+ latents: Optional[torch.FloatTensor] = None,
792
+ max_embeddings_multiples: Optional[int] = 3,
793
+ output_type: Optional[str] = "pil",
794
+ return_dict: bool = True,
795
+ controlnet=None,
796
+ controlnet_image=None,
797
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
798
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
799
+ callback_steps: int = 1,
800
+ ):
801
+ r"""
802
+ Function invoked when calling the pipeline for generation.
803
+
804
+ Args:
805
+ prompt (`str` or `List[str]`):
806
+ The prompt or prompts to guide the image generation.
807
+ negative_prompt (`str` or `List[str]`, *optional*):
808
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
809
+ if `guidance_scale` is less than `1`).
810
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
811
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
812
+ process.
813
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
814
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
815
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
816
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
817
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
818
+ height (`int`, *optional*, defaults to 512):
819
+ The height in pixels of the generated image.
820
+ width (`int`, *optional*, defaults to 512):
821
+ The width in pixels of the generated image.
822
+ num_inference_steps (`int`, *optional*, defaults to 50):
823
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
824
+ expense of slower inference.
825
+ guidance_scale (`float`, *optional*, defaults to 7.5):
826
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
827
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
828
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
829
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
830
+ usually at the expense of lower image quality.
831
+ strength (`float`, *optional*, defaults to 0.8):
832
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
833
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
834
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
835
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
836
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
837
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
838
+ The number of images to generate per prompt.
839
+ eta (`float`, *optional*, defaults to 0.0):
840
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
841
+ [`schedulers.DDIMScheduler`], will be ignored for others.
842
+ generator (`torch.Generator`, *optional*):
843
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
844
+ deterministic.
845
+ latents (`torch.FloatTensor`, *optional*):
846
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
847
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
848
+ tensor will ge generated by sampling using the supplied random `generator`.
849
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
850
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
851
+ output_type (`str`, *optional*, defaults to `"pil"`):
852
+ The output format of the generate image. Choose between
853
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
854
+ return_dict (`bool`, *optional*, defaults to `True`):
855
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
856
+ plain tuple.
857
+ controlnet (`diffusers.ControlNetModel`, *optional*):
858
+ A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
859
+ controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
860
+ `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
861
+ inference.
862
+ callback (`Callable`, *optional*):
863
+ A function that will be called every `callback_steps` steps during inference. The function will be
864
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
865
+ is_cancelled_callback (`Callable`, *optional*):
866
+ A function that will be called every `callback_steps` steps during inference. If the function returns
867
+ `True`, the inference will be cancelled.
868
+ callback_steps (`int`, *optional*, defaults to 1):
869
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
870
+ called at every step.
871
+
872
+ Returns:
873
+ `None` if cancelled by `is_cancelled_callback`,
874
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
875
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
876
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
877
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
878
+ (nsfw) content, according to the `safety_checker`.
879
+ """
880
+ if controlnet is not None and controlnet_image is None:
881
+ raise ValueError("controlnet_image must be provided if controlnet is not None.")
882
+
883
+ # 0. Default height and width to unet
884
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
885
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
886
+
887
+ # 1. Check inputs. Raise error if not correct
888
+ self.check_inputs(prompt, height, width, strength, callback_steps)
889
+
890
+ # 2. Define call parameters
891
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
892
+ device = self._execution_device
893
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
894
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
895
+ # corresponds to doing no classifier free guidance.
896
+ do_classifier_free_guidance = guidance_scale > 1.0
897
+
898
+ # 3. Encode input prompt
899
+ # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
900
+ # To simplify the implementation, switch the tokenzer/text encoder and call it twice
901
+ text_embeddings_list = []
902
+ text_pool = None
903
+ uncond_embeddings_list = []
904
+ uncond_pool = None
905
+ for i in range(len(self.tokenizers)):
906
+ self.tokenizer = self.tokenizers[i]
907
+ self.text_encoder = self.text_encoders[i]
908
+
909
+ text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
910
+ prompt,
911
+ device,
912
+ num_images_per_prompt,
913
+ do_classifier_free_guidance,
914
+ negative_prompt,
915
+ max_embeddings_multiples,
916
+ is_sdxl_text_encoder2=i == 1,
917
+ )
918
+ text_embeddings_list.append(text_embeddings)
919
+ uncond_embeddings_list.append(uncond_embeddings)
920
+
921
+ if tp1 is not None:
922
+ text_pool = tp1
923
+ if up1 is not None:
924
+ uncond_pool = up1
925
+
926
+ dtype = self.unet.dtype
927
+
928
+ # 4. Preprocess image and mask
929
+ if isinstance(image, PIL.Image.Image):
930
+ image = preprocess_image(image)
931
+ if image is not None:
932
+ image = image.to(device=self.device, dtype=dtype)
933
+ if isinstance(mask_image, PIL.Image.Image):
934
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
935
+ if mask_image is not None:
936
+ mask = mask_image.to(device=self.device, dtype=dtype)
937
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
938
+ else:
939
+ mask = None
940
+
941
+ # ControlNet is not working yet in SDXL, but keep the code here for future use
942
+ if controlnet_image is not None:
943
+ controlnet_image = prepare_controlnet_image(
944
+ controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
945
+ )
946
+
947
+ # 5. set timesteps
948
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
949
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
950
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
951
+
952
+ # 6. Prepare latent variables
953
+ latents, init_latents_orig, noise = self.prepare_latents(
954
+ image,
955
+ latent_timestep,
956
+ batch_size * num_images_per_prompt,
957
+ height,
958
+ width,
959
+ dtype,
960
+ device,
961
+ generator,
962
+ latents,
963
+ )
964
+
965
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
966
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
967
+
968
+ # create size embs and concat embeddings for SDXL
969
+ orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
970
+ crop_size = torch.zeros_like(orig_size)
971
+ target_size = orig_size
972
+ embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
973
+
974
+ # make conditionings
975
+ if do_classifier_free_guidance:
976
+ text_embeddings = torch.cat(text_embeddings_list, dim=2)
977
+ uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
978
+ text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
979
+
980
+ cond_vector = torch.cat([text_pool, embs], dim=1)
981
+ uncond_vector = torch.cat([uncond_pool, embs], dim=1)
982
+ vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
983
+ else:
984
+ text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
985
+ vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
986
+
987
+ # 8. Denoising loop
988
+ for i, t in enumerate(self.progress_bar(timesteps)):
989
+ # expand the latents if we are doing classifier free guidance
990
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
991
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
992
+
993
+ unet_additional_args = {}
994
+ if controlnet is not None:
995
+ down_block_res_samples, mid_block_res_sample = controlnet(
996
+ latent_model_input,
997
+ t,
998
+ encoder_hidden_states=text_embeddings,
999
+ controlnet_cond=controlnet_image,
1000
+ conditioning_scale=1.0,
1001
+ guess_mode=False,
1002
+ return_dict=False,
1003
+ )
1004
+ unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
1005
+ unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
1006
+
1007
+ # predict the noise residual
1008
+ noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
1009
+ noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
1010
+
1011
+ # perform guidance
1012
+ if do_classifier_free_guidance:
1013
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1014
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1015
+
1016
+ # compute the previous noisy sample x_t -> x_t-1
1017
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1018
+
1019
+ if mask is not None:
1020
+ # masking
1021
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
1022
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
1023
+
1024
+ # call the callback, if provided
1025
+ if i % callback_steps == 0:
1026
+ if callback is not None:
1027
+ callback(i, t, latents)
1028
+ if is_cancelled_callback is not None and is_cancelled_callback():
1029
+ return None
1030
+
1031
+ return latents
1032
+
1033
+ def latents_to_image(self, latents):
1034
+ # 9. Post-processing
1035
+ image = self.decode_latents(latents.to(self.vae.dtype))
1036
+ image = self.numpy_to_pil(image)
1037
+ return image
1038
+
1039
+ # copy from pil_utils.py
1040
+ def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
1041
+ """
1042
+ Convert a numpy image or a batch of images to a PIL image.
1043
+ """
1044
+ if images.ndim == 3:
1045
+ images = images[None, ...]
1046
+ images = (images * 255).round().astype("uint8")
1047
+ if images.shape[-1] == 1:
1048
+ # special case for grayscale (single channel) images
1049
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
1050
+ else:
1051
+ pil_images = [Image.fromarray(image) for image in images]
1052
+
1053
+ return pil_images
1054
+
1055
+ def text2img(
1056
+ self,
1057
+ prompt: Union[str, List[str]],
1058
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1059
+ height: int = 512,
1060
+ width: int = 512,
1061
+ num_inference_steps: int = 50,
1062
+ guidance_scale: float = 7.5,
1063
+ num_images_per_prompt: Optional[int] = 1,
1064
+ eta: float = 0.0,
1065
+ generator: Optional[torch.Generator] = None,
1066
+ latents: Optional[torch.FloatTensor] = None,
1067
+ max_embeddings_multiples: Optional[int] = 3,
1068
+ output_type: Optional[str] = "pil",
1069
+ return_dict: bool = True,
1070
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1071
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1072
+ callback_steps: int = 1,
1073
+ ):
1074
+ r"""
1075
+ Function for text-to-image generation.
1076
+ Args:
1077
+ prompt (`str` or `List[str]`):
1078
+ The prompt or prompts to guide the image generation.
1079
+ negative_prompt (`str` or `List[str]`, *optional*):
1080
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1081
+ if `guidance_scale` is less than `1`).
1082
+ height (`int`, *optional*, defaults to 512):
1083
+ The height in pixels of the generated image.
1084
+ width (`int`, *optional*, defaults to 512):
1085
+ The width in pixels of the generated image.
1086
+ num_inference_steps (`int`, *optional*, defaults to 50):
1087
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1088
+ expense of slower inference.
1089
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1090
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1091
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1092
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1093
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1094
+ usually at the expense of lower image quality.
1095
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1096
+ The number of images to generate per prompt.
1097
+ eta (`float`, *optional*, defaults to 0.0):
1098
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1099
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1100
+ generator (`torch.Generator`, *optional*):
1101
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1102
+ deterministic.
1103
+ latents (`torch.FloatTensor`, *optional*):
1104
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1105
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1106
+ tensor will ge generated by sampling using the supplied random `generator`.
1107
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1108
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1109
+ output_type (`str`, *optional*, defaults to `"pil"`):
1110
+ The output format of the generate image. Choose between
1111
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1112
+ return_dict (`bool`, *optional*, defaults to `True`):
1113
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1114
+ plain tuple.
1115
+ callback (`Callable`, *optional*):
1116
+ A function that will be called every `callback_steps` steps during inference. The function will be
1117
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1118
+ is_cancelled_callback (`Callable`, *optional*):
1119
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1120
+ `True`, the inference will be cancelled.
1121
+ callback_steps (`int`, *optional*, defaults to 1):
1122
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1123
+ called at every step.
1124
+ Returns:
1125
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1126
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1127
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1128
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1129
+ (nsfw) content, according to the `safety_checker`.
1130
+ """
1131
+ return self.__call__(
1132
+ prompt=prompt,
1133
+ negative_prompt=negative_prompt,
1134
+ height=height,
1135
+ width=width,
1136
+ num_inference_steps=num_inference_steps,
1137
+ guidance_scale=guidance_scale,
1138
+ num_images_per_prompt=num_images_per_prompt,
1139
+ eta=eta,
1140
+ generator=generator,
1141
+ latents=latents,
1142
+ max_embeddings_multiples=max_embeddings_multiples,
1143
+ output_type=output_type,
1144
+ return_dict=return_dict,
1145
+ callback=callback,
1146
+ is_cancelled_callback=is_cancelled_callback,
1147
+ callback_steps=callback_steps,
1148
+ )
1149
+
1150
+ def img2img(
1151
+ self,
1152
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1153
+ prompt: Union[str, List[str]],
1154
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1155
+ strength: float = 0.8,
1156
+ num_inference_steps: Optional[int] = 50,
1157
+ guidance_scale: Optional[float] = 7.5,
1158
+ num_images_per_prompt: Optional[int] = 1,
1159
+ eta: Optional[float] = 0.0,
1160
+ generator: Optional[torch.Generator] = None,
1161
+ max_embeddings_multiples: Optional[int] = 3,
1162
+ output_type: Optional[str] = "pil",
1163
+ return_dict: bool = True,
1164
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1165
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1166
+ callback_steps: int = 1,
1167
+ ):
1168
+ r"""
1169
+ Function for image-to-image generation.
1170
+ Args:
1171
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1172
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1173
+ process.
1174
+ prompt (`str` or `List[str]`):
1175
+ The prompt or prompts to guide the image generation.
1176
+ negative_prompt (`str` or `List[str]`, *optional*):
1177
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1178
+ if `guidance_scale` is less than `1`).
1179
+ strength (`float`, *optional*, defaults to 0.8):
1180
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1181
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1182
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1183
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1184
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1185
+ num_inference_steps (`int`, *optional*, defaults to 50):
1186
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1187
+ expense of slower inference. This parameter will be modulated by `strength`.
1188
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1189
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1190
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1191
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1192
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1193
+ usually at the expense of lower image quality.
1194
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1195
+ The number of images to generate per prompt.
1196
+ eta (`float`, *optional*, defaults to 0.0):
1197
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1198
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1199
+ generator (`torch.Generator`, *optional*):
1200
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1201
+ deterministic.
1202
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1203
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1204
+ output_type (`str`, *optional*, defaults to `"pil"`):
1205
+ The output format of the generate image. Choose between
1206
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1207
+ return_dict (`bool`, *optional*, defaults to `True`):
1208
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1209
+ plain tuple.
1210
+ callback (`Callable`, *optional*):
1211
+ A function that will be called every `callback_steps` steps during inference. The function will be
1212
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1213
+ is_cancelled_callback (`Callable`, *optional*):
1214
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1215
+ `True`, the inference will be cancelled.
1216
+ callback_steps (`int`, *optional*, defaults to 1):
1217
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1218
+ called at every step.
1219
+ Returns:
1220
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1221
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1222
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1223
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1224
+ (nsfw) content, according to the `safety_checker`.
1225
+ """
1226
+ return self.__call__(
1227
+ prompt=prompt,
1228
+ negative_prompt=negative_prompt,
1229
+ image=image,
1230
+ num_inference_steps=num_inference_steps,
1231
+ guidance_scale=guidance_scale,
1232
+ strength=strength,
1233
+ num_images_per_prompt=num_images_per_prompt,
1234
+ eta=eta,
1235
+ generator=generator,
1236
+ max_embeddings_multiples=max_embeddings_multiples,
1237
+ output_type=output_type,
1238
+ return_dict=return_dict,
1239
+ callback=callback,
1240
+ is_cancelled_callback=is_cancelled_callback,
1241
+ callback_steps=callback_steps,
1242
+ )
1243
+
1244
+ def inpaint(
1245
+ self,
1246
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1247
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1248
+ prompt: Union[str, List[str]],
1249
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1250
+ strength: float = 0.8,
1251
+ num_inference_steps: Optional[int] = 50,
1252
+ guidance_scale: Optional[float] = 7.5,
1253
+ num_images_per_prompt: Optional[int] = 1,
1254
+ eta: Optional[float] = 0.0,
1255
+ generator: Optional[torch.Generator] = None,
1256
+ max_embeddings_multiples: Optional[int] = 3,
1257
+ output_type: Optional[str] = "pil",
1258
+ return_dict: bool = True,
1259
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1260
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1261
+ callback_steps: int = 1,
1262
+ ):
1263
+ r"""
1264
+ Function for inpaint.
1265
+ Args:
1266
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1267
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1268
+ process. This is the image whose masked region will be inpainted.
1269
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1270
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1271
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1272
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1273
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1274
+ prompt (`str` or `List[str]`):
1275
+ The prompt or prompts to guide the image generation.
1276
+ negative_prompt (`str` or `List[str]`, *optional*):
1277
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1278
+ if `guidance_scale` is less than `1`).
1279
+ strength (`float`, *optional*, defaults to 0.8):
1280
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1281
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1282
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1283
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1284
+ num_inference_steps (`int`, *optional*, defaults to 50):
1285
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1286
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1287
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1288
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1289
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1290
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1291
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1292
+ usually at the expense of lower image quality.
1293
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1294
+ The number of images to generate per prompt.
1295
+ eta (`float`, *optional*, defaults to 0.0):
1296
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1297
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1298
+ generator (`torch.Generator`, *optional*):
1299
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1300
+ deterministic.
1301
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1302
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1303
+ output_type (`str`, *optional*, defaults to `"pil"`):
1304
+ The output format of the generate image. Choose between
1305
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1306
+ return_dict (`bool`, *optional*, defaults to `True`):
1307
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1308
+ plain tuple.
1309
+ callback (`Callable`, *optional*):
1310
+ A function that will be called every `callback_steps` steps during inference. The function will be
1311
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1312
+ is_cancelled_callback (`Callable`, *optional*):
1313
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1314
+ `True`, the inference will be cancelled.
1315
+ callback_steps (`int`, *optional*, defaults to 1):
1316
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1317
+ called at every step.
1318
+ Returns:
1319
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1320
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1321
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1322
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1323
+ (nsfw) content, according to the `safety_checker`.
1324
+ """
1325
+ return self.__call__(
1326
+ prompt=prompt,
1327
+ negative_prompt=negative_prompt,
1328
+ image=image,
1329
+ mask_image=mask_image,
1330
+ num_inference_steps=num_inference_steps,
1331
+ guidance_scale=guidance_scale,
1332
+ strength=strength,
1333
+ num_images_per_prompt=num_images_per_prompt,
1334
+ eta=eta,
1335
+ generator=generator,
1336
+ max_embeddings_multiples=max_embeddings_multiples,
1337
+ output_type=output_type,
1338
+ return_dict=return_dict,
1339
+ callback=callback,
1340
+ is_cancelled_callback=is_cancelled_callback,
1341
+ callback_steps=callback_steps,
1342
+ )
external/llite/library/sdxl_model_util.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from accelerate import init_empty_weights
3
+ from accelerate.utils.modeling import set_module_tensor_to_device
4
+ from safetensors.torch import load_file, save_file
5
+ from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
6
+ from typing import List
7
+ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
8
+ from external.llite.library import model_util
9
+ from external.llite.library import sdxl_original_unet
10
+
11
+
12
+ VAE_SCALE_FACTOR = 0.13025
13
+ MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
14
+
15
+ # Diffusersの設定を読み込むための参照モデル
16
+ DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
17
+
18
+ DIFFUSERS_SDXL_UNET_CONFIG = {
19
+ "act_fn": "silu",
20
+ "addition_embed_type": "text_time",
21
+ "addition_embed_type_num_heads": 64,
22
+ "addition_time_embed_dim": 256,
23
+ "attention_head_dim": [5, 10, 20],
24
+ "block_out_channels": [320, 640, 1280],
25
+ "center_input_sample": False,
26
+ "class_embed_type": None,
27
+ "class_embeddings_concat": False,
28
+ "conv_in_kernel": 3,
29
+ "conv_out_kernel": 3,
30
+ "cross_attention_dim": 2048,
31
+ "cross_attention_norm": None,
32
+ "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
33
+ "downsample_padding": 1,
34
+ "dual_cross_attention": False,
35
+ "encoder_hid_dim": None,
36
+ "encoder_hid_dim_type": None,
37
+ "flip_sin_to_cos": True,
38
+ "freq_shift": 0,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_only_cross_attention": None,
42
+ "mid_block_scale_factor": 1,
43
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
44
+ "norm_eps": 1e-05,
45
+ "norm_num_groups": 32,
46
+ "num_attention_heads": None,
47
+ "num_class_embeds": None,
48
+ "only_cross_attention": False,
49
+ "out_channels": 4,
50
+ "projection_class_embeddings_input_dim": 2816,
51
+ "resnet_out_scale_factor": 1.0,
52
+ "resnet_skip_time_act": False,
53
+ "resnet_time_scale_shift": "default",
54
+ "sample_size": 128,
55
+ "time_cond_proj_dim": None,
56
+ "time_embedding_act_fn": None,
57
+ "time_embedding_dim": None,
58
+ "time_embedding_type": "positional",
59
+ "timestep_post_act": None,
60
+ "transformer_layers_per_block": [1, 2, 10],
61
+ "up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
62
+ "upcast_attention": False,
63
+ "use_linear_projection": True,
64
+ }
65
+
66
+
67
+ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
68
+ SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
69
+
70
+ # SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
71
+ # logit_scaleはcheckpointの保存時に使用する
72
+ def convert_key(key):
73
+ # common conversion
74
+ key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
75
+ key = key.replace(SDXL_KEY_PREFIX, "text_model.")
76
+
77
+ if "resblocks" in key:
78
+ # resblocks conversion
79
+ key = key.replace(".resblocks.", ".layers.")
80
+ if ".ln_" in key:
81
+ key = key.replace(".ln_", ".layer_norm")
82
+ elif ".mlp." in key:
83
+ key = key.replace(".c_fc.", ".fc1.")
84
+ key = key.replace(".c_proj.", ".fc2.")
85
+ elif ".attn.out_proj" in key:
86
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
87
+ elif ".attn.in_proj" in key:
88
+ key = None # 特殊なので後で処理する
89
+ else:
90
+ raise ValueError(f"unexpected key in SD: {key}")
91
+ elif ".positional_embedding" in key:
92
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
93
+ elif ".text_projection" in key:
94
+ key = key.replace("text_model.text_projection", "text_projection.weight")
95
+ elif ".logit_scale" in key:
96
+ key = None # 後で処理する
97
+ elif ".token_embedding" in key:
98
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
99
+ elif ".ln_final" in key:
100
+ key = key.replace(".ln_final", ".final_layer_norm")
101
+ # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
102
+ elif ".embeddings.position_ids" in key:
103
+ key = None # remove this key: make position_ids by ourselves
104
+ return key
105
+
106
+ keys = list(checkpoint.keys())
107
+ new_sd = {}
108
+ for key in keys:
109
+ new_key = convert_key(key)
110
+ if new_key is None:
111
+ continue
112
+ new_sd[new_key] = checkpoint[key]
113
+
114
+ # attnの変換
115
+ for key in keys:
116
+ if ".resblocks" in key and ".attn.in_proj_" in key:
117
+ # 三つに分割
118
+ values = torch.chunk(checkpoint[key], 3)
119
+
120
+ key_suffix = ".weight" if "weight" in key else ".bias"
121
+ key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
122
+ key_pfx = key_pfx.replace("_weight", "")
123
+ key_pfx = key_pfx.replace("_bias", "")
124
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
125
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
126
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
127
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
128
+
129
+ # original SD にはないので、position_idsを追加
130
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
131
+ new_sd["text_model.embeddings.position_ids"] = position_ids
132
+
133
+ # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
134
+ logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
135
+
136
+ # temporary workaround for text_projection.weight.weight for Playground-v2
137
+ if "text_projection.weight.weight" in new_sd:
138
+ print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
139
+ new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
140
+ del new_sd["text_projection.weight.weight"]
141
+
142
+ return new_sd, logit_scale
143
+
144
+
145
+ # load state_dict without allocating new tensors
146
+ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
147
+ # dtype will use fp32 as default
148
+ missing_keys = list(model.state_dict().keys() - state_dict.keys())
149
+ unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
150
+
151
+ # similar to model.load_state_dict()
152
+ if not missing_keys and not unexpected_keys:
153
+ for k in list(state_dict.keys()):
154
+ set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
155
+ return "<All keys matched successfully>"
156
+
157
+ # error_msgs
158
+ error_msgs: List[str] = []
159
+ if missing_keys:
160
+ error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
161
+ if unexpected_keys:
162
+ error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
163
+
164
+ raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
165
+
166
+
167
+ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
168
+ # model_version is reserved for future use
169
+ # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
170
+
171
+ # Load the state dict
172
+ if model_util.is_safetensors(ckpt_path):
173
+ checkpoint = None
174
+ try:
175
+ state_dict = load_file(ckpt_path, device=map_location)
176
+ except:
177
+ state_dict = load_file(ckpt_path) # prevent device invalid Error
178
+ epoch = None
179
+ global_step = None
180
+ else:
181
+ checkpoint = torch.load(ckpt_path, map_location=map_location)
182
+ if "state_dict" in checkpoint:
183
+ state_dict = checkpoint["state_dict"]
184
+ epoch = checkpoint.get("epoch", 0)
185
+ global_step = checkpoint.get("global_step", 0)
186
+ else:
187
+ state_dict = checkpoint
188
+ epoch = 0
189
+ global_step = 0
190
+ checkpoint = None
191
+
192
+ # U-Net
193
+ print("building U-Net")
194
+ with init_empty_weights():
195
+ unet = sdxl_original_unet.SdxlUNet2DConditionModel()
196
+
197
+ print("loading U-Net from checkpoint")
198
+ unet_sd = {}
199
+ for k in list(state_dict.keys()):
200
+ if k.startswith("model.diffusion_model."):
201
+ unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
202
+ info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
203
+ print("U-Net: ", info)
204
+
205
+ # Text Encoders
206
+ print("building text encoders")
207
+
208
+ # Text Encoder 1 is same to Stability AI's SDXL
209
+ text_model1_cfg = CLIPTextConfig(
210
+ vocab_size=49408,
211
+ hidden_size=768,
212
+ intermediate_size=3072,
213
+ num_hidden_layers=12,
214
+ num_attention_heads=12,
215
+ max_position_embeddings=77,
216
+ hidden_act="quick_gelu",
217
+ layer_norm_eps=1e-05,
218
+ dropout=0.0,
219
+ attention_dropout=0.0,
220
+ initializer_range=0.02,
221
+ initializer_factor=1.0,
222
+ pad_token_id=1,
223
+ bos_token_id=0,
224
+ eos_token_id=2,
225
+ model_type="clip_text_model",
226
+ projection_dim=768,
227
+ # torch_dtype="float32",
228
+ # transformers_version="4.25.0.dev0",
229
+ )
230
+ with init_empty_weights():
231
+ text_model1 = CLIPTextModel._from_config(text_model1_cfg)
232
+
233
+ # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
234
+ # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
235
+ text_model2_cfg = CLIPTextConfig(
236
+ vocab_size=49408,
237
+ hidden_size=1280,
238
+ intermediate_size=5120,
239
+ num_hidden_layers=32,
240
+ num_attention_heads=20,
241
+ max_position_embeddings=77,
242
+ hidden_act="gelu",
243
+ layer_norm_eps=1e-05,
244
+ dropout=0.0,
245
+ attention_dropout=0.0,
246
+ initializer_range=0.02,
247
+ initializer_factor=1.0,
248
+ pad_token_id=1,
249
+ bos_token_id=0,
250
+ eos_token_id=2,
251
+ model_type="clip_text_model",
252
+ projection_dim=1280,
253
+ # torch_dtype="float32",
254
+ # transformers_version="4.25.0.dev0",
255
+ )
256
+ with init_empty_weights():
257
+ text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
258
+
259
+ print("loading text encoders from checkpoint")
260
+ te1_sd = {}
261
+ te2_sd = {}
262
+ for k in list(state_dict.keys()):
263
+ if k.startswith("conditioner.embedders.0.transformer."):
264
+ te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
265
+ elif k.startswith("conditioner.embedders.1.model."):
266
+ te2_sd[k] = state_dict.pop(k)
267
+
268
+ # 一部のposition_idsがないモデルへの対応 / add position_ids for some models
269
+ if "text_model.embeddings.position_ids" not in te1_sd:
270
+ te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
271
+
272
+ info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
273
+ print("text encoder 1:", info1)
274
+
275
+ converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
276
+ info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
277
+ print("text encoder 2:", info2)
278
+
279
+ # prepare vae
280
+ print("building VAE")
281
+ vae_config = model_util.create_vae_diffusers_config()
282
+ with init_empty_weights():
283
+ vae = AutoencoderKL(**vae_config)
284
+
285
+ print("loading VAE from checkpoint")
286
+ converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
287
+ info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
288
+ print("VAE:", info)
289
+
290
+ ckpt_info = (epoch, global_step) if epoch is not None else None
291
+ return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
292
+
293
+
294
+ def make_unet_conversion_map():
295
+ unet_conversion_map_layer = []
296
+
297
+ for i in range(3): # num_blocks is 3 in sdxl
298
+ # loop over downblocks/upblocks
299
+ for j in range(2):
300
+ # loop over resnets/attentions for downblocks
301
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
302
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
303
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
304
+
305
+ if i < 3:
306
+ # no attention layers in down_blocks.3
307
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
308
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
309
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
310
+
311
+ for j in range(3):
312
+ # loop over resnets/attentions for upblocks
313
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
314
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
315
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
316
+
317
+ # if i > 0: commentout for sdxl
318
+ # no attention layers in up_blocks.0
319
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
320
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
321
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
322
+
323
+ if i < 3:
324
+ # no downsample in down_blocks.3
325
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
326
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
327
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
328
+
329
+ # no upsample in up_blocks.3
330
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
331
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
332
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
333
+
334
+ hf_mid_atn_prefix = "mid_block.attentions.0."
335
+ sd_mid_atn_prefix = "middle_block.1."
336
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
337
+
338
+ for j in range(2):
339
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
340
+ sd_mid_res_prefix = f"middle_block.{2*j}."
341
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
342
+
343
+ unet_conversion_map_resnet = [
344
+ # (stable-diffusion, HF Diffusers)
345
+ ("in_layers.0.", "norm1."),
346
+ ("in_layers.2.", "conv1."),
347
+ ("out_layers.0.", "norm2."),
348
+ ("out_layers.3.", "conv2."),
349
+ ("emb_layers.1.", "time_emb_proj."),
350
+ ("skip_connection.", "conv_shortcut."),
351
+ ]
352
+
353
+ unet_conversion_map = []
354
+ for sd, hf in unet_conversion_map_layer:
355
+ if "resnets" in hf:
356
+ for sd_res, hf_res in unet_conversion_map_resnet:
357
+ unet_conversion_map.append((sd + sd_res, hf + hf_res))
358
+ else:
359
+ unet_conversion_map.append((sd, hf))
360
+
361
+ for j in range(2):
362
+ hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
363
+ sd_time_embed_prefix = f"time_embed.{j*2}."
364
+ unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
365
+
366
+ for j in range(2):
367
+ hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
368
+ sd_label_embed_prefix = f"label_emb.0.{j*2}."
369
+ unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
370
+
371
+ unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
372
+ unet_conversion_map.append(("out.0.", "conv_norm_out."))
373
+ unet_conversion_map.append(("out.2.", "conv_out."))
374
+
375
+ return unet_conversion_map
376
+
377
+
378
+ def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
379
+ unet_conversion_map = make_unet_conversion_map()
380
+
381
+ conversion_map = {hf: sd for sd, hf in unet_conversion_map}
382
+ return convert_unet_state_dict(du_sd, conversion_map)
383
+
384
+
385
+ def convert_unet_state_dict(src_sd, conversion_map):
386
+ converted_sd = {}
387
+ for src_key, value in src_sd.items():
388
+ # さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す
389
+ src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
390
+ while len(src_key_fragments) > 0:
391
+ src_key_prefix = ".".join(src_key_fragments) + "."
392
+ if src_key_prefix in conversion_map:
393
+ converted_prefix = conversion_map[src_key_prefix]
394
+ converted_key = converted_prefix + src_key[len(src_key_prefix) :]
395
+ converted_sd[converted_key] = value
396
+ break
397
+ src_key_fragments.pop(-1)
398
+ assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
399
+
400
+ return converted_sd
401
+
402
+
403
+ def convert_sdxl_unet_state_dict_to_diffusers(sd):
404
+ unet_conversion_map = make_unet_conversion_map()
405
+
406
+ conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
407
+ return convert_unet_state_dict(sd, conversion_dict)
408
+
409
+
410
+ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
411
+ def convert_key(key):
412
+ # position_idsの除去
413
+ if ".position_ids" in key:
414
+ return None
415
+
416
+ # common
417
+ key = key.replace("text_model.encoder.", "transformer.")
418
+ key = key.replace("text_model.", "")
419
+ if "layers" in key:
420
+ # resblocks conversion
421
+ key = key.replace(".layers.", ".resblocks.")
422
+ if ".layer_norm" in key:
423
+ key = key.replace(".layer_norm", ".ln_")
424
+ elif ".mlp." in key:
425
+ key = key.replace(".fc1.", ".c_fc.")
426
+ key = key.replace(".fc2.", ".c_proj.")
427
+ elif ".self_attn.out_proj" in key:
428
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
429
+ elif ".self_attn." in key:
430
+ key = None # 特殊なので後で処理する
431
+ else:
432
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
433
+ elif ".position_embedding" in key:
434
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
435
+ elif ".token_embedding" in key:
436
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
437
+ elif "text_projection" in key: # no dot in key
438
+ key = key.replace("text_projection.weight", "text_projection")
439
+ elif "final_layer_norm" in key:
440
+ key = key.replace("final_layer_norm", "ln_final")
441
+ return key
442
+
443
+ keys = list(checkpoint.keys())
444
+ new_sd = {}
445
+ for key in keys:
446
+ new_key = convert_key(key)
447
+ if new_key is None:
448
+ continue
449
+ new_sd[new_key] = checkpoint[key]
450
+
451
+ # attnの変換
452
+ for key in keys:
453
+ if "layers" in key and "q_proj" in key:
454
+ # 三つを結合
455
+ key_q = key
456
+ key_k = key.replace("q_proj", "k_proj")
457
+ key_v = key.replace("q_proj", "v_proj")
458
+
459
+ value_q = checkpoint[key_q]
460
+ value_k = checkpoint[key_k]
461
+ value_v = checkpoint[key_v]
462
+ value = torch.cat([value_q, value_k, value_v])
463
+
464
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
465
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
466
+ new_sd[new_key] = value
467
+
468
+ if logit_scale is not None:
469
+ new_sd["logit_scale"] = logit_scale
470
+
471
+ return new_sd
472
+
473
+
474
+ def save_stable_diffusion_checkpoint(
475
+ output_file,
476
+ text_encoder1,
477
+ text_encoder2,
478
+ unet,
479
+ epochs,
480
+ steps,
481
+ ckpt_info,
482
+ vae,
483
+ logit_scale,
484
+ metadata,
485
+ save_dtype=None,
486
+ ):
487
+ state_dict = {}
488
+
489
+ def update_sd(prefix, sd):
490
+ for k, v in sd.items():
491
+ key = prefix + k
492
+ if save_dtype is not None:
493
+ v = v.detach().clone().to("cpu").to(save_dtype)
494
+ state_dict[key] = v
495
+
496
+ # Convert the UNet model
497
+ update_sd("model.diffusion_model.", unet.state_dict())
498
+
499
+ # Convert the text encoders
500
+ update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
501
+
502
+ text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
503
+ update_sd("conditioner.embedders.1.model.", text_enc2_dict)
504
+
505
+ # Convert the VAE
506
+ vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
507
+ update_sd("first_stage_model.", vae_dict)
508
+
509
+ # Put together new checkpoint
510
+ key_count = len(state_dict.keys())
511
+ new_ckpt = {"state_dict": state_dict}
512
+
513
+ # epoch and global_step are sometimes not int
514
+ if ckpt_info is not None:
515
+ epochs += ckpt_info[0]
516
+ steps += ckpt_info[1]
517
+
518
+ new_ckpt["epoch"] = epochs
519
+ new_ckpt["global_step"] = steps
520
+
521
+ if model_util.is_safetensors(output_file):
522
+ save_file(state_dict, output_file, metadata)
523
+ else:
524
+ torch.save(new_ckpt, output_file)
525
+
526
+ return key_count
527
+
528
+
529
+ def save_diffusers_checkpoint(
530
+ output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
531
+ ):
532
+ from diffusers import StableDiffusionXLPipeline
533
+
534
+ # convert U-Net
535
+ unet_sd = unet.state_dict()
536
+ du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
537
+
538
+ diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
539
+ if save_dtype is not None:
540
+ diffusers_unet.to(save_dtype)
541
+ diffusers_unet.load_state_dict(du_unet_sd)
542
+
543
+ # create pipeline to save
544
+ if pretrained_model_name_or_path is None:
545
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
546
+
547
+ scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
548
+ tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
549
+ tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
550
+ if vae is None:
551
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
552
+
553
+ # prevent local path from being saved
554
+ def remove_name_or_path(model):
555
+ if hasattr(model, "config"):
556
+ model.config._name_or_path = None
557
+ model.config._name_or_path = None
558
+
559
+ remove_name_or_path(diffusers_unet)
560
+ remove_name_or_path(text_encoder1)
561
+ remove_name_or_path(text_encoder2)
562
+ remove_name_or_path(scheduler)
563
+ remove_name_or_path(tokenizer1)
564
+ remove_name_or_path(tokenizer2)
565
+ remove_name_or_path(vae)
566
+
567
+ pipeline = StableDiffusionXLPipeline(
568
+ unet=diffusers_unet,
569
+ text_encoder=text_encoder1,
570
+ text_encoder_2=text_encoder2,
571
+ vae=vae,
572
+ scheduler=scheduler,
573
+ tokenizer=tokenizer1,
574
+ tokenizer_2=tokenizer2,
575
+ )
576
+ if save_dtype is not None:
577
+ pipeline.to(None, save_dtype)
578
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
external/llite/library/sdxl_original_unet.py ADDED
@@ -0,0 +1,1281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusersのコードをベースとした sd_xl_baseのU-Net
2
+ # state dictの形式をSDXLに合わせてある
3
+
4
+ """
5
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
6
+ params:
7
+ adm_in_channels: 2816
8
+ num_classes: sequential
9
+ use_checkpoint: True
10
+ in_channels: 4
11
+ out_channels: 4
12
+ model_channels: 320
13
+ attention_resolutions: [4, 2]
14
+ num_res_blocks: 2
15
+ channel_mult: [1, 2, 4]
16
+ num_head_channels: 64
17
+ use_spatial_transformer: True
18
+ use_linear_in_transformer: True
19
+ transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
20
+ context_dim: 2048
21
+ spatial_transformer_attn_type: softmax-xformers
22
+ legacy: False
23
+ """
24
+
25
+ import math
26
+ from types import SimpleNamespace
27
+ from typing import Any, Optional
28
+ import torch
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import functional as F
32
+ from einops import rearrange
33
+
34
+
35
+ IN_CHANNELS: int = 4
36
+ OUT_CHANNELS: int = 4
37
+ ADM_IN_CHANNELS: int = 2816
38
+ CONTEXT_DIM: int = 2048
39
+ MODEL_CHANNELS: int = 320
40
+ TIME_EMBED_DIM = 320 * 4
41
+
42
+ USE_REENTRANT = True
43
+
44
+ # region memory efficient attention
45
+
46
+ # FlashAttentionを使うCrossAttention
47
+ # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
48
+ # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
49
+
50
+ # constants
51
+
52
+ EPSILON = 1e-6
53
+
54
+ # helper functions
55
+
56
+
57
+ def exists(val):
58
+ return val is not None
59
+
60
+
61
+ def default(val, d):
62
+ return val if exists(val) else d
63
+
64
+
65
+ # flash attention forwards and backwards
66
+
67
+ # https://arxiv.org/abs/2205.14135
68
+
69
+
70
+ class FlashAttentionFunction(torch.autograd.Function):
71
+ @staticmethod
72
+ @torch.no_grad()
73
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
74
+ """Algorithm 2 in the paper"""
75
+
76
+ device = q.device
77
+ dtype = q.dtype
78
+ max_neg_value = -torch.finfo(q.dtype).max
79
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
80
+
81
+ o = torch.zeros_like(q)
82
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
83
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
84
+
85
+ scale = q.shape[-1] ** -0.5
86
+
87
+ if not exists(mask):
88
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
89
+ else:
90
+ mask = rearrange(mask, "b n -> b 1 1 n")
91
+ mask = mask.split(q_bucket_size, dim=-1)
92
+
93
+ row_splits = zip(
94
+ q.split(q_bucket_size, dim=-2),
95
+ o.split(q_bucket_size, dim=-2),
96
+ mask,
97
+ all_row_sums.split(q_bucket_size, dim=-2),
98
+ all_row_maxes.split(q_bucket_size, dim=-2),
99
+ )
100
+
101
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
102
+ q_start_index = ind * q_bucket_size - qk_len_diff
103
+
104
+ col_splits = zip(
105
+ k.split(k_bucket_size, dim=-2),
106
+ v.split(k_bucket_size, dim=-2),
107
+ )
108
+
109
+ for k_ind, (kc, vc) in enumerate(col_splits):
110
+ k_start_index = k_ind * k_bucket_size
111
+
112
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
113
+
114
+ if exists(row_mask):
115
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
116
+
117
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
118
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
119
+ q_start_index - k_start_index + 1
120
+ )
121
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
122
+
123
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
124
+ attn_weights -= block_row_maxes
125
+ exp_weights = torch.exp(attn_weights)
126
+
127
+ if exists(row_mask):
128
+ exp_weights.masked_fill_(~row_mask, 0.0)
129
+
130
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
131
+
132
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
133
+
134
+ exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
135
+
136
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
137
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
138
+
139
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
140
+
141
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
142
+
143
+ row_maxes.copy_(new_row_maxes)
144
+ row_sums.copy_(new_row_sums)
145
+
146
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
147
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
148
+
149
+ return o
150
+
151
+ @staticmethod
152
+ @torch.no_grad()
153
+ def backward(ctx, do):
154
+ """Algorithm 4 in the paper"""
155
+
156
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
157
+ q, k, v, o, l, m = ctx.saved_tensors
158
+
159
+ device = q.device
160
+
161
+ max_neg_value = -torch.finfo(q.dtype).max
162
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
163
+
164
+ dq = torch.zeros_like(q)
165
+ dk = torch.zeros_like(k)
166
+ dv = torch.zeros_like(v)
167
+
168
+ row_splits = zip(
169
+ q.split(q_bucket_size, dim=-2),
170
+ o.split(q_bucket_size, dim=-2),
171
+ do.split(q_bucket_size, dim=-2),
172
+ mask,
173
+ l.split(q_bucket_size, dim=-2),
174
+ m.split(q_bucket_size, dim=-2),
175
+ dq.split(q_bucket_size, dim=-2),
176
+ )
177
+
178
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
179
+ q_start_index = ind * q_bucket_size - qk_len_diff
180
+
181
+ col_splits = zip(
182
+ k.split(k_bucket_size, dim=-2),
183
+ v.split(k_bucket_size, dim=-2),
184
+ dk.split(k_bucket_size, dim=-2),
185
+ dv.split(k_bucket_size, dim=-2),
186
+ )
187
+
188
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
189
+ k_start_index = k_ind * k_bucket_size
190
+
191
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
192
+
193
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
194
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
195
+ q_start_index - k_start_index + 1
196
+ )
197
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
198
+
199
+ exp_attn_weights = torch.exp(attn_weights - mc)
200
+
201
+ if exists(row_mask):
202
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
203
+
204
+ p = exp_attn_weights / lc
205
+
206
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
207
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
208
+
209
+ D = (doc * oc).sum(dim=-1, keepdims=True)
210
+ ds = p * scale * (dp - D)
211
+
212
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
213
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
214
+
215
+ dqc.add_(dq_chunk)
216
+ dkc.add_(dk_chunk)
217
+ dvc.add_(dv_chunk)
218
+
219
+ return dq, dk, dv, None, None, None, None
220
+
221
+
222
+ # endregion
223
+
224
+
225
+ def get_parameter_dtype(parameter: torch.nn.Module):
226
+ return next(parameter.parameters()).dtype
227
+
228
+
229
+ def get_parameter_device(parameter: torch.nn.Module):
230
+ return next(parameter.parameters()).device
231
+
232
+
233
+ def get_timestep_embedding(
234
+ timesteps: torch.Tensor,
235
+ embedding_dim: int,
236
+ downscale_freq_shift: float = 1,
237
+ scale: float = 1,
238
+ max_period: int = 10000,
239
+ ):
240
+ """
241
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
242
+
243
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
244
+ These may be fractional.
245
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
246
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
247
+ """
248
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
249
+
250
+ half_dim = embedding_dim // 2
251
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
252
+ exponent = exponent / (half_dim - downscale_freq_shift)
253
+
254
+ emb = torch.exp(exponent)
255
+ emb = timesteps[:, None].float() * emb[None, :]
256
+
257
+ # scale embeddings
258
+ emb = scale * emb
259
+
260
+ # concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True
261
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
262
+
263
+ # zero pad
264
+ if embedding_dim % 2 == 1:
265
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
266
+ return emb
267
+
268
+
269
+ # Deep Shrink: We do not common this function, because minimize dependencies.
270
+ def resize_like(x, target, mode="bicubic", align_corners=False):
271
+ org_dtype = x.dtype
272
+ if org_dtype == torch.bfloat16:
273
+ x = x.to(torch.float32)
274
+
275
+ if x.shape[-2:] != target.shape[-2:]:
276
+ if mode == "nearest":
277
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode)
278
+ else:
279
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
280
+
281
+ if org_dtype == torch.bfloat16:
282
+ x = x.to(org_dtype)
283
+ return x
284
+
285
+
286
+ class GroupNorm32(nn.GroupNorm):
287
+ def forward(self, x):
288
+ if self.weight.dtype != torch.float32:
289
+ return super().forward(x)
290
+ return super().forward(x.float()).type(x.dtype)
291
+
292
+
293
+ class ResnetBlock2D(nn.Module):
294
+ def __init__(
295
+ self,
296
+ in_channels,
297
+ out_channels,
298
+ ):
299
+ super().__init__()
300
+ self.in_channels = in_channels
301
+ self.out_channels = out_channels
302
+
303
+ self.in_layers = nn.Sequential(
304
+ GroupNorm32(32, in_channels),
305
+ nn.SiLU(),
306
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
307
+ )
308
+
309
+ self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels))
310
+
311
+ self.out_layers = nn.Sequential(
312
+ GroupNorm32(32, out_channels),
313
+ nn.SiLU(),
314
+ nn.Identity(), # to make state_dict compatible with original model
315
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
316
+ )
317
+
318
+ if in_channels != out_channels:
319
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
320
+ else:
321
+ self.skip_connection = nn.Identity()
322
+
323
+ self.gradient_checkpointing = False
324
+
325
+ def forward_body(self, x, emb):
326
+ h = self.in_layers(x)
327
+ emb_out = self.emb_layers(emb).type(h.dtype)
328
+ h = h + emb_out[:, :, None, None]
329
+ h = self.out_layers(h)
330
+ x = self.skip_connection(x)
331
+ return x + h
332
+
333
+ def forward(self, x, emb):
334
+ if self.training and self.gradient_checkpointing:
335
+ # print("ResnetBlock2D: gradient_checkpointing")
336
+
337
+ def create_custom_forward(func):
338
+ def custom_forward(*inputs):
339
+ return func(*inputs)
340
+
341
+ return custom_forward
342
+
343
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT)
344
+ else:
345
+ x = self.forward_body(x, emb)
346
+
347
+ return x
348
+
349
+
350
+ class Downsample2D(nn.Module):
351
+ def __init__(self, channels, out_channels):
352
+ super().__init__()
353
+
354
+ self.channels = channels
355
+ self.out_channels = out_channels
356
+
357
+ self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
358
+
359
+ self.gradient_checkpointing = False
360
+
361
+ def forward_body(self, hidden_states):
362
+ assert hidden_states.shape[1] == self.channels
363
+ hidden_states = self.op(hidden_states)
364
+
365
+ return hidden_states
366
+
367
+ def forward(self, hidden_states):
368
+ if self.training and self.gradient_checkpointing:
369
+ # print("Downsample2D: gradient_checkpointing")
370
+
371
+ def create_custom_forward(func):
372
+ def custom_forward(*inputs):
373
+ return func(*inputs)
374
+
375
+ return custom_forward
376
+
377
+ hidden_states = torch.utils.checkpoint.checkpoint(
378
+ create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT
379
+ )
380
+ else:
381
+ hidden_states = self.forward_body(hidden_states)
382
+
383
+ return hidden_states
384
+
385
+
386
+ class CrossAttention(nn.Module):
387
+ def __init__(
388
+ self,
389
+ query_dim: int,
390
+ cross_attention_dim: Optional[int] = None,
391
+ heads: int = 8,
392
+ dim_head: int = 64,
393
+ upcast_attention: bool = False,
394
+ ):
395
+ super().__init__()
396
+ inner_dim = dim_head * heads
397
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
398
+ self.upcast_attention = upcast_attention
399
+
400
+ self.scale = dim_head**-0.5
401
+ self.heads = heads
402
+
403
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
404
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
405
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
406
+
407
+ self.to_out = nn.ModuleList([])
408
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
409
+ # no dropout here
410
+
411
+ self.use_memory_efficient_attention_xformers = False
412
+ self.use_memory_efficient_attention_mem_eff = False
413
+ self.use_sdpa = False
414
+
415
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
416
+ self.use_memory_efficient_attention_xformers = xformers
417
+ self.use_memory_efficient_attention_mem_eff = mem_eff
418
+
419
+ def set_use_sdpa(self, sdpa):
420
+ self.use_sdpa = sdpa
421
+
422
+ def reshape_heads_to_batch_dim(self, tensor):
423
+ batch_size, seq_len, dim = tensor.shape
424
+ head_size = self.heads
425
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
426
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
427
+ return tensor
428
+
429
+ def reshape_batch_dim_to_heads(self, tensor):
430
+ batch_size, seq_len, dim = tensor.shape
431
+ head_size = self.heads
432
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
433
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
434
+ return tensor
435
+
436
+ def forward(self, hidden_states, context=None, mask=None):
437
+ if self.use_memory_efficient_attention_xformers:
438
+ return self.forward_memory_efficient_xformers(hidden_states, context, mask)
439
+ if self.use_memory_efficient_attention_mem_eff:
440
+ return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
441
+ if self.use_sdpa:
442
+ return self.forward_sdpa(hidden_states, context, mask)
443
+
444
+ query = self.to_q(hidden_states)
445
+ context = context if context is not None else hidden_states
446
+ key = self.to_k(context)
447
+ value = self.to_v(context)
448
+
449
+ query = self.reshape_heads_to_batch_dim(query)
450
+ key = self.reshape_heads_to_batch_dim(key)
451
+ value = self.reshape_heads_to_batch_dim(value)
452
+
453
+ hidden_states = self._attention(query, key, value)
454
+
455
+ # linear proj
456
+ hidden_states = self.to_out[0](hidden_states)
457
+ # hidden_states = self.to_out[1](hidden_states) # no dropout
458
+ return hidden_states
459
+
460
+ def _attention(self, query, key, value):
461
+ if self.upcast_attention:
462
+ query = query.float()
463
+ key = key.float()
464
+
465
+ attention_scores = torch.baddbmm(
466
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
467
+ query,
468
+ key.transpose(-1, -2),
469
+ beta=0,
470
+ alpha=self.scale,
471
+ )
472
+ attention_probs = attention_scores.softmax(dim=-1)
473
+
474
+ # cast back to the original dtype
475
+ attention_probs = attention_probs.to(value.dtype)
476
+
477
+ # compute attention output
478
+ hidden_states = torch.bmm(attention_probs, value)
479
+
480
+ # reshape hidden_states
481
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
482
+ return hidden_states
483
+
484
+ # TODO support Hypernetworks
485
+ def forward_memory_efficient_xformers(self, x, context=None, mask=None):
486
+ import xformers.ops
487
+
488
+ h = self.heads
489
+ q_in = self.to_q(x)
490
+ context = context if context is not None else x
491
+ context = context.to(x.dtype)
492
+ k_in = self.to_k(context)
493
+ v_in = self.to_v(context)
494
+
495
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
496
+ del q_in, k_in, v_in
497
+
498
+ q = q.contiguous()
499
+ k = k.contiguous()
500
+ v = v.contiguous()
501
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
502
+ del q, k, v
503
+
504
+ out = rearrange(out, "b n h d -> b n (h d)", h=h)
505
+
506
+ out = self.to_out[0](out)
507
+ return out
508
+
509
+ def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
510
+ flash_func = FlashAttentionFunction
511
+
512
+ q_bucket_size = 512
513
+ k_bucket_size = 1024
514
+
515
+ h = self.heads
516
+ q = self.to_q(x)
517
+ context = context if context is not None else x
518
+ context = context.to(x.dtype)
519
+ k = self.to_k(context)
520
+ v = self.to_v(context)
521
+ del context, x
522
+
523
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
524
+
525
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
526
+
527
+ out = rearrange(out, "b h n d -> b n (h d)")
528
+
529
+ out = self.to_out[0](out)
530
+ return out
531
+
532
+ def forward_sdpa(self, x, context=None, mask=None):
533
+ h = self.heads
534
+ q_in = self.to_q(x)
535
+ context = context if context is not None else x
536
+ context = context.to(x.dtype)
537
+ k_in = self.to_k(context)
538
+ v_in = self.to_v(context)
539
+
540
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
541
+ del q_in, k_in, v_in
542
+
543
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
544
+
545
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
546
+
547
+ out = self.to_out[0](out)
548
+ return out
549
+
550
+
551
+ # feedforward
552
+ class GEGLU(nn.Module):
553
+ r"""
554
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
555
+
556
+ Parameters:
557
+ dim_in (`int`): The number of channels in the input.
558
+ dim_out (`int`): The number of channels in the output.
559
+ """
560
+
561
+ def __init__(self, dim_in: int, dim_out: int):
562
+ super().__init__()
563
+ self.proj = nn.Linear(dim_in, dim_out * 2)
564
+
565
+ def gelu(self, gate):
566
+ if gate.device.type != "mps":
567
+ return F.gelu(gate)
568
+ # mps: gelu is not implemented for float16
569
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
570
+
571
+ def forward(self, hidden_states):
572
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
573
+ return hidden_states * self.gelu(gate)
574
+
575
+
576
+ class FeedForward(nn.Module):
577
+ def __init__(
578
+ self,
579
+ dim: int,
580
+ ):
581
+ super().__init__()
582
+ inner_dim = int(dim * 4) # mult is always 4
583
+
584
+ self.net = nn.ModuleList([])
585
+ # project in
586
+ self.net.append(GEGLU(dim, inner_dim))
587
+ # project dropout
588
+ self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
589
+ # project out
590
+ self.net.append(nn.Linear(inner_dim, dim))
591
+
592
+ def forward(self, hidden_states):
593
+ for module in self.net:
594
+ hidden_states = module(hidden_states)
595
+ return hidden_states
596
+
597
+
598
+ class BasicTransformerBlock(nn.Module):
599
+ def __init__(
600
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
601
+ ):
602
+ super().__init__()
603
+
604
+ self.gradient_checkpointing = False
605
+
606
+ # 1. Self-Attn
607
+ self.attn1 = CrossAttention(
608
+ query_dim=dim,
609
+ cross_attention_dim=None,
610
+ heads=num_attention_heads,
611
+ dim_head=attention_head_dim,
612
+ upcast_attention=upcast_attention,
613
+ )
614
+ self.ff = FeedForward(dim)
615
+
616
+ # 2. Cross-Attn
617
+ self.attn2 = CrossAttention(
618
+ query_dim=dim,
619
+ cross_attention_dim=cross_attention_dim,
620
+ heads=num_attention_heads,
621
+ dim_head=attention_head_dim,
622
+ upcast_attention=upcast_attention,
623
+ )
624
+
625
+ self.norm1 = nn.LayerNorm(dim)
626
+ self.norm2 = nn.LayerNorm(dim)
627
+
628
+ # 3. Feed-forward
629
+ self.norm3 = nn.LayerNorm(dim)
630
+
631
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
632
+ self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
633
+ self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
634
+
635
+ def set_use_sdpa(self, sdpa: bool):
636
+ self.attn1.set_use_sdpa(sdpa)
637
+ self.attn2.set_use_sdpa(sdpa)
638
+
639
+ def forward_body(self, hidden_states, context=None, timestep=None):
640
+ # 1. Self-Attention
641
+ norm_hidden_states = self.norm1(hidden_states)
642
+
643
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
644
+
645
+ # 2. Cross-Attention
646
+ norm_hidden_states = self.norm2(hidden_states)
647
+ hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
648
+
649
+ # 3. Feed-forward
650
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
651
+
652
+ return hidden_states
653
+
654
+ def forward(self, hidden_states, context=None, timestep=None):
655
+ if self.training and self.gradient_checkpointing:
656
+ # print("BasicTransformerBlock: checkpointing")
657
+
658
+ def create_custom_forward(func):
659
+ def custom_forward(*inputs):
660
+ return func(*inputs)
661
+
662
+ return custom_forward
663
+
664
+ output = torch.utils.checkpoint.checkpoint(
665
+ create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT
666
+ )
667
+ else:
668
+ output = self.forward_body(hidden_states, context, timestep)
669
+
670
+ return output
671
+
672
+
673
+ class Transformer2DModel(nn.Module):
674
+ def __init__(
675
+ self,
676
+ num_attention_heads: int = 16,
677
+ attention_head_dim: int = 88,
678
+ in_channels: Optional[int] = None,
679
+ cross_attention_dim: Optional[int] = None,
680
+ use_linear_projection: bool = False,
681
+ upcast_attention: bool = False,
682
+ num_transformer_layers: int = 1,
683
+ ):
684
+ super().__init__()
685
+ self.in_channels = in_channels
686
+ self.num_attention_heads = num_attention_heads
687
+ self.attention_head_dim = attention_head_dim
688
+ inner_dim = num_attention_heads * attention_head_dim
689
+ self.use_linear_projection = use_linear_projection
690
+
691
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
692
+ # self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True)
693
+
694
+ if use_linear_projection:
695
+ self.proj_in = nn.Linear(in_channels, inner_dim)
696
+ else:
697
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
698
+
699
+ blocks = []
700
+ for _ in range(num_transformer_layers):
701
+ blocks.append(
702
+ BasicTransformerBlock(
703
+ inner_dim,
704
+ num_attention_heads,
705
+ attention_head_dim,
706
+ cross_attention_dim=cross_attention_dim,
707
+ upcast_attention=upcast_attention,
708
+ )
709
+ )
710
+
711
+ self.transformer_blocks = nn.ModuleList(blocks)
712
+
713
+ if use_linear_projection:
714
+ self.proj_out = nn.Linear(in_channels, inner_dim)
715
+ else:
716
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
717
+
718
+ self.gradient_checkpointing = False
719
+
720
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
721
+ for transformer in self.transformer_blocks:
722
+ transformer.set_use_memory_efficient_attention(xformers, mem_eff)
723
+
724
+ def set_use_sdpa(self, sdpa):
725
+ for transformer in self.transformer_blocks:
726
+ transformer.set_use_sdpa(sdpa)
727
+
728
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None):
729
+ # 1. Input
730
+ batch, _, height, weight = hidden_states.shape
731
+ residual = hidden_states
732
+
733
+ hidden_states = self.norm(hidden_states)
734
+ if not self.use_linear_projection:
735
+ hidden_states = self.proj_in(hidden_states)
736
+ inner_dim = hidden_states.shape[1]
737
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
738
+ else:
739
+ inner_dim = hidden_states.shape[1]
740
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
741
+ hidden_states = self.proj_in(hidden_states)
742
+
743
+ # 2. Blocks
744
+ for block in self.transformer_blocks:
745
+ hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
746
+
747
+ # 3. Output
748
+ if not self.use_linear_projection:
749
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
750
+ hidden_states = self.proj_out(hidden_states)
751
+ else:
752
+ hidden_states = self.proj_out(hidden_states)
753
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
754
+
755
+ output = hidden_states + residual
756
+
757
+ return output
758
+
759
+
760
+ class Upsample2D(nn.Module):
761
+ def __init__(self, channels, out_channels):
762
+ super().__init__()
763
+ self.channels = channels
764
+ self.out_channels = out_channels
765
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
766
+
767
+ self.gradient_checkpointing = False
768
+
769
+ def forward_body(self, hidden_states, output_size=None):
770
+ assert hidden_states.shape[1] == self.channels
771
+
772
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
773
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
774
+ # https://github.com/pytorch/pytorch/issues/86679
775
+ dtype = hidden_states.dtype
776
+ if dtype == torch.bfloat16:
777
+ hidden_states = hidden_states.to(torch.float32)
778
+
779
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
780
+ if hidden_states.shape[0] >= 64:
781
+ hidden_states = hidden_states.contiguous()
782
+
783
+ # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
784
+ if output_size is None:
785
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
786
+ else:
787
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
788
+
789
+ # If the input is bfloat16, we cast back to bfloat16
790
+ if dtype == torch.bfloat16:
791
+ hidden_states = hidden_states.to(dtype)
792
+
793
+ hidden_states = self.conv(hidden_states)
794
+
795
+ return hidden_states
796
+
797
+ def forward(self, hidden_states, output_size=None):
798
+ if self.training and self.gradient_checkpointing:
799
+ # print("Upsample2D: gradient_checkpointing")
800
+
801
+ def create_custom_forward(func):
802
+ def custom_forward(*inputs):
803
+ return func(*inputs)
804
+
805
+ return custom_forward
806
+
807
+ hidden_states = torch.utils.checkpoint.checkpoint(
808
+ create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT
809
+ )
810
+ else:
811
+ hidden_states = self.forward_body(hidden_states, output_size)
812
+
813
+ return hidden_states
814
+
815
+
816
+ class SdxlUNet2DConditionModel(nn.Module):
817
+ _supports_gradient_checkpointing = True
818
+
819
+ def __init__(
820
+ self,
821
+ **kwargs,
822
+ ):
823
+ super().__init__()
824
+
825
+ self.in_channels = IN_CHANNELS
826
+ self.out_channels = OUT_CHANNELS
827
+ self.model_channels = MODEL_CHANNELS
828
+ self.time_embed_dim = TIME_EMBED_DIM
829
+ self.adm_in_channels = ADM_IN_CHANNELS
830
+
831
+ self.gradient_checkpointing = False
832
+ # self.sample_size = sample_size
833
+
834
+ # time embedding
835
+ self.time_embed = nn.Sequential(
836
+ nn.Linear(self.model_channels, self.time_embed_dim),
837
+ nn.SiLU(),
838
+ nn.Linear(self.time_embed_dim, self.time_embed_dim),
839
+ )
840
+
841
+ # label embedding
842
+ self.label_emb = nn.Sequential(
843
+ nn.Sequential(
844
+ nn.Linear(self.adm_in_channels, self.time_embed_dim),
845
+ nn.SiLU(),
846
+ nn.Linear(self.time_embed_dim, self.time_embed_dim),
847
+ )
848
+ )
849
+
850
+ # input
851
+ self.input_blocks = nn.ModuleList(
852
+ [
853
+ nn.Sequential(
854
+ nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)),
855
+ )
856
+ ]
857
+ )
858
+
859
+ # level 0
860
+ for i in range(2):
861
+ layers = [
862
+ ResnetBlock2D(
863
+ in_channels=1 * self.model_channels,
864
+ out_channels=1 * self.model_channels,
865
+ ),
866
+ ]
867
+ self.input_blocks.append(nn.ModuleList(layers))
868
+
869
+ self.input_blocks.append(
870
+ nn.Sequential(
871
+ Downsample2D(
872
+ channels=1 * self.model_channels,
873
+ out_channels=1 * self.model_channels,
874
+ ),
875
+ )
876
+ )
877
+
878
+ # level 1
879
+ for i in range(2):
880
+ layers = [
881
+ ResnetBlock2D(
882
+ in_channels=(1 if i == 0 else 2) * self.model_channels,
883
+ out_channels=2 * self.model_channels,
884
+ ),
885
+ Transformer2DModel(
886
+ num_attention_heads=2 * self.model_channels // 64,
887
+ attention_head_dim=64,
888
+ in_channels=2 * self.model_channels,
889
+ num_transformer_layers=2,
890
+ use_linear_projection=True,
891
+ cross_attention_dim=2048,
892
+ ),
893
+ ]
894
+ self.input_blocks.append(nn.ModuleList(layers))
895
+
896
+ self.input_blocks.append(
897
+ nn.Sequential(
898
+ Downsample2D(
899
+ channels=2 * self.model_channels,
900
+ out_channels=2 * self.model_channels,
901
+ ),
902
+ )
903
+ )
904
+
905
+ # level 2
906
+ for i in range(2):
907
+ layers = [
908
+ ResnetBlock2D(
909
+ in_channels=(2 if i == 0 else 4) * self.model_channels,
910
+ out_channels=4 * self.model_channels,
911
+ ),
912
+ Transformer2DModel(
913
+ num_attention_heads=4 * self.model_channels // 64,
914
+ attention_head_dim=64,
915
+ in_channels=4 * self.model_channels,
916
+ num_transformer_layers=10,
917
+ use_linear_projection=True,
918
+ cross_attention_dim=2048,
919
+ ),
920
+ ]
921
+ self.input_blocks.append(nn.ModuleList(layers))
922
+
923
+ # mid
924
+ self.middle_block = nn.ModuleList(
925
+ [
926
+ ResnetBlock2D(
927
+ in_channels=4 * self.model_channels,
928
+ out_channels=4 * self.model_channels,
929
+ ),
930
+ Transformer2DModel(
931
+ num_attention_heads=4 * self.model_channels // 64,
932
+ attention_head_dim=64,
933
+ in_channels=4 * self.model_channels,
934
+ num_transformer_layers=10,
935
+ use_linear_projection=True,
936
+ cross_attention_dim=2048,
937
+ ),
938
+ ResnetBlock2D(
939
+ in_channels=4 * self.model_channels,
940
+ out_channels=4 * self.model_channels,
941
+ ),
942
+ ]
943
+ )
944
+
945
+ # output
946
+ self.output_blocks = nn.ModuleList([])
947
+
948
+ # level 2
949
+ for i in range(3):
950
+ layers = [
951
+ ResnetBlock2D(
952
+ in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels,
953
+ out_channels=4 * self.model_channels,
954
+ ),
955
+ Transformer2DModel(
956
+ num_attention_heads=4 * self.model_channels // 64,
957
+ attention_head_dim=64,
958
+ in_channels=4 * self.model_channels,
959
+ num_transformer_layers=10,
960
+ use_linear_projection=True,
961
+ cross_attention_dim=2048,
962
+ ),
963
+ ]
964
+ if i == 2:
965
+ layers.append(
966
+ Upsample2D(
967
+ channels=4 * self.model_channels,
968
+ out_channels=4 * self.model_channels,
969
+ )
970
+ )
971
+
972
+ self.output_blocks.append(nn.ModuleList(layers))
973
+
974
+ # level 1
975
+ for i in range(3):
976
+ layers = [
977
+ ResnetBlock2D(
978
+ in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels,
979
+ out_channels=2 * self.model_channels,
980
+ ),
981
+ Transformer2DModel(
982
+ num_attention_heads=2 * self.model_channels // 64,
983
+ attention_head_dim=64,
984
+ in_channels=2 * self.model_channels,
985
+ num_transformer_layers=2,
986
+ use_linear_projection=True,
987
+ cross_attention_dim=2048,
988
+ ),
989
+ ]
990
+ if i == 2:
991
+ layers.append(
992
+ Upsample2D(
993
+ channels=2 * self.model_channels,
994
+ out_channels=2 * self.model_channels,
995
+ )
996
+ )
997
+
998
+ self.output_blocks.append(nn.ModuleList(layers))
999
+
1000
+ # level 0
1001
+ for i in range(3):
1002
+ layers = [
1003
+ ResnetBlock2D(
1004
+ in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels,
1005
+ out_channels=1 * self.model_channels,
1006
+ ),
1007
+ ]
1008
+
1009
+ self.output_blocks.append(nn.ModuleList(layers))
1010
+
1011
+ # output
1012
+ self.out = nn.ModuleList(
1013
+ [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
1014
+ )
1015
+
1016
+ # region diffusers compatibility
1017
+ def prepare_config(self):
1018
+ self.config = SimpleNamespace()
1019
+
1020
+ @property
1021
+ def dtype(self) -> torch.dtype:
1022
+ # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1023
+ return get_parameter_dtype(self)
1024
+
1025
+ @property
1026
+ def device(self) -> torch.device:
1027
+ # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
1028
+ return get_parameter_device(self)
1029
+
1030
+ def set_attention_slice(self, slice_size):
1031
+ raise NotImplementedError("Attention slicing is not supported for this model.")
1032
+
1033
+ def is_gradient_checkpointing(self) -> bool:
1034
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
1035
+
1036
+ def enable_gradient_checkpointing(self):
1037
+ self.gradient_checkpointing = True
1038
+ self.set_gradient_checkpointing(value=True)
1039
+
1040
+ def disable_gradient_checkpointing(self):
1041
+ self.gradient_checkpointing = False
1042
+ self.set_gradient_checkpointing(value=False)
1043
+
1044
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
1045
+ blocks = self.input_blocks + [self.middle_block] + self.output_blocks
1046
+ for block in blocks:
1047
+ for module in block:
1048
+ if hasattr(module, "set_use_memory_efficient_attention"):
1049
+ # print(module.__class__.__name__)
1050
+ module.set_use_memory_efficient_attention(xformers, mem_eff)
1051
+
1052
+ def set_use_sdpa(self, sdpa: bool) -> None:
1053
+ blocks = self.input_blocks + [self.middle_block] + self.output_blocks
1054
+ for block in blocks:
1055
+ for module in block:
1056
+ if hasattr(module, "set_use_sdpa"):
1057
+ module.set_use_sdpa(sdpa)
1058
+
1059
+ def set_gradient_checkpointing(self, value=False):
1060
+ blocks = self.input_blocks + [self.middle_block] + self.output_blocks
1061
+ for block in blocks:
1062
+ for module in block.modules():
1063
+ if hasattr(module, "gradient_checkpointing"):
1064
+ # print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
1065
+ module.gradient_checkpointing = value
1066
+
1067
+ # endregion
1068
+
1069
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1070
+ # broadcast timesteps to batch dimension
1071
+ timesteps = timesteps.expand(x.shape[0])
1072
+
1073
+ hs = []
1074
+ t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
1075
+ t_emb = t_emb.to(x.dtype)
1076
+ emb = self.time_embed(t_emb)
1077
+
1078
+ assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
1079
+ assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
1080
+ # assert x.dtype == self.dtype
1081
+ emb = emb + self.label_emb(y)
1082
+
1083
+ def call_module(module, h, emb, context):
1084
+ x = h
1085
+ for layer in module:
1086
+ # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
1087
+ if isinstance(layer, ResnetBlock2D):
1088
+ x = layer(x, emb)
1089
+ elif isinstance(layer, Transformer2DModel):
1090
+ x = layer(x, context)
1091
+ else:
1092
+ x = layer(x)
1093
+ return x
1094
+
1095
+ # h = x.type(self.dtype)
1096
+ h = x
1097
+
1098
+ for module in self.input_blocks:
1099
+ h = call_module(module, h, emb, context)
1100
+ hs.append(h)
1101
+
1102
+ h = call_module(self.middle_block, h, emb, context)
1103
+
1104
+ for module in self.output_blocks:
1105
+ h = torch.cat([h, hs.pop()], dim=1)
1106
+ h = call_module(module, h, emb, context)
1107
+
1108
+ h = h.type(x.dtype)
1109
+ h = call_module(self.out, h, emb, context)
1110
+
1111
+ return h
1112
+
1113
+
1114
+ class InferSdxlUNet2DConditionModel:
1115
+ def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
1116
+ self.delegate = original_unet
1117
+
1118
+ # override original model's forward method: because forward is not called by `__call__`
1119
+ # overriding `__call__` is not enough, because nn.Module.forward has a special handling
1120
+ self.delegate.forward = self.forward
1121
+
1122
+ # Deep Shrink
1123
+ self.ds_depth_1 = None
1124
+ self.ds_depth_2 = None
1125
+ self.ds_timesteps_1 = None
1126
+ self.ds_timesteps_2 = None
1127
+ self.ds_ratio = None
1128
+
1129
+ # call original model's methods
1130
+ def __getattr__(self, name):
1131
+ return getattr(self.delegate, name)
1132
+
1133
+ def __call__(self, *args, **kwargs):
1134
+ return self.delegate(*args, **kwargs)
1135
+
1136
+ def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
1137
+ if ds_depth_1 is None:
1138
+ print("Deep Shrink is disabled.")
1139
+ self.ds_depth_1 = None
1140
+ self.ds_timesteps_1 = None
1141
+ self.ds_depth_2 = None
1142
+ self.ds_timesteps_2 = None
1143
+ self.ds_ratio = None
1144
+ else:
1145
+ print(
1146
+ f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
1147
+ )
1148
+ self.ds_depth_1 = ds_depth_1
1149
+ self.ds_timesteps_1 = ds_timesteps_1
1150
+ self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
1151
+ self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
1152
+ self.ds_ratio = ds_ratio
1153
+
1154
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1155
+ r"""
1156
+ current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
1157
+ """
1158
+ _self = self.delegate
1159
+
1160
+ # broadcast timesteps to batch dimension
1161
+ timesteps = timesteps.expand(x.shape[0])
1162
+
1163
+ hs = []
1164
+ t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
1165
+ t_emb = t_emb.to(x.dtype)
1166
+ emb = _self.time_embed(t_emb)
1167
+
1168
+ assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
1169
+ assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
1170
+ # assert x.dtype == _self.dtype
1171
+ emb = emb + _self.label_emb(y)
1172
+
1173
+ def call_module(module, h, emb, context):
1174
+ x = h
1175
+ for layer in module:
1176
+ # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
1177
+ if isinstance(layer, ResnetBlock2D):
1178
+ x = layer(x, emb)
1179
+ elif isinstance(layer, Transformer2DModel):
1180
+ x = layer(x, context)
1181
+ else:
1182
+ x = layer(x)
1183
+ return x
1184
+
1185
+ # h = x.type(self.dtype)
1186
+ h = x
1187
+
1188
+ for depth, module in enumerate(_self.input_blocks):
1189
+ # Deep Shrink
1190
+ if self.ds_depth_1 is not None:
1191
+ if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
1192
+ self.ds_depth_2 is not None
1193
+ and depth == self.ds_depth_2
1194
+ and timesteps[0] < self.ds_timesteps_1
1195
+ and timesteps[0] >= self.ds_timesteps_2
1196
+ ):
1197
+ # print("downsample", h.shape, self.ds_ratio)
1198
+ org_dtype = h.dtype
1199
+ if org_dtype == torch.bfloat16:
1200
+ h = h.to(torch.float32)
1201
+ h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
1202
+
1203
+ h = call_module(module, h, emb, context)
1204
+ hs.append(h)
1205
+
1206
+ h = call_module(_self.middle_block, h, emb, context)
1207
+
1208
+ for module in _self.output_blocks:
1209
+ # Deep Shrink
1210
+ if self.ds_depth_1 is not None:
1211
+ if hs[-1].shape[-2:] != h.shape[-2:]:
1212
+ # print("upsample", h.shape, hs[-1].shape)
1213
+ h = resize_like(h, hs[-1])
1214
+
1215
+ h = torch.cat([h, hs.pop()], dim=1)
1216
+ h = call_module(module, h, emb, context)
1217
+
1218
+ # Deep Shrink: in case of depth 0
1219
+ if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
1220
+ # print("upsample", h.shape, x.shape)
1221
+ h = resize_like(h, x)
1222
+
1223
+ h = h.type(x.dtype)
1224
+ h = call_module(_self.out, h, emb, context)
1225
+
1226
+ return h
1227
+
1228
+
1229
+ if __name__ == "__main__":
1230
+ import time
1231
+
1232
+ print("create unet")
1233
+ unet = SdxlUNet2DConditionModel()
1234
+
1235
+ unet.to("cuda")
1236
+ unet.set_use_memory_efficient_attention(True, False)
1237
+ unet.set_gradient_checkpointing(True)
1238
+ unet.train()
1239
+
1240
+ # 使用メモリ量確認用の疑似学習ループ
1241
+ print("preparing optimizer")
1242
+
1243
+ # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
1244
+
1245
+ # import bitsandbytes
1246
+ # optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working
1247
+ # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
1248
+ # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
1249
+
1250
+ import transformers
1251
+
1252
+ optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
1253
+
1254
+ scaler = torch.cuda.amp.GradScaler(enabled=True)
1255
+
1256
+ print("start training")
1257
+ steps = 10
1258
+ batch_size = 1
1259
+
1260
+ for step in range(steps):
1261
+ print(f"step {step}")
1262
+ if step == 1:
1263
+ time_start = time.perf_counter()
1264
+
1265
+ x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
1266
+ t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
1267
+ ctx = torch.randn(batch_size, 77, 2048).cuda()
1268
+ y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()
1269
+
1270
+ with torch.cuda.amp.autocast(enabled=True):
1271
+ output = unet(x, t, ctx, y)
1272
+ target = torch.randn_like(output)
1273
+ loss = torch.nn.functional.mse_loss(output, target)
1274
+
1275
+ scaler.scale(loss).backward()
1276
+ scaler.step(optimizer)
1277
+ scaler.update()
1278
+ optimizer.zero_grad(set_to_none=True)
1279
+
1280
+ time_end = time.perf_counter()
1281
+ print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
external/llite/library/sdxl_train_util.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import math
4
+ import os
5
+ from typing import Optional
6
+ import torch
7
+ from accelerate import init_empty_weights
8
+ from tqdm import tqdm
9
+ from transformers import CLIPTokenizer
10
+ from external.llite.library import model_util, sdxl_model_util, train_util, sdxl_original_unet
11
+ from external.llite.library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
12
+
13
+ TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
14
+ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
15
+
16
+ # DEFAULT_NOISE_OFFSET = 0.0357
17
+
18
+
19
+ def load_target_model(args, accelerator, model_version: str, weight_dtype):
20
+ # load models for each process
21
+ model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
22
+ for pi in range(accelerator.state.num_processes):
23
+ if pi == accelerator.state.local_process_index:
24
+ print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
25
+
26
+ (
27
+ load_stable_diffusion_format,
28
+ text_encoder1,
29
+ text_encoder2,
30
+ vae,
31
+ unet,
32
+ logit_scale,
33
+ ckpt_info,
34
+ ) = _load_target_model(
35
+ args.pretrained_model_name_or_path,
36
+ args.vae,
37
+ model_version,
38
+ weight_dtype,
39
+ accelerator.device if args.lowram else "cpu",
40
+ model_dtype,
41
+ )
42
+
43
+ # work on low-ram device
44
+ if args.lowram:
45
+ text_encoder1.to(accelerator.device)
46
+ text_encoder2.to(accelerator.device)
47
+ unet.to(accelerator.device)
48
+ vae.to(accelerator.device)
49
+
50
+ gc.collect()
51
+ torch.cuda.empty_cache()
52
+ accelerator.wait_for_everyone()
53
+
54
+ return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
55
+
56
+
57
+ def _load_target_model(
58
+ name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
59
+ ):
60
+ # model_dtype only work with full fp16/bf16
61
+ name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
62
+ load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
63
+
64
+ if load_stable_diffusion_format:
65
+ print(f"load StableDiffusion checkpoint: {name_or_path}")
66
+ (
67
+ text_encoder1,
68
+ text_encoder2,
69
+ vae,
70
+ unet,
71
+ logit_scale,
72
+ ckpt_info,
73
+ ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
74
+ else:
75
+ # Diffusers model is loaded to CPU
76
+ from diffusers import StableDiffusionXLPipeline
77
+
78
+ variant = "fp16" if weight_dtype == torch.float16 else None
79
+ print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
80
+ try:
81
+ try:
82
+ pipe = StableDiffusionXLPipeline.from_pretrained(
83
+ name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
84
+ )
85
+ except EnvironmentError as ex:
86
+ if variant is not None:
87
+ print("try to load fp32 model")
88
+ pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
89
+ else:
90
+ raise ex
91
+ except EnvironmentError as ex:
92
+ print(
93
+ f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
94
+ )
95
+ raise ex
96
+
97
+ text_encoder1 = pipe.text_encoder
98
+ text_encoder2 = pipe.text_encoder_2
99
+
100
+ # convert to fp32 for cache text_encoders outputs
101
+ if text_encoder1.dtype != torch.float32:
102
+ text_encoder1 = text_encoder1.to(dtype=torch.float32)
103
+ if text_encoder2.dtype != torch.float32:
104
+ text_encoder2 = text_encoder2.to(dtype=torch.float32)
105
+
106
+ vae = pipe.vae
107
+ unet = pipe.unet
108
+ del pipe
109
+
110
+ # Diffusers U-Net to original U-Net
111
+ state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
112
+ with init_empty_weights():
113
+ unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
114
+ sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
115
+ print("U-Net converted to original U-Net")
116
+
117
+ logit_scale = None
118
+ ckpt_info = None
119
+
120
+ # VAEを読み込む
121
+ if vae_path is not None:
122
+ vae = model_util.load_vae(vae_path, weight_dtype)
123
+ print("additional VAE loaded")
124
+
125
+ return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
126
+
127
+
128
+ def load_tokenizers(args: argparse.Namespace):
129
+ print("prepare tokenizers")
130
+
131
+ original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
132
+ tokeniers = []
133
+ for i, original_path in enumerate(original_paths):
134
+ tokenizer: CLIPTokenizer = None
135
+ if args.tokenizer_cache_dir:
136
+ local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
137
+ if os.path.exists(local_tokenizer_path):
138
+ print(f"load tokenizer from cache: {local_tokenizer_path}")
139
+ tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
140
+
141
+ if tokenizer is None:
142
+ tokenizer = CLIPTokenizer.from_pretrained(original_path)
143
+
144
+ if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
145
+ print(f"save Tokenizer to cache: {local_tokenizer_path}")
146
+ tokenizer.save_pretrained(local_tokenizer_path)
147
+
148
+ if i == 1:
149
+ tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer
150
+
151
+ tokeniers.append(tokenizer)
152
+
153
+ if hasattr(args, "max_token_length") and args.max_token_length is not None:
154
+ print(f"update token length: {args.max_token_length}")
155
+
156
+ return tokeniers
157
+
158
+
159
+ def match_mixed_precision(args, weight_dtype):
160
+ if args.full_fp16:
161
+ assert (
162
+ weight_dtype == torch.float16
163
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
164
+ return weight_dtype
165
+ elif args.full_bf16:
166
+ assert (
167
+ weight_dtype == torch.bfloat16
168
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
169
+ return weight_dtype
170
+ else:
171
+ return None
172
+
173
+
174
+ def timestep_embedding(timesteps, dim, max_period=10000):
175
+ """
176
+ Create sinusoidal timestep embeddings.
177
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
178
+ These may be fractional.
179
+ :param dim: the dimension of the output.
180
+ :param max_period: controls the minimum frequency of the embeddings.
181
+ :return: an [N x dim] Tensor of positional embeddings.
182
+ """
183
+ half = dim // 2
184
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
185
+ device=timesteps.device
186
+ )
187
+ args = timesteps[:, None].float() * freqs[None]
188
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
189
+ if dim % 2:
190
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
191
+ return embedding
192
+
193
+
194
+ def get_timestep_embedding(x, outdim):
195
+ assert len(x.shape) == 2
196
+ b, dims = x.shape[0], x.shape[1]
197
+ x = torch.flatten(x)
198
+ emb = timestep_embedding(x, outdim)
199
+ emb = torch.reshape(emb, (b, dims * outdim))
200
+ return emb
201
+
202
+
203
+ def get_size_embeddings(orig_size, crop_size, target_size, device):
204
+ emb1 = get_timestep_embedding(orig_size, 256)
205
+ emb2 = get_timestep_embedding(crop_size, 256)
206
+ emb3 = get_timestep_embedding(target_size, 256)
207
+ vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
208
+ return vector
209
+
210
+
211
+ def save_sd_model_on_train_end(
212
+ args: argparse.Namespace,
213
+ src_path: str,
214
+ save_stable_diffusion_format: bool,
215
+ use_safetensors: bool,
216
+ save_dtype: torch.dtype,
217
+ epoch: int,
218
+ global_step: int,
219
+ text_encoder1,
220
+ text_encoder2,
221
+ unet,
222
+ vae,
223
+ logit_scale,
224
+ ckpt_info,
225
+ ):
226
+ def sd_saver(ckpt_file, epoch_no, global_step):
227
+ sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
228
+ sdxl_model_util.save_stable_diffusion_checkpoint(
229
+ ckpt_file,
230
+ text_encoder1,
231
+ text_encoder2,
232
+ unet,
233
+ epoch_no,
234
+ global_step,
235
+ ckpt_info,
236
+ vae,
237
+ logit_scale,
238
+ sai_metadata,
239
+ save_dtype,
240
+ )
241
+
242
+ def diffusers_saver(out_dir):
243
+ sdxl_model_util.save_diffusers_checkpoint(
244
+ out_dir,
245
+ text_encoder1,
246
+ text_encoder2,
247
+ unet,
248
+ src_path,
249
+ vae,
250
+ use_safetensors=use_safetensors,
251
+ save_dtype=save_dtype,
252
+ )
253
+
254
+ train_util.save_sd_model_on_train_end_common(
255
+ args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
256
+ )
257
+
258
+
259
+ # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
260
+ # on_epoch_end: Trueならepoch終了時、Falseならstep経過時
261
+ def save_sd_model_on_epoch_end_or_stepwise(
262
+ args: argparse.Namespace,
263
+ on_epoch_end: bool,
264
+ accelerator,
265
+ src_path,
266
+ save_stable_diffusion_format: bool,
267
+ use_safetensors: bool,
268
+ save_dtype: torch.dtype,
269
+ epoch: int,
270
+ num_train_epochs: int,
271
+ global_step: int,
272
+ text_encoder1,
273
+ text_encoder2,
274
+ unet,
275
+ vae,
276
+ logit_scale,
277
+ ckpt_info,
278
+ ):
279
+ def sd_saver(ckpt_file, epoch_no, global_step):
280
+ sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
281
+ sdxl_model_util.save_stable_diffusion_checkpoint(
282
+ ckpt_file,
283
+ text_encoder1,
284
+ text_encoder2,
285
+ unet,
286
+ epoch_no,
287
+ global_step,
288
+ ckpt_info,
289
+ vae,
290
+ logit_scale,
291
+ sai_metadata,
292
+ save_dtype,
293
+ )
294
+
295
+ def diffusers_saver(out_dir):
296
+ sdxl_model_util.save_diffusers_checkpoint(
297
+ out_dir,
298
+ text_encoder1,
299
+ text_encoder2,
300
+ unet,
301
+ src_path,
302
+ vae,
303
+ use_safetensors=use_safetensors,
304
+ save_dtype=save_dtype,
305
+ )
306
+
307
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
308
+ args,
309
+ on_epoch_end,
310
+ accelerator,
311
+ save_stable_diffusion_format,
312
+ use_safetensors,
313
+ epoch,
314
+ num_train_epochs,
315
+ global_step,
316
+ sd_saver,
317
+ diffusers_saver,
318
+ )
319
+
320
+
321
+ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
322
+ parser.add_argument(
323
+ "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
324
+ )
325
+ parser.add_argument(
326
+ "--cache_text_encoder_outputs_to_disk",
327
+ action="store_true",
328
+ help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
329
+ )
330
+
331
+
332
+ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
333
+ assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
334
+ if args.v_parameterization:
335
+ print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
336
+
337
+ if args.clip_skip is not None:
338
+ print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
339
+
340
+ # if args.multires_noise_iterations:
341
+ # print(
342
+ # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
343
+ # )
344
+ # else:
345
+ # if args.noise_offset is None:
346
+ # args.noise_offset = DEFAULT_NOISE_OFFSET
347
+ # elif args.noise_offset != DEFAULT_NOISE_OFFSET:
348
+ # print(
349
+ # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
350
+ # )
351
+ # print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
352
+
353
+ assert (
354
+ not hasattr(args, "weighted_captions") or not args.weighted_captions
355
+ ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
356
+
357
+ if supportTextEncoderCaching:
358
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
359
+ args.cache_text_encoder_outputs = True
360
+ print(
361
+ "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
362
+ + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
363
+ )
364
+
365
+
366
+ def sample_images(*args, **kwargs):
367
+ return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
external/llite/library/slicing_vae.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Diffusers to reduce VRAM usage
2
+
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
27
+ from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
28
+ from diffusers.models.autoencoder_kl import AutoencoderKLOutput
29
+
30
+
31
+ def slice_h(x, num_slices):
32
+ # slice with pad 1 both sides: to eliminate side effect of padding of conv2d
33
+ # Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
34
+ # NCHWでもNHWCでもどちらでも動く
35
+ size = (x.shape[2] + num_slices - 1) // num_slices
36
+ sliced = []
37
+ for i in range(num_slices):
38
+ if i == 0:
39
+ sliced.append(x[:, :, : size + 1, :])
40
+ else:
41
+ end = size * (i + 1) + 1
42
+ if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
43
+ end = x.shape[2]
44
+ sliced.append(x[:, :, size * i - 1 : end, :])
45
+ if end >= x.shape[2]:
46
+ break
47
+ return sliced
48
+
49
+
50
+ def cat_h(sliced):
51
+ # padding分を除いて結合する
52
+ cat = []
53
+ for i, x in enumerate(sliced):
54
+ if i == 0:
55
+ cat.append(x[:, :, :-1, :])
56
+ elif i == len(sliced) - 1:
57
+ cat.append(x[:, :, 1:, :])
58
+ else:
59
+ cat.append(x[:, :, 1:-1, :])
60
+ del x
61
+ x = torch.cat(cat, dim=2)
62
+ return x
63
+
64
+
65
+ def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
66
+ assert _self.upsample is None and _self.downsample is None
67
+ assert _self.norm1.num_groups == _self.norm2.num_groups
68
+ assert temb is None
69
+
70
+ # make sure norms are on cpu
71
+ org_device = input_tensor.device
72
+ cpu_device = torch.device("cpu")
73
+ _self.norm1.to(cpu_device)
74
+ _self.norm2.to(cpu_device)
75
+
76
+ # GroupNormがCPUでfp16で動かない対策
77
+ org_dtype = input_tensor.dtype
78
+ if org_dtype == torch.float16:
79
+ _self.norm1.to(torch.float32)
80
+ _self.norm2.to(torch.float32)
81
+
82
+ # すべてのテンソルをCPUに移動する
83
+ input_tensor = input_tensor.to(cpu_device)
84
+ hidden_states = input_tensor
85
+
86
+ # どうもこれは結果が異なるようだ……
87
+ # def sliced_norm1(norm, x):
88
+ # num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
89
+ # sliced_tensor = torch.chunk(x, num_div, dim=1)
90
+ # sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
91
+ # sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
92
+ # print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
93
+ # normed_tensor = []
94
+ # for i in range(num_div):
95
+ # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
96
+ # normed_tensor.append(n)
97
+ # del n
98
+ # x = torch.cat(normed_tensor, dim=1)
99
+ # return num_div, x
100
+
101
+ # normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
102
+ if org_dtype == torch.float16:
103
+ hidden_states = hidden_states.to(torch.float32)
104
+ hidden_states = _self.norm1(hidden_states) # run on cpu
105
+ if org_dtype == torch.float16:
106
+ hidden_states = hidden_states.to(torch.float16)
107
+
108
+ sliced = slice_h(hidden_states, num_slices)
109
+ del hidden_states
110
+
111
+ for i in range(len(sliced)):
112
+ x = sliced[i]
113
+ sliced[i] = None
114
+
115
+ # 計算する部分だけGPUに移動する、以下同様
116
+ x = x.to(org_device)
117
+ x = _self.nonlinearity(x)
118
+ x = _self.conv1(x)
119
+ x = x.to(cpu_device)
120
+ sliced[i] = x
121
+ del x
122
+
123
+ hidden_states = cat_h(sliced)
124
+ del sliced
125
+
126
+ if org_dtype == torch.float16:
127
+ hidden_states = hidden_states.to(torch.float32)
128
+ hidden_states = _self.norm2(hidden_states) # run on cpu
129
+ if org_dtype == torch.float16:
130
+ hidden_states = hidden_states.to(torch.float16)
131
+
132
+ sliced = slice_h(hidden_states, num_slices)
133
+ del hidden_states
134
+
135
+ for i in range(len(sliced)):
136
+ x = sliced[i]
137
+ sliced[i] = None
138
+
139
+ x = x.to(org_device)
140
+ x = _self.nonlinearity(x)
141
+ x = _self.dropout(x)
142
+ x = _self.conv2(x)
143
+ x = x.to(cpu_device)
144
+ sliced[i] = x
145
+ del x
146
+
147
+ hidden_states = cat_h(sliced)
148
+ del sliced
149
+
150
+ # make shortcut
151
+ if _self.conv_shortcut is not None:
152
+ sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする
153
+ del input_tensor
154
+
155
+ for i in range(len(sliced)):
156
+ x = sliced[i]
157
+ sliced[i] = None
158
+
159
+ x = x.to(org_device)
160
+ x = _self.conv_shortcut(x)
161
+ x = x.to(cpu_device)
162
+ sliced[i] = x
163
+ del x
164
+
165
+ input_tensor = torch.cat(sliced, dim=2)
166
+ del sliced
167
+
168
+ output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
169
+
170
+ output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する
171
+ return output_tensor
172
+
173
+
174
+ class SlicingEncoder(nn.Module):
175
+ def __init__(
176
+ self,
177
+ in_channels=3,
178
+ out_channels=3,
179
+ down_block_types=("DownEncoderBlock2D",),
180
+ block_out_channels=(64,),
181
+ layers_per_block=2,
182
+ norm_num_groups=32,
183
+ act_fn="silu",
184
+ double_z=True,
185
+ num_slices=2,
186
+ ):
187
+ super().__init__()
188
+ self.layers_per_block = layers_per_block
189
+
190
+ self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
191
+
192
+ self.mid_block = None
193
+ self.down_blocks = nn.ModuleList([])
194
+
195
+ # down
196
+ output_channel = block_out_channels[0]
197
+ for i, down_block_type in enumerate(down_block_types):
198
+ input_channel = output_channel
199
+ output_channel = block_out_channels[i]
200
+ is_final_block = i == len(block_out_channels) - 1
201
+
202
+ down_block = get_down_block(
203
+ down_block_type,
204
+ num_layers=self.layers_per_block,
205
+ in_channels=input_channel,
206
+ out_channels=output_channel,
207
+ add_downsample=not is_final_block,
208
+ resnet_eps=1e-6,
209
+ downsample_padding=0,
210
+ resnet_act_fn=act_fn,
211
+ resnet_groups=norm_num_groups,
212
+ attention_head_dim=output_channel,
213
+ temb_channels=None,
214
+ )
215
+ self.down_blocks.append(down_block)
216
+
217
+ # mid
218
+ self.mid_block = UNetMidBlock2D(
219
+ in_channels=block_out_channels[-1],
220
+ resnet_eps=1e-6,
221
+ resnet_act_fn=act_fn,
222
+ output_scale_factor=1,
223
+ resnet_time_scale_shift="default",
224
+ attention_head_dim=block_out_channels[-1],
225
+ resnet_groups=norm_num_groups,
226
+ temb_channels=None,
227
+ )
228
+ self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
229
+
230
+ # out
231
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
232
+ self.conv_act = nn.SiLU()
233
+
234
+ conv_out_channels = 2 * out_channels if double_z else out_channels
235
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
236
+
237
+ # replace forward of ResBlocks
238
+ def wrapper(func, module, num_slices):
239
+ def forward(*args, **kwargs):
240
+ return func(module, num_slices, *args, **kwargs)
241
+
242
+ return forward
243
+
244
+ self.num_slices = num_slices
245
+ div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
246
+ # print(f"initial divisor: {div}")
247
+ if div >= 2:
248
+ div = int(div)
249
+ for resnet in self.mid_block.resnets:
250
+ resnet.forward = wrapper(resblock_forward, resnet, div)
251
+ # midblock doesn't have downsample
252
+
253
+ for i, down_block in enumerate(self.down_blocks[::-1]):
254
+ if div >= 2:
255
+ div = int(div)
256
+ # print(f"down block: {i} divisor: {div}")
257
+ for resnet in down_block.resnets:
258
+ resnet.forward = wrapper(resblock_forward, resnet, div)
259
+ if down_block.downsamplers is not None:
260
+ # print("has downsample")
261
+ for downsample in down_block.downsamplers:
262
+ downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
263
+ div *= 2
264
+
265
+ def forward(self, x):
266
+ sample = x
267
+ del x
268
+
269
+ org_device = sample.device
270
+ cpu_device = torch.device("cpu")
271
+
272
+ # sample = self.conv_in(sample)
273
+ sample = sample.to(cpu_device)
274
+ sliced = slice_h(sample, self.num_slices)
275
+ del sample
276
+
277
+ for i in range(len(sliced)):
278
+ x = sliced[i]
279
+ sliced[i] = None
280
+
281
+ x = x.to(org_device)
282
+ x = self.conv_in(x)
283
+ x = x.to(cpu_device)
284
+ sliced[i] = x
285
+ del x
286
+
287
+ sample = cat_h(sliced)
288
+ del sliced
289
+
290
+ sample = sample.to(org_device)
291
+
292
+ # down
293
+ for down_block in self.down_blocks:
294
+ sample = down_block(sample)
295
+
296
+ # middle
297
+ sample = self.mid_block(sample)
298
+
299
+ # post-process
300
+ # ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
301
+ sample = self.conv_norm_out(sample)
302
+ sample = self.conv_act(sample)
303
+ sample = self.conv_out(sample)
304
+
305
+ return sample
306
+
307
+ def downsample_forward(self, _self, num_slices, hidden_states):
308
+ assert hidden_states.shape[1] == _self.channels
309
+ assert _self.use_conv and _self.padding == 0
310
+ print("downsample forward", num_slices, hidden_states.shape)
311
+
312
+ org_device = hidden_states.device
313
+ cpu_device = torch.device("cpu")
314
+
315
+ hidden_states = hidden_states.to(cpu_device)
316
+ pad = (0, 1, 0, 1)
317
+ hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
318
+
319
+ # slice with even number because of stride 2
320
+ # strideが2なので偶数でスライスする
321
+ # slice with pad 1 both sides: to eliminate side effect of padding of conv2d
322
+ size = (hidden_states.shape[2] + num_slices - 1) // num_slices
323
+ size = size + 1 if size % 2 == 1 else size
324
+
325
+ sliced = []
326
+ for i in range(num_slices):
327
+ if i == 0:
328
+ sliced.append(hidden_states[:, :, : size + 1, :])
329
+ else:
330
+ end = size * (i + 1) + 1
331
+ if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor
332
+ end = hidden_states.shape[2]
333
+ sliced.append(hidden_states[:, :, size * i - 1 : end, :])
334
+ if end >= hidden_states.shape[2]:
335
+ break
336
+ del hidden_states
337
+
338
+ for i in range(len(sliced)):
339
+ x = sliced[i]
340
+ sliced[i] = None
341
+
342
+ x = x.to(org_device)
343
+ x = _self.conv(x)
344
+ x = x.to(cpu_device)
345
+
346
+ # ここだけ雰囲気が違うのはCopilotのせい
347
+ if i == 0:
348
+ hidden_states = x
349
+ else:
350
+ hidden_states = torch.cat([hidden_states, x], dim=2)
351
+
352
+ hidden_states = hidden_states.to(org_device)
353
+ # print("downsample forward done", hidden_states.shape)
354
+ return hidden_states
355
+
356
+
357
+ class SlicingDecoder(nn.Module):
358
+ def __init__(
359
+ self,
360
+ in_channels=3,
361
+ out_channels=3,
362
+ up_block_types=("UpDecoderBlock2D",),
363
+ block_out_channels=(64,),
364
+ layers_per_block=2,
365
+ norm_num_groups=32,
366
+ act_fn="silu",
367
+ num_slices=2,
368
+ ):
369
+ super().__init__()
370
+ self.layers_per_block = layers_per_block
371
+
372
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
373
+
374
+ self.mid_block = None
375
+ self.up_blocks = nn.ModuleList([])
376
+
377
+ # mid
378
+ self.mid_block = UNetMidBlock2D(
379
+ in_channels=block_out_channels[-1],
380
+ resnet_eps=1e-6,
381
+ resnet_act_fn=act_fn,
382
+ output_scale_factor=1,
383
+ resnet_time_scale_shift="default",
384
+ attention_head_dim=block_out_channels[-1],
385
+ resnet_groups=norm_num_groups,
386
+ temb_channels=None,
387
+ )
388
+ self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
389
+
390
+ # up
391
+ reversed_block_out_channels = list(reversed(block_out_channels))
392
+ output_channel = reversed_block_out_channels[0]
393
+ for i, up_block_type in enumerate(up_block_types):
394
+ prev_output_channel = output_channel
395
+ output_channel = reversed_block_out_channels[i]
396
+
397
+ is_final_block = i == len(block_out_channels) - 1
398
+
399
+ up_block = get_up_block(
400
+ up_block_type,
401
+ num_layers=self.layers_per_block + 1,
402
+ in_channels=prev_output_channel,
403
+ out_channels=output_channel,
404
+ prev_output_channel=None,
405
+ add_upsample=not is_final_block,
406
+ resnet_eps=1e-6,
407
+ resnet_act_fn=act_fn,
408
+ resnet_groups=norm_num_groups,
409
+ attention_head_dim=output_channel,
410
+ temb_channels=None,
411
+ )
412
+ self.up_blocks.append(up_block)
413
+ prev_output_channel = output_channel
414
+
415
+ # out
416
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
417
+ self.conv_act = nn.SiLU()
418
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
419
+
420
+ # replace forward of ResBlocks
421
+ def wrapper(func, module, num_slices):
422
+ def forward(*args, **kwargs):
423
+ return func(module, num_slices, *args, **kwargs)
424
+
425
+ return forward
426
+
427
+ self.num_slices = num_slices
428
+ div = num_slices / (2 ** (len(self.up_blocks) - 1))
429
+ print(f"initial divisor: {div}")
430
+ if div >= 2:
431
+ div = int(div)
432
+ for resnet in self.mid_block.resnets:
433
+ resnet.forward = wrapper(resblock_forward, resnet, div)
434
+ # midblock doesn't have upsample
435
+
436
+ for i, up_block in enumerate(self.up_blocks):
437
+ if div >= 2:
438
+ div = int(div)
439
+ # print(f"up block: {i} divisor: {div}")
440
+ for resnet in up_block.resnets:
441
+ resnet.forward = wrapper(resblock_forward, resnet, div)
442
+ if up_block.upsamplers is not None:
443
+ # print("has upsample")
444
+ for upsample in up_block.upsamplers:
445
+ upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
446
+ div *= 2
447
+
448
+ def forward(self, z):
449
+ sample = z
450
+ del z
451
+ sample = self.conv_in(sample)
452
+
453
+ # middle
454
+ sample = self.mid_block(sample)
455
+
456
+ # up
457
+ for i, up_block in enumerate(self.up_blocks):
458
+ sample = up_block(sample)
459
+
460
+ # post-process
461
+ sample = self.conv_norm_out(sample)
462
+ sample = self.conv_act(sample)
463
+
464
+ # conv_out with slicing because of VRAM usage
465
+ # conv_outはとてもVRAM使うのでスライスして対応
466
+ org_device = sample.device
467
+ cpu_device = torch.device("cpu")
468
+ sample = sample.to(cpu_device)
469
+
470
+ sliced = slice_h(sample, self.num_slices)
471
+ del sample
472
+ for i in range(len(sliced)):
473
+ x = sliced[i]
474
+ sliced[i] = None
475
+
476
+ x = x.to(org_device)
477
+ x = self.conv_out(x)
478
+ x = x.to(cpu_device)
479
+ sliced[i] = x
480
+ sample = cat_h(sliced)
481
+ del sliced
482
+
483
+ sample = sample.to(org_device)
484
+ return sample
485
+
486
+ def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
487
+ assert hidden_states.shape[1] == _self.channels
488
+ assert _self.use_conv_transpose == False and _self.use_conv
489
+
490
+ org_dtype = hidden_states.dtype
491
+ org_device = hidden_states.device
492
+ cpu_device = torch.device("cpu")
493
+
494
+ hidden_states = hidden_states.to(cpu_device)
495
+ sliced = slice_h(hidden_states, num_slices)
496
+ del hidden_states
497
+
498
+ for i in range(len(sliced)):
499
+ x = sliced[i]
500
+ sliced[i] = None
501
+
502
+ x = x.to(org_device)
503
+
504
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
505
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
506
+ # https://github.com/pytorch/pytorch/issues/86679
507
+ # PyTorch 2で直らないかね……
508
+ if org_dtype == torch.bfloat16:
509
+ x = x.to(torch.float32)
510
+
511
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
512
+
513
+ if org_dtype == torch.bfloat16:
514
+ x = x.to(org_dtype)
515
+
516
+ x = _self.conv(x)
517
+
518
+ # upsampleされてるのでpadは2になる
519
+ if i == 0:
520
+ x = x[:, :, :-2, :]
521
+ elif i == num_slices - 1:
522
+ x = x[:, :, 2:, :]
523
+ else:
524
+ x = x[:, :, 2:-2, :]
525
+
526
+ x = x.to(cpu_device)
527
+ sliced[i] = x
528
+ del x
529
+
530
+ hidden_states = torch.cat(sliced, dim=2)
531
+ # print("us hidden_states", hidden_states.shape)
532
+ del sliced
533
+
534
+ hidden_states = hidden_states.to(org_device)
535
+ return hidden_states
536
+
537
+
538
+ class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
539
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
540
+ and Max Welling.
541
+
542
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
543
+ implements for all the model (such as downloading or saving, etc.)
544
+
545
+ Parameters:
546
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
547
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
548
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
549
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
550
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
551
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
552
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
553
+ obj:`(64,)`): Tuple of block output channels.
554
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
555
+ latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
556
+ sample_size (`int`, *optional*, defaults to `32`): TODO
557
+ """
558
+
559
+ @register_to_config
560
+ def __init__(
561
+ self,
562
+ in_channels: int = 3,
563
+ out_channels: int = 3,
564
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
565
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
566
+ block_out_channels: Tuple[int] = (64,),
567
+ layers_per_block: int = 1,
568
+ act_fn: str = "silu",
569
+ latent_channels: int = 4,
570
+ norm_num_groups: int = 32,
571
+ sample_size: int = 32,
572
+ num_slices: int = 16,
573
+ ):
574
+ super().__init__()
575
+
576
+ # pass init params to Encoder
577
+ self.encoder = SlicingEncoder(
578
+ in_channels=in_channels,
579
+ out_channels=latent_channels,
580
+ down_block_types=down_block_types,
581
+ block_out_channels=block_out_channels,
582
+ layers_per_block=layers_per_block,
583
+ act_fn=act_fn,
584
+ norm_num_groups=norm_num_groups,
585
+ double_z=True,
586
+ num_slices=num_slices,
587
+ )
588
+
589
+ # pass init params to Decoder
590
+ self.decoder = SlicingDecoder(
591
+ in_channels=latent_channels,
592
+ out_channels=out_channels,
593
+ up_block_types=up_block_types,
594
+ block_out_channels=block_out_channels,
595
+ layers_per_block=layers_per_block,
596
+ norm_num_groups=norm_num_groups,
597
+ act_fn=act_fn,
598
+ num_slices=num_slices,
599
+ )
600
+
601
+ self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
602
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
603
+ self.use_slicing = False
604
+
605
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
606
+ h = self.encoder(x)
607
+ moments = self.quant_conv(h)
608
+ posterior = DiagonalGaussianDistribution(moments)
609
+
610
+ if not return_dict:
611
+ return (posterior,)
612
+
613
+ return AutoencoderKLOutput(latent_dist=posterior)
614
+
615
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
616
+ z = self.post_quant_conv(z)
617
+ dec = self.decoder(z)
618
+
619
+ if not return_dict:
620
+ return (dec,)
621
+
622
+ return DecoderOutput(sample=dec)
623
+
624
+ # これはバッチ方向のスライシング 紛らわしい
625
+ def enable_slicing(self):
626
+ r"""
627
+ Enable sliced VAE decoding.
628
+
629
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
630
+ steps. This is useful to save some memory and allow larger batch sizes.
631
+ """
632
+ self.use_slicing = True
633
+
634
+ def disable_slicing(self):
635
+ r"""
636
+ Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
637
+ decoding in one step.
638
+ """
639
+ self.use_slicing = False
640
+
641
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
642
+ if self.use_slicing and z.shape[0] > 1:
643
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
644
+ decoded = torch.cat(decoded_slices)
645
+ else:
646
+ decoded = self._decode(z).sample
647
+
648
+ if not return_dict:
649
+ return (decoded,)
650
+
651
+ return DecoderOutput(sample=decoded)
652
+
653
+ def forward(
654
+ self,
655
+ sample: torch.FloatTensor,
656
+ sample_posterior: bool = False,
657
+ return_dict: bool = True,
658
+ generator: Optional[torch.Generator] = None,
659
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
660
+ r"""
661
+ Args:
662
+ sample (`torch.FloatTensor`): Input sample.
663
+ sample_posterior (`bool`, *optional*, defaults to `False`):
664
+ Whether to sample from the posterior.
665
+ return_dict (`bool`, *optional*, defaults to `True`):
666
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
667
+ """
668
+ x = sample
669
+ posterior = self.encode(x).latent_dist
670
+ if sample_posterior:
671
+ z = posterior.sample(generator=generator)
672
+ else:
673
+ z = posterior.mode()
674
+ dec = self.decode(z).sample
675
+
676
+ if not return_dict:
677
+ return (dec,)
678
+
679
+ return DecoderOutput(sample=dec)
external/llite/library/train_util.py ADDED
The diff for this file is too large to render. See raw diff
 
external/llite/library/utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import threading
2
+ from typing import *
3
+
4
+
5
+ def fire_in_thread(f, *args, **kwargs):
6
+ threading.Thread(target=f, args=args, kwargs=kwargs).start()
external/llite/networks/.ipynb_checkpoints/control_net_lllite-checkpoint.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, List, Type
3
+ import torch
4
+ from external.llite.library import sdxl_original_unet
5
+
6
+
7
+ # input_blocksに適用するかどうか / if True, input_blocks are not applied
8
+ SKIP_INPUT_BLOCKS = False
9
+
10
+ # output_blocksに適用するかどうか / if True, output_blocks are not applied
11
+ SKIP_OUTPUT_BLOCKS = True
12
+
13
+ # conv2dに適用するかどうか / if True, conv2d are not applied
14
+ SKIP_CONV2D = False
15
+
16
+ # transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
17
+ # if True, only transformer_blocks are applied, and ResBlocks are not applied
18
+ TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
19
+
20
+ # Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
21
+ ATTN1_2_ONLY = True
22
+
23
+ # Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
24
+ ATTN_QKV_ONLY = True
25
+
26
+ # Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
27
+ # ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
28
+ ATTN1_ETC_ONLY = False # True
29
+
30
+ # transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
31
+ # max index of transformer_blocks. if None, apply to all transformer_blocks
32
+ TRANSFORMER_MAX_BLOCK_INDEX = None
33
+
34
+
35
+ class LLLiteModule(torch.nn.Module):
36
+ def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
37
+ super().__init__()
38
+
39
+ self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
40
+ self.lllite_name = name
41
+ self.cond_emb_dim = cond_emb_dim
42
+ self.org_module = [org_module]
43
+ self.dropout = dropout
44
+ self.multiplier = multiplier
45
+
46
+ if self.is_conv2d:
47
+ in_dim = org_module.in_channels
48
+ else:
49
+ in_dim = org_module.in_features
50
+
51
+ # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
52
+ # conditioning1 embeds conditioning image. it is not called for each timestep
53
+ modules = []
54
+ modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
55
+ if depth == 1:
56
+ modules.append(torch.nn.ReLU(inplace=True))
57
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
58
+ elif depth == 2:
59
+ modules.append(torch.nn.ReLU(inplace=True))
60
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
61
+ elif depth == 3:
62
+ # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
63
+ modules.append(torch.nn.ReLU(inplace=True))
64
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
65
+ modules.append(torch.nn.ReLU(inplace=True))
66
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
67
+
68
+ self.conditioning1 = torch.nn.Sequential(*modules)
69
+
70
+ # downで入力の次元数を削減する。LoRAにヒントを得ていることにする
71
+ # midでconditioning image embeddingと入力を結合する
72
+ # upで元の次元数に戻す
73
+ # これらはtimestepごとに呼ばれる
74
+ # reduce the number of input dimensions with down. inspired by LoRA
75
+ # combine conditioning image embedding and input with mid
76
+ # restore to the original dimension with up
77
+ # these are called for each timestep
78
+
79
+ if self.is_conv2d:
80
+ self.down = torch.nn.Sequential(
81
+ torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
82
+ torch.nn.ReLU(inplace=True),
83
+ )
84
+ self.mid = torch.nn.Sequential(
85
+ torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
86
+ torch.nn.ReLU(inplace=True),
87
+ )
88
+ self.up = torch.nn.Sequential(
89
+ torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
90
+ )
91
+ else:
92
+ # midの前にconditioningをreshapeすること / reshape conditioning before mid
93
+ self.down = torch.nn.Sequential(
94
+ torch.nn.Linear(in_dim, mlp_dim),
95
+ torch.nn.ReLU(inplace=True),
96
+ )
97
+ self.mid = torch.nn.Sequential(
98
+ torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
99
+ torch.nn.ReLU(inplace=True),
100
+ )
101
+ self.up = torch.nn.Sequential(
102
+ torch.nn.Linear(mlp_dim, in_dim),
103
+ )
104
+
105
+ # Zero-Convにする / set to Zero-Conv
106
+ torch.nn.init.zeros_(self.up[0].weight) # zero conv
107
+
108
+ self.depth = depth # 1~3
109
+ self.cond_emb = None
110
+ self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
111
+ self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
112
+
113
+ # batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
114
+ # Controlの種類によっては使えるかも
115
+ # both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
116
+ # it may be available depending on the type of Control
117
+
118
+ def set_cond_image(self, cond_image):
119
+ r"""
120
+ 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
121
+ / call the model inside, so if necessary, surround it with torch.no_grad()
122
+ """
123
+ if cond_image is None:
124
+ self.cond_emb = None
125
+ return
126
+
127
+ # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
128
+ # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
129
+ cx = self.conditioning1(cond_image)
130
+ if not self.is_conv2d:
131
+ # reshape / b,c,h,w -> b,h*w,c
132
+ n, c, h, w = cx.shape
133
+ cx = cx.view(n, c, h * w).permute(0, 2, 1)
134
+ self.cond_emb = cx
135
+
136
+ def set_batch_cond_only(self, cond_only, zeros):
137
+ self.batch_cond_only = cond_only
138
+ self.use_zeros_for_batch_uncond = zeros
139
+
140
+ def apply_to(self):
141
+ self.org_forward = self.org_module[0].forward
142
+ self.org_module[0].forward = self.forward
143
+
144
+ def forward(self, x):
145
+ r"""
146
+ 学習用の便利forward。元のモジュールのforwardを呼び出す
147
+ / convenient forward for training. call the forward of the original module
148
+ """
149
+ if self.multiplier == 0.0 or self.cond_emb is None:
150
+ return self.org_forward(x)
151
+
152
+ cx = self.cond_emb
153
+
154
+ if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
155
+ cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
156
+ if self.use_zeros_for_batch_uncond:
157
+ cx[0::2] = 0.0 # uncond is zero
158
+ # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
159
+
160
+ # downで入力の次元数を削減し、conditioning image embeddingと結合する
161
+ # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
162
+ # down reduces the number of input dimensions and combines it with conditioning image embedding
163
+ # we expect that it will mix well by combining in the channel direction instead of adding
164
+
165
+ cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
166
+ cx = self.mid(cx)
167
+
168
+ if self.dropout is not None and self.training:
169
+ cx = torch.nn.functional.dropout(cx, p=self.dropout)
170
+
171
+ cx = self.up(cx) * self.multiplier
172
+
173
+ # residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
174
+ if self.batch_cond_only:
175
+ zx = torch.zeros_like(x)
176
+ zx[1::2] += cx
177
+ cx = zx
178
+
179
+ x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
180
+ return x
181
+
182
+
183
+ class ControlNetLLLite(torch.nn.Module):
184
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
185
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
186
+
187
+ def __init__(
188
+ self,
189
+ unet: sdxl_original_unet.SdxlUNet2DConditionModel,
190
+ cond_emb_dim: int = 16,
191
+ mlp_dim: int = 16,
192
+ dropout: Optional[float] = None,
193
+ varbose: Optional[bool] = False,
194
+ multiplier: Optional[float] = 1.0,
195
+ ) -> None:
196
+ super().__init__()
197
+ # self.unets = [unet]
198
+
199
+ def create_modules(
200
+ root_module: torch.nn.Module,
201
+ target_replace_modules: List[torch.nn.Module],
202
+ module_class: Type[object],
203
+ ) -> List[torch.nn.Module]:
204
+ prefix = "lllite_unet"
205
+
206
+ modules = []
207
+ for name, module in root_module.named_modules():
208
+ if module.__class__.__name__ in target_replace_modules:
209
+ for child_name, child_module in module.named_modules():
210
+ is_linear = child_module.__class__.__name__ == "Linear"
211
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
212
+ if is_linear or (is_conv2d and not SKIP_CONV2D):
213
+ # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
214
+ # block index to depth: depth is using to calculate conditioning size and channels
215
+ block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
216
+ index1 = int(index1)
217
+ if block_name == "input_blocks":
218
+ if SKIP_INPUT_BLOCKS:
219
+ continue
220
+ depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
221
+ elif block_name == "middle_block":
222
+ depth = 3
223
+ elif block_name == "output_blocks":
224
+ if SKIP_OUTPUT_BLOCKS:
225
+ continue
226
+ depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
227
+ if int(index2) >= 2:
228
+ depth -= 1
229
+ else:
230
+ raise NotImplementedError()
231
+
232
+ lllite_name = prefix + "." + name + "." + child_name
233
+ lllite_name = lllite_name.replace(".", "_")
234
+
235
+ if TRANSFORMER_MAX_BLOCK_INDEX is not None:
236
+ p = lllite_name.find("transformer_blocks")
237
+ if p >= 0:
238
+ tf_index = int(lllite_name[p:].split("_")[2])
239
+ if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
240
+ continue
241
+
242
+ # time embは適用外とする
243
+ # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
244
+ # time emb is not applied
245
+ # attn2 conditioning (input from CLIP) cannot be applied because the shape is different
246
+ if "emb_layers" in lllite_name or (
247
+ "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
248
+ ):
249
+ continue
250
+
251
+ if ATTN1_2_ONLY:
252
+ if not ("attn1" in lllite_name or "attn2" in lllite_name):
253
+ continue
254
+ if ATTN_QKV_ONLY:
255
+ if "to_out" in lllite_name:
256
+ continue
257
+
258
+ if ATTN1_ETC_ONLY:
259
+ if "proj_out" in lllite_name:
260
+ pass
261
+ elif "attn1" in lllite_name and (
262
+ "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
263
+ ):
264
+ pass
265
+ elif "ff_net_2" in lllite_name:
266
+ pass
267
+ else:
268
+ continue
269
+
270
+ module = module_class(
271
+ depth,
272
+ cond_emb_dim,
273
+ lllite_name,
274
+ child_module,
275
+ mlp_dim,
276
+ dropout=dropout,
277
+ multiplier=multiplier,
278
+ )
279
+ modules.append(module)
280
+ print(f"Returning {len(modules)} modules for llite net")
281
+ return modules
282
+
283
+ target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
284
+ if not TRANSFORMER_ONLY:
285
+ target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
286
+
287
+ # create module instances
288
+ self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
289
+ print(f"created ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
290
+
291
+ def forward(self, x):
292
+ return x # dummy
293
+
294
+ def set_cond_image(self, cond_image):
295
+ r"""
296
+ 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
297
+ / call the model inside, so if necessary, surround it with torch.no_grad()
298
+ """
299
+ for module in self.unet_modules:
300
+ module.set_cond_image(cond_image)
301
+
302
+ def set_batch_cond_only(self, cond_only, zeros):
303
+ for module in self.unet_modules:
304
+ module.set_batch_cond_only(cond_only, zeros)
305
+
306
+ def set_multiplier(self, multiplier):
307
+ for module in self.unet_modules:
308
+ module.multiplier = multiplier
309
+
310
+ def load_weights(self, file):
311
+ if os.path.splitext(file)[1] == ".safetensors":
312
+ from safetensors.torch import load_file
313
+
314
+ weights_sd = load_file(file)
315
+ else:
316
+ weights_sd = torch.load(file, map_location="cpu")
317
+
318
+ info = self.load_state_dict(weights_sd, False)
319
+ return info
320
+
321
+ def apply_to(self):
322
+ print("applying LLLite for U-Net...")
323
+ for module in self.unet_modules:
324
+ module.apply_to()
325
+ self.add_module(module.lllite_name, module)
326
+
327
+ # マージできるかどうかを返す
328
+ def is_mergeable(self):
329
+ return False
330
+
331
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
332
+ raise NotImplementedError()
333
+
334
+ def enable_gradient_checkpointing(self):
335
+ # not supported
336
+ pass
337
+
338
+ def prepare_optimizer_params(self):
339
+ self.requires_grad_(True)
340
+ return self.parameters()
341
+
342
+ def prepare_grad_etc(self):
343
+ self.requires_grad_(True)
344
+
345
+ def on_epoch_start(self):
346
+ self.train()
347
+
348
+ def get_trainable_params(self):
349
+ return self.parameters()
350
+
351
+ def save_weights(self, file, dtype, metadata):
352
+ if metadata is not None and len(metadata) == 0:
353
+ metadata = None
354
+
355
+ state_dict = self.state_dict()
356
+
357
+ if dtype is not None:
358
+ for key in list(state_dict.keys()):
359
+ v = state_dict[key]
360
+ v = v.detach().clone().to("cpu").to(dtype)
361
+ state_dict[key] = v
362
+
363
+ if os.path.splitext(file)[1] == ".safetensors":
364
+ from safetensors.torch import save_file
365
+
366
+ save_file(state_dict, file, metadata)
367
+ else:
368
+ torch.save(state_dict, file)
369
+
370
+
371
+ if __name__ == "__main__":
372
+ # デバッグ用 / for debug
373
+
374
+ # sdxl_original_unet.USE_REENTRANT = False
375
+
376
+ # test shape etc
377
+ print("create unet")
378
+ unet = sdxl_original_unet.SdxlUNet2DConditionModel()
379
+ unet.to("cuda").to(torch.float16)
380
+
381
+ print("create ControlNet-LLLite")
382
+ control_net = ControlNetLLLite(unet, 32, 64)
383
+ control_net.apply_to()
384
+ control_net.to("cuda")
385
+
386
+ print(control_net)
387
+
388
+ # print number of parameters
389
+ print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
390
+
391
+ input()
392
+
393
+ unet.set_use_memory_efficient_attention(True, False)
394
+ unet.set_gradient_checkpointing(True)
395
+ unet.train() # for gradient checkpointing
396
+
397
+ control_net.train()
398
+
399
+ # # visualize
400
+ # import torchviz
401
+ # print("run visualize")
402
+ # controlnet.set_control(conditioning_image)
403
+ # output = unet(x, t, ctx, y)
404
+ # print("make_dot")
405
+ # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
406
+ # print("render")
407
+ # image.format = "svg" # "png"
408
+ # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
409
+ # input()
410
+
411
+ import bitsandbytes
412
+
413
+ optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3)
414
+
415
+ scaler = torch.cuda.amp.GradScaler(enabled=True)
416
+
417
+ print("start training")
418
+ steps = 10
419
+
420
+ sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
421
+ for step in range(steps):
422
+ print(f"step {step}")
423
+
424
+ batch_size = 1
425
+ conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
426
+ x = torch.randn(batch_size, 4, 128, 128).cuda()
427
+ t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
428
+ ctx = torch.randn(batch_size, 77, 2048).cuda()
429
+ y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
430
+
431
+ with torch.cuda.amp.autocast(enabled=True):
432
+ control_net.set_cond_image(conditioning_image)
433
+
434
+ output = unet(x, t, ctx, y)
435
+ target = torch.randn_like(output)
436
+ loss = torch.nn.functional.mse_loss(output, target)
437
+
438
+ scaler.scale(loss).backward()
439
+ scaler.step(optimizer)
440
+ scaler.update()
441
+ optimizer.zero_grad(set_to_none=True)
442
+ print(sample_param)
443
+
444
+ # from safetensors.torch import save_file
445
+
446
+ # save_file(control_net.state_dict(), "logs/control_net.safetensors")
external/llite/networks/check_lora_weights.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ from safetensors.torch import load_file
5
+
6
+
7
+ def main(file):
8
+ print(f"loading: {file}")
9
+ if os.path.splitext(file)[1] == ".safetensors":
10
+ sd = load_file(file)
11
+ else:
12
+ sd = torch.load(file, map_location="cpu")
13
+
14
+ values = []
15
+
16
+ keys = list(sd.keys())
17
+ for key in keys:
18
+ if "lora_up" in key or "lora_down" in key:
19
+ values.append((key, sd[key]))
20
+ print(f"number of LoRA modules: {len(values)}")
21
+
22
+ if args.show_all_keys:
23
+ for key in [k for k in keys if k not in values]:
24
+ values.append((key, sd[key]))
25
+ print(f"number of all modules: {len(values)}")
26
+
27
+ for key, value in values:
28
+ value = value.to(torch.float32)
29
+ print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
30
+
31
+
32
+ def setup_parser() -> argparse.ArgumentParser:
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
35
+ parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する")
36
+
37
+ return parser
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = setup_parser()
42
+
43
+ args = parser.parse_args()
44
+
45
+ main(args.file)
external/llite/networks/control_net_lllite.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, List, Type
3
+ import torch
4
+ from external.llite.library import sdxl_original_unet
5
+
6
+
7
+ # input_blocksに適用するかどうか / if True, input_blocks are not applied
8
+ SKIP_INPUT_BLOCKS = False
9
+
10
+ # output_blocksに適用するかどうか / if True, output_blocks are not applied
11
+ SKIP_OUTPUT_BLOCKS = True
12
+
13
+ # conv2dに適用するかどうか / if True, conv2d are not applied
14
+ SKIP_CONV2D = False
15
+
16
+ # transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
17
+ # if True, only transformer_blocks are applied, and ResBlocks are not applied
18
+ TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
19
+
20
+ # Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
21
+ ATTN1_2_ONLY = True
22
+
23
+ # Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
24
+ ATTN_QKV_ONLY = True
25
+
26
+ # Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
27
+ # ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
28
+ ATTN1_ETC_ONLY = False # True
29
+
30
+ # transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
31
+ # max index of transformer_blocks. if None, apply to all transformer_blocks
32
+ TRANSFORMER_MAX_BLOCK_INDEX = None
33
+
34
+
35
+ class LLLiteModule(torch.nn.Module):
36
+ def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
37
+ super().__init__()
38
+
39
+ self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
40
+ self.lllite_name = name
41
+ self.cond_emb_dim = cond_emb_dim
42
+ self.org_module = [org_module]
43
+ self.dropout = dropout
44
+ self.multiplier = multiplier
45
+
46
+ if self.is_conv2d:
47
+ in_dim = org_module.in_channels
48
+ else:
49
+ in_dim = org_module.in_features
50
+
51
+ # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
52
+ # conditioning1 embeds conditioning image. it is not called for each timestep
53
+ modules = []
54
+ modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
55
+ if depth == 1:
56
+ modules.append(torch.nn.ReLU(inplace=True))
57
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
58
+ elif depth == 2:
59
+ modules.append(torch.nn.ReLU(inplace=True))
60
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
61
+ elif depth == 3:
62
+ # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
63
+ modules.append(torch.nn.ReLU(inplace=True))
64
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
65
+ modules.append(torch.nn.ReLU(inplace=True))
66
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
67
+
68
+ self.conditioning1 = torch.nn.Sequential(*modules)
69
+
70
+ # downで入力の次元数を削減する。LoRAにヒントを得ていることにする
71
+ # midでconditioning image embeddingと入力を結合する
72
+ # upで元の次元数に戻す
73
+ # これらはtimestepごとに呼ばれる
74
+ # reduce the number of input dimensions with down. inspired by LoRA
75
+ # combine conditioning image embedding and input with mid
76
+ # restore to the original dimension with up
77
+ # these are called for each timestep
78
+
79
+ if self.is_conv2d:
80
+ self.down = torch.nn.Sequential(
81
+ torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
82
+ torch.nn.ReLU(inplace=True),
83
+ )
84
+ self.mid = torch.nn.Sequential(
85
+ torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
86
+ torch.nn.ReLU(inplace=True),
87
+ )
88
+ self.up = torch.nn.Sequential(
89
+ torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
90
+ )
91
+ else:
92
+ # midの前にconditioningをreshapeすること / reshape conditioning before mid
93
+ self.down = torch.nn.Sequential(
94
+ torch.nn.Linear(in_dim, mlp_dim),
95
+ torch.nn.ReLU(inplace=True),
96
+ )
97
+ self.mid = torch.nn.Sequential(
98
+ torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
99
+ torch.nn.ReLU(inplace=True),
100
+ )
101
+ self.up = torch.nn.Sequential(
102
+ torch.nn.Linear(mlp_dim, in_dim),
103
+ )
104
+
105
+ # Zero-Convにする / set to Zero-Conv
106
+ torch.nn.init.zeros_(self.up[0].weight) # zero conv
107
+
108
+ self.depth = depth # 1~3
109
+ self.cond_emb = None
110
+ self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
111
+ self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
112
+
113
+ # batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
114
+ # Controlの種類によっては使えるかも
115
+ # both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
116
+ # it may be available depending on the type of Control
117
+
118
+ def set_cond_image(self, cond_image):
119
+ r"""
120
+ 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
121
+ / call the model inside, so if necessary, surround it with torch.no_grad()
122
+ """
123
+ if cond_image is None:
124
+ self.cond_emb = None
125
+ return
126
+
127
+ # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
128
+ # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
129
+ cx = self.conditioning1(cond_image)
130
+ if not self.is_conv2d:
131
+ # reshape / b,c,h,w -> b,h*w,c
132
+ n, c, h, w = cx.shape
133
+ cx = cx.view(n, c, h * w).permute(0, 2, 1)
134
+ self.cond_emb = cx
135
+
136
+ def set_batch_cond_only(self, cond_only, zeros):
137
+ self.batch_cond_only = cond_only
138
+ self.use_zeros_for_batch_uncond = zeros
139
+
140
+ def apply_to(self):
141
+ self.org_forward = self.org_module[0].forward
142
+ self.org_module[0].forward = self.forward
143
+
144
+ def forward(self, x):
145
+ r"""
146
+ 学習用の便利forward。元のモジュールのforwardを呼び出す
147
+ / convenient forward for training. call the forward of the original module
148
+ """
149
+ if self.multiplier == 0.0 or self.cond_emb is None:
150
+ return self.org_forward(x)
151
+
152
+ cx = self.cond_emb
153
+
154
+ if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
155
+ cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
156
+ if self.use_zeros_for_batch_uncond:
157
+ cx[0::2] = 0.0 # uncond is zero
158
+ # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
159
+
160
+ # downで入力の次元数を削減し、conditioning image embeddingと結合する
161
+ # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
162
+ # down reduces the number of input dimensions and combines it with conditioning image embedding
163
+ # we expect that it will mix well by combining in the channel direction instead of adding
164
+
165
+ cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
166
+ cx = self.mid(cx)
167
+
168
+ if self.dropout is not None and self.training:
169
+ cx = torch.nn.functional.dropout(cx, p=self.dropout)
170
+
171
+ cx = self.up(cx) * self.multiplier
172
+
173
+ # residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
174
+ if self.batch_cond_only:
175
+ zx = torch.zeros_like(x)
176
+ zx[1::2] += cx
177
+ cx = zx
178
+
179
+ x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
180
+ return x
181
+
182
+
183
+ class ControlNetLLLite(torch.nn.Module):
184
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
185
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
186
+
187
+ def __init__(
188
+ self,
189
+ unet: sdxl_original_unet.SdxlUNet2DConditionModel,
190
+ cond_emb_dim: int = 16,
191
+ mlp_dim: int = 16,
192
+ dropout: Optional[float] = None,
193
+ varbose: Optional[bool] = False,
194
+ multiplier: Optional[float] = 1.0,
195
+ ) -> None:
196
+ super().__init__()
197
+ # self.unets = [unet]
198
+
199
+ def create_modules(
200
+ root_module: torch.nn.Module,
201
+ target_replace_modules: List[torch.nn.Module],
202
+ module_class: Type[object],
203
+ ) -> List[torch.nn.Module]:
204
+ prefix = "lllite_unet"
205
+
206
+ modules = []
207
+ for name, module in root_module.named_modules():
208
+ if module.__class__.__name__ in target_replace_modules:
209
+ for child_name, child_module in module.named_modules():
210
+ is_linear = child_module.__class__.__name__ == "Linear"
211
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
212
+ if is_linear or (is_conv2d and not SKIP_CONV2D):
213
+ # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
214
+ # block index to depth: depth is using to calculate conditioning size and channels
215
+ block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
216
+ index1 = int(index1)
217
+ if block_name == "input_blocks":
218
+ if SKIP_INPUT_BLOCKS:
219
+ continue
220
+ depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
221
+ elif block_name == "middle_block":
222
+ depth = 3
223
+ elif block_name == "output_blocks":
224
+ if SKIP_OUTPUT_BLOCKS:
225
+ continue
226
+ depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
227
+ if int(index2) >= 2:
228
+ depth -= 1
229
+ else:
230
+ raise NotImplementedError()
231
+
232
+ lllite_name = prefix + "." + name + "." + child_name
233
+ lllite_name = lllite_name.replace(".", "_")
234
+
235
+ if TRANSFORMER_MAX_BLOCK_INDEX is not None:
236
+ p = lllite_name.find("transformer_blocks")
237
+ if p >= 0:
238
+ tf_index = int(lllite_name[p:].split("_")[2])
239
+ if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
240
+ continue
241
+
242
+ # time embは適用外とする
243
+ # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
244
+ # time emb is not applied
245
+ # attn2 conditioning (input from CLIP) cannot be applied because the shape is different
246
+ if "emb_layers" in lllite_name or (
247
+ "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
248
+ ):
249
+ continue
250
+
251
+ if ATTN1_2_ONLY:
252
+ if not ("attn1" in lllite_name or "attn2" in lllite_name):
253
+ continue
254
+ if ATTN_QKV_ONLY:
255
+ if "to_out" in lllite_name:
256
+ continue
257
+
258
+ if ATTN1_ETC_ONLY:
259
+ if "proj_out" in lllite_name:
260
+ pass
261
+ elif "attn1" in lllite_name and (
262
+ "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
263
+ ):
264
+ pass
265
+ elif "ff_net_2" in lllite_name:
266
+ pass
267
+ else:
268
+ continue
269
+
270
+ module = module_class(
271
+ depth,
272
+ cond_emb_dim,
273
+ lllite_name,
274
+ child_module,
275
+ mlp_dim,
276
+ dropout=dropout,
277
+ multiplier=multiplier,
278
+ )
279
+ modules.append(module)
280
+ print(f"Returning {len(modules)} modules for llite net")
281
+ return modules
282
+
283
+ target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
284
+ if not TRANSFORMER_ONLY:
285
+ target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
286
+
287
+ # create module instances
288
+ self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
289
+ print(f"created ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
290
+
291
+ def forward(self, x):
292
+ return x # dummy
293
+
294
+ def set_cond_image(self, cond_image):
295
+ r"""
296
+ 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
297
+ / call the model inside, so if necessary, surround it with torch.no_grad()
298
+ """
299
+ for module in self.unet_modules:
300
+ module.set_cond_image(cond_image)
301
+
302
+ def set_batch_cond_only(self, cond_only, zeros):
303
+ for module in self.unet_modules:
304
+ module.set_batch_cond_only(cond_only, zeros)
305
+
306
+ def set_multiplier(self, multiplier):
307
+ for module in self.unet_modules:
308
+ module.multiplier = multiplier
309
+
310
+ def load_weights(self, file):
311
+ if os.path.splitext(file)[1] == ".safetensors":
312
+ from safetensors.torch import load_file
313
+
314
+ weights_sd = load_file(file)
315
+ else:
316
+ weights_sd = torch.load(file, map_location="cpu")
317
+
318
+ info = self.load_state_dict(weights_sd, False)
319
+ return info
320
+
321
+ def apply_to(self):
322
+ print("applying LLLite for U-Net...")
323
+ for module in self.unet_modules:
324
+ module.apply_to()
325
+ self.add_module(module.lllite_name, module)
326
+
327
+ # マージできるかどうかを返す
328
+ def is_mergeable(self):
329
+ return False
330
+
331
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
332
+ raise NotImplementedError()
333
+
334
+ def enable_gradient_checkpointing(self):
335
+ # not supported
336
+ pass
337
+
338
+ def prepare_optimizer_params(self):
339
+ self.requires_grad_(True)
340
+ return self.parameters()
341
+
342
+ def prepare_grad_etc(self):
343
+ self.requires_grad_(True)
344
+
345
+ def on_epoch_start(self):
346
+ self.train()
347
+
348
+ def get_trainable_params(self):
349
+ return self.parameters()
350
+
351
+ def save_weights(self, file, dtype, metadata):
352
+ if metadata is not None and len(metadata) == 0:
353
+ metadata = None
354
+
355
+ state_dict = self.state_dict()
356
+
357
+ if dtype is not None:
358
+ for key in list(state_dict.keys()):
359
+ v = state_dict[key]
360
+ v = v.detach().clone().to("cpu").to(dtype)
361
+ state_dict[key] = v
362
+
363
+ if os.path.splitext(file)[1] == ".safetensors":
364
+ from safetensors.torch import save_file
365
+
366
+ save_file(state_dict, file, metadata)
367
+ else:
368
+ torch.save(state_dict, file)
369
+
370
+
371
+ if __name__ == "__main__":
372
+ # デバッグ用 / for debug
373
+
374
+ # sdxl_original_unet.USE_REENTRANT = False
375
+
376
+ # test shape etc
377
+ print("create unet")
378
+ unet = sdxl_original_unet.SdxlUNet2DConditionModel()
379
+ unet.to("cuda").to(torch.float16)
380
+
381
+ print("create ControlNet-LLLite")
382
+ control_net = ControlNetLLLite(unet, 32, 64)
383
+ control_net.apply_to()
384
+ control_net.to("cuda")
385
+
386
+ print(control_net)
387
+
388
+ # print number of parameters
389
+ print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
390
+
391
+ input()
392
+
393
+ unet.set_use_memory_efficient_attention(True, False)
394
+ unet.set_gradient_checkpointing(True)
395
+ unet.train() # for gradient checkpointing
396
+
397
+ control_net.train()
398
+
399
+ # # visualize
400
+ # import torchviz
401
+ # print("run visualize")
402
+ # controlnet.set_control(conditioning_image)
403
+ # output = unet(x, t, ctx, y)
404
+ # print("make_dot")
405
+ # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
406
+ # print("render")
407
+ # image.format = "svg" # "png"
408
+ # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
409
+ # input()
410
+
411
+ import bitsandbytes
412
+
413
+ optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3)
414
+
415
+ scaler = torch.cuda.amp.GradScaler(enabled=True)
416
+
417
+ print("start training")
418
+ steps = 10
419
+
420
+ sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
421
+ for step in range(steps):
422
+ print(f"step {step}")
423
+
424
+ batch_size = 1
425
+ conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
426
+ x = torch.randn(batch_size, 4, 128, 128).cuda()
427
+ t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
428
+ ctx = torch.randn(batch_size, 77, 2048).cuda()
429
+ y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
430
+
431
+ with torch.cuda.amp.autocast(enabled=True):
432
+ control_net.set_cond_image(conditioning_image)
433
+
434
+ output = unet(x, t, ctx, y)
435
+ target = torch.randn_like(output)
436
+ loss = torch.nn.functional.mse_loss(output, target)
437
+
438
+ scaler.scale(loss).backward()
439
+ scaler.step(optimizer)
440
+ scaler.update()
441
+ optimizer.zero_grad(set_to_none=True)
442
+ print(sample_param)
443
+
444
+ # from safetensors.torch import save_file
445
+
446
+ # save_file(control_net.state_dict(), "logs/control_net.safetensors")
external/llite/networks/control_net_lllite_for_train.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用実装
2
+ # ControlNet-LLLite implementation for verification with cond_image passed in U-Net's forward
3
+
4
+ import os
5
+ import re
6
+ from typing import Optional, List, Type
7
+ import torch
8
+ from library import sdxl_original_unet
9
+
10
+
11
+ # input_blocksに適用するかどうか / if True, input_blocks are not applied
12
+ SKIP_INPUT_BLOCKS = False
13
+
14
+ # output_blocksに適用するかどうか / if True, output_blocks are not applied
15
+ SKIP_OUTPUT_BLOCKS = True
16
+
17
+ # conv2dに適用するかどうか / if True, conv2d are not applied
18
+ SKIP_CONV2D = False
19
+
20
+ # transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
21
+ # if True, only transformer_blocks are applied, and ResBlocks are not applied
22
+ TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
23
+
24
+ # Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
25
+ ATTN1_2_ONLY = True
26
+
27
+ # Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
28
+ ATTN_QKV_ONLY = True
29
+
30
+ # Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
31
+ # ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
32
+ ATTN1_ETC_ONLY = False # True
33
+
34
+ # transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
35
+ # max index of transformer_blocks. if None, apply to all transformer_blocks
36
+ TRANSFORMER_MAX_BLOCK_INDEX = None
37
+
38
+ ORIGINAL_LINEAR = torch.nn.Linear
39
+ ORIGINAL_CONV2D = torch.nn.Conv2d
40
+
41
+
42
+ def add_lllite_modules(module: torch.nn.Module, in_dim: int, depth, cond_emb_dim, mlp_dim) -> None:
43
+ # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
44
+ # conditioning1 embeds conditioning image. it is not called for each timestep
45
+ modules = []
46
+ modules.append(ORIGINAL_CONV2D(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
47
+ if depth == 1:
48
+ modules.append(torch.nn.ReLU(inplace=True))
49
+ modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
50
+ elif depth == 2:
51
+ modules.append(torch.nn.ReLU(inplace=True))
52
+ modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
53
+ elif depth == 3:
54
+ # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
55
+ modules.append(torch.nn.ReLU(inplace=True))
56
+ modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
57
+ modules.append(torch.nn.ReLU(inplace=True))
58
+ modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
59
+
60
+ module.lllite_conditioning1 = torch.nn.Sequential(*modules)
61
+
62
+ # downで入力の次元数を削減する。LoRAにヒントを得ていることにする
63
+ # midでconditioning image embeddingと入力を結合する
64
+ # upで元の次元数に戻す
65
+ # これらはtimestepごとに呼ばれる
66
+ # reduce the number of input dimensions with down. inspired by LoRA
67
+ # combine conditioning image embedding and input with mid
68
+ # restore to the original dimension with up
69
+ # these are called for each timestep
70
+
71
+ module.lllite_down = torch.nn.Sequential(
72
+ ORIGINAL_LINEAR(in_dim, mlp_dim),
73
+ torch.nn.ReLU(inplace=True),
74
+ )
75
+ module.lllite_mid = torch.nn.Sequential(
76
+ ORIGINAL_LINEAR(mlp_dim + cond_emb_dim, mlp_dim),
77
+ torch.nn.ReLU(inplace=True),
78
+ )
79
+ module.lllite_up = torch.nn.Sequential(
80
+ ORIGINAL_LINEAR(mlp_dim, in_dim),
81
+ )
82
+
83
+ # Zero-Convにする / set to Zero-Conv
84
+ torch.nn.init.zeros_(module.lllite_up[0].weight) # zero conv
85
+
86
+
87
+ class LLLiteLinear(ORIGINAL_LINEAR):
88
+ def __init__(self, in_features: int, out_features: int, **kwargs):
89
+ super().__init__(in_features, out_features, **kwargs)
90
+ self.enabled = False
91
+
92
+ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
93
+ self.enabled = True
94
+ self.lllite_name = name
95
+ self.cond_emb_dim = cond_emb_dim
96
+ self.dropout = dropout
97
+ self.multiplier = multiplier # ignored
98
+
99
+ in_dim = self.in_features
100
+ add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
101
+
102
+ self.cond_image = None
103
+ self.cond_emb = None
104
+
105
+ def set_cond_image(self, cond_image):
106
+ self.cond_image = cond_image
107
+ self.cond_emb = None
108
+
109
+ def forward(self, x):
110
+ if not self.enabled:
111
+ return super().forward(x)
112
+
113
+ if self.cond_emb is None:
114
+ self.cond_emb = self.lllite_conditioning1(self.cond_image)
115
+ cx = self.cond_emb
116
+
117
+ # reshape / b,c,h,w -> b,h*w,c
118
+ n, c, h, w = cx.shape
119
+ cx = cx.view(n, c, h * w).permute(0, 2, 1)
120
+
121
+ cx = torch.cat([cx, self.lllite_down(x)], dim=2)
122
+ cx = self.lllite_mid(cx)
123
+
124
+ if self.dropout is not None and self.training:
125
+ cx = torch.nn.functional.dropout(cx, p=self.dropout)
126
+
127
+ cx = self.lllite_up(cx) * self.multiplier
128
+
129
+ x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
130
+ return x
131
+
132
+
133
+ class LLLiteConv2d(ORIGINAL_CONV2D):
134
+ def __init__(self, in_channels: int, out_channels: int, kernel_size, **kwargs):
135
+ super().__init__(in_channels, out_channels, kernel_size, **kwargs)
136
+ self.enabled = False
137
+
138
+ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
139
+ self.enabled = True
140
+ self.lllite_name = name
141
+ self.cond_emb_dim = cond_emb_dim
142
+ self.dropout = dropout
143
+ self.multiplier = multiplier # ignored
144
+
145
+ in_dim = self.in_channels
146
+ add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
147
+
148
+ self.cond_image = None
149
+ self.cond_emb = None
150
+
151
+ def set_cond_image(self, cond_image):
152
+ self.cond_image = cond_image
153
+ self.cond_emb = None
154
+
155
+ def forward(self, x): # , cond_image=None):
156
+ if not self.enabled:
157
+ return super().forward(x)
158
+
159
+ if self.cond_emb is None:
160
+ self.cond_emb = self.lllite_conditioning1(self.cond_image)
161
+ cx = self.cond_emb
162
+
163
+ cx = torch.cat([cx, self.down(x)], dim=1)
164
+ cx = self.mid(cx)
165
+
166
+ if self.dropout is not None and self.training:
167
+ cx = torch.nn.functional.dropout(cx, p=self.dropout)
168
+
169
+ cx = self.up(cx) * self.multiplier
170
+
171
+ x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
172
+ return x
173
+
174
+
175
+ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DConditionModel):
176
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
177
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
178
+ LLLITE_PREFIX = "lllite_unet"
179
+
180
+ def __init__(self, **kwargs):
181
+ super().__init__(**kwargs)
182
+
183
+ def apply_lllite(
184
+ self,
185
+ cond_emb_dim: int = 16,
186
+ mlp_dim: int = 16,
187
+ dropout: Optional[float] = None,
188
+ varbose: Optional[bool] = False,
189
+ multiplier: Optional[float] = 1.0,
190
+ ) -> None:
191
+ def apply_to_modules(
192
+ root_module: torch.nn.Module,
193
+ target_replace_modules: List[torch.nn.Module],
194
+ ) -> List[torch.nn.Module]:
195
+ prefix = "lllite_unet"
196
+
197
+ modules = []
198
+ for name, module in root_module.named_modules():
199
+ if module.__class__.__name__ in target_replace_modules:
200
+ for child_name, child_module in module.named_modules():
201
+ is_linear = child_module.__class__.__name__ == "LLLiteLinear"
202
+ is_conv2d = child_module.__class__.__name__ == "LLLiteConv2d"
203
+
204
+ if is_linear or (is_conv2d and not SKIP_CONV2D):
205
+ # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
206
+ # block index to depth: depth is using to calculate conditioning size and channels
207
+ block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
208
+ index1 = int(index1)
209
+ if block_name == "input_blocks":
210
+ if SKIP_INPUT_BLOCKS:
211
+ continue
212
+ depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
213
+ elif block_name == "middle_block":
214
+ depth = 3
215
+ elif block_name == "output_blocks":
216
+ if SKIP_OUTPUT_BLOCKS:
217
+ continue
218
+ depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
219
+ if int(index2) >= 2:
220
+ depth -= 1
221
+ else:
222
+ raise NotImplementedError()
223
+
224
+ lllite_name = prefix + "." + name + "." + child_name
225
+ lllite_name = lllite_name.replace(".", "_")
226
+
227
+ if TRANSFORMER_MAX_BLOCK_INDEX is not None:
228
+ p = lllite_name.find("transformer_blocks")
229
+ if p >= 0:
230
+ tf_index = int(lllite_name[p:].split("_")[2])
231
+ if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
232
+ continue
233
+
234
+ # time embは適用外とする
235
+ # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
236
+ # time emb is not applied
237
+ # attn2 conditioning (input from CLIP) cannot be applied because the shape is different
238
+ if "emb_layers" in lllite_name or (
239
+ "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
240
+ ):
241
+ continue
242
+
243
+ if ATTN1_2_ONLY:
244
+ if not ("attn1" in lllite_name or "attn2" in lllite_name):
245
+ continue
246
+ if ATTN_QKV_ONLY:
247
+ if "to_out" in lllite_name:
248
+ continue
249
+
250
+ if ATTN1_ETC_ONLY:
251
+ if "proj_out" in lllite_name:
252
+ pass
253
+ elif "attn1" in lllite_name and (
254
+ "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
255
+ ):
256
+ pass
257
+ elif "ff_net_2" in lllite_name:
258
+ pass
259
+ else:
260
+ continue
261
+
262
+ child_module.set_lllite(depth, cond_emb_dim, lllite_name, mlp_dim, dropout, multiplier)
263
+ modules.append(child_module)
264
+
265
+ return modules
266
+
267
+ target_modules = SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE
268
+ if not TRANSFORMER_ONLY:
269
+ target_modules = target_modules + SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
270
+
271
+ # create module instances
272
+ self.lllite_modules = apply_to_modules(self, target_modules)
273
+ print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
274
+
275
+ # def prepare_optimizer_params(self):
276
+ def prepare_params(self):
277
+ train_params = []
278
+ non_train_params = []
279
+ for name, p in self.named_parameters():
280
+ if "lllite" in name:
281
+ train_params.append(p)
282
+ else:
283
+ non_train_params.append(p)
284
+ print(f"count of trainable parameters: {len(train_params)}")
285
+ print(f"count of non-trainable parameters: {len(non_train_params)}")
286
+
287
+ for p in non_train_params:
288
+ p.requires_grad_(False)
289
+
290
+ # without this, an error occurs in the optimizer
291
+ # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
292
+ non_train_params[0].requires_grad_(True)
293
+
294
+ for p in train_params:
295
+ p.requires_grad_(True)
296
+
297
+ return train_params
298
+
299
+ # def prepare_grad_etc(self):
300
+ # self.requires_grad_(True)
301
+
302
+ # def on_epoch_start(self):
303
+ # self.train()
304
+
305
+ def get_trainable_params(self):
306
+ return [p[1] for p in self.named_parameters() if "lllite" in p[0]]
307
+
308
+ def save_lllite_weights(self, file, dtype, metadata):
309
+ if metadata is not None and len(metadata) == 0:
310
+ metadata = None
311
+
312
+ org_state_dict = self.state_dict()
313
+
314
+ # copy LLLite keys from org_state_dict to state_dict with key conversion
315
+ state_dict = {}
316
+ for key in org_state_dict.keys():
317
+ # split with ".lllite"
318
+ pos = key.find(".lllite")
319
+ if pos < 0:
320
+ continue
321
+ lllite_key = SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "." + key[:pos]
322
+ lllite_key = lllite_key.replace(".", "_") + key[pos:]
323
+ lllite_key = lllite_key.replace(".lllite_", ".")
324
+ state_dict[lllite_key] = org_state_dict[key]
325
+
326
+ if dtype is not None:
327
+ for key in list(state_dict.keys()):
328
+ v = state_dict[key]
329
+ v = v.detach().clone().to("cpu").to(dtype)
330
+ state_dict[key] = v
331
+
332
+ if os.path.splitext(file)[1] == ".safetensors":
333
+ from safetensors.torch import save_file
334
+
335
+ save_file(state_dict, file, metadata)
336
+ else:
337
+ torch.save(state_dict, file)
338
+
339
+ def load_lllite_weights(self, file, non_lllite_unet_sd=None):
340
+ r"""
341
+ LLLiteの重みを読み込まない(initされた値を使う)場合はfileにNoneを指定する。
342
+ この場合、non_lllite_unet_sdにはU-Netのstate_dictを指定する。
343
+
344
+ If you do not want to load LLLite weights (use initialized values), specify None for file.
345
+ In this case, specify the state_dict of U-Net for non_lllite_unet_sd.
346
+ """
347
+ if not file:
348
+ state_dict = self.state_dict()
349
+ for key in non_lllite_unet_sd:
350
+ if key in state_dict:
351
+ state_dict[key] = non_lllite_unet_sd[key]
352
+ info = self.load_state_dict(state_dict, False)
353
+ return info
354
+
355
+ if os.path.splitext(file)[1] == ".safetensors":
356
+ from safetensors.torch import load_file
357
+
358
+ weights_sd = load_file(file)
359
+ else:
360
+ weights_sd = torch.load(file, map_location="cpu")
361
+
362
+ # module_name = module_name.replace("_block", "@blocks")
363
+ # module_name = module_name.replace("_layer", "@layer")
364
+ # module_name = module_name.replace("to_", "to@")
365
+ # module_name = module_name.replace("time_embed", "time@embed")
366
+ # module_name = module_name.replace("label_emb", "label@emb")
367
+ # module_name = module_name.replace("skip_connection", "skip@connection")
368
+ # module_name = module_name.replace("proj_in", "proj@in")
369
+ # module_name = module_name.replace("proj_out", "proj@out")
370
+ pattern = re.compile(r"(_block|_layer|to_|time_embed|label_emb|skip_connection|proj_in|proj_out)")
371
+
372
+ # convert to lllite with U-Net state dict
373
+ state_dict = non_lllite_unet_sd.copy() if non_lllite_unet_sd is not None else {}
374
+ for key in weights_sd.keys():
375
+ # split with "."
376
+ pos = key.find(".")
377
+ if pos < 0:
378
+ continue
379
+
380
+ module_name = key[:pos]
381
+ weight_name = key[pos + 1 :] # exclude "."
382
+ module_name = module_name.replace(SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "_", "")
383
+
384
+ # これはうまくいかない。逆変換を考えなかった設計が悪い / this does not work well. bad design because I didn't think about inverse conversion
385
+ # module_name = module_name.replace("_", ".")
386
+
387
+ # ださいけどSDXLのU-Netの "_" を "@" に変換する / ugly but convert "_" of SDXL U-Net to "@"
388
+ matches = pattern.findall(module_name)
389
+ if matches is not None:
390
+ for m in matches:
391
+ print(module_name, m)
392
+ module_name = module_name.replace(m, m.replace("_", "@"))
393
+ module_name = module_name.replace("_", ".")
394
+ module_name = module_name.replace("@", "_")
395
+
396
+ lllite_key = module_name + ".lllite_" + weight_name
397
+
398
+ state_dict[lllite_key] = weights_sd[key]
399
+
400
+ info = self.load_state_dict(state_dict, False)
401
+ return info
402
+
403
+ def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kwargs):
404
+ for m in self.lllite_modules:
405
+ m.set_cond_image(cond_image)
406
+ return super().forward(x, timesteps, context, y, **kwargs)
407
+
408
+
409
+ def replace_unet_linear_and_conv2d():
410
+ print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
411
+ sdxl_original_unet.torch.nn.Linear = LLLiteLinear
412
+ sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
413
+
414
+
415
+ if __name__ == "__main__":
416
+ # デバッグ用 / for debug
417
+
418
+ # sdxl_original_unet.USE_REENTRANT = False
419
+ replace_unet_linear_and_conv2d()
420
+
421
+ # test shape etc
422
+ print("create unet")
423
+ unet = SdxlUNet2DConditionModelControlNetLLLite()
424
+
425
+ print("enable ControlNet-LLLite")
426
+ unet.apply_lllite(32, 64, None, False, 1.0)
427
+ unet.to("cuda") # .to(torch.float16)
428
+
429
+ # from safetensors.torch import load_file
430
+
431
+ # model_sd = load_file(r"E:\Work\SD\Models\sdxl\sd_xl_base_1.0_0.9vae.safetensors")
432
+ # unet_sd = {}
433
+
434
+ # # copy U-Net keys from unet_state_dict to state_dict
435
+ # prefix = "model.diffusion_model."
436
+ # for key in model_sd.keys():
437
+ # if key.startswith(prefix):
438
+ # converted_key = key[len(prefix) :]
439
+ # unet_sd[converted_key] = model_sd[key]
440
+
441
+ # info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
442
+ # print(info)
443
+
444
+ # print(unet)
445
+
446
+ # print number of parameters
447
+ params = unet.prepare_params()
448
+ print("number of parameters", sum(p.numel() for p in params))
449
+ # print("type any key to continue")
450
+ # input()
451
+
452
+ unet.set_use_memory_efficient_attention(True, False)
453
+ unet.set_gradient_checkpointing(True)
454
+ unet.train() # for gradient checkpointing
455
+
456
+ # # visualize
457
+ # import torchviz
458
+ # print("run visualize")
459
+ # controlnet.set_control(conditioning_image)
460
+ # output = unet(x, t, ctx, y)
461
+ # print("make_dot")
462
+ # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
463
+ # print("render")
464
+ # image.format = "svg" # "png"
465
+ # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
466
+ # input()
467
+
468
+ import bitsandbytes
469
+
470
+ optimizer = bitsandbytes.adam.Adam8bit(params, 1e-3)
471
+
472
+ scaler = torch.cuda.amp.GradScaler(enabled=True)
473
+
474
+ print("start training")
475
+ steps = 10
476
+ batch_size = 1
477
+
478
+ sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
479
+ for step in range(steps):
480
+ print(f"step {step}")
481
+
482
+ conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
483
+ x = torch.randn(batch_size, 4, 128, 128).cuda()
484
+ t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
485
+ ctx = torch.randn(batch_size, 77, 2048).cuda()
486
+ y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
487
+
488
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
489
+ output = unet(x, t, ctx, y, conditioning_image)
490
+ target = torch.randn_like(output)
491
+ loss = torch.nn.functional.mse_loss(output, target)
492
+
493
+ scaler.scale(loss).backward()
494
+ scaler.step(optimizer)
495
+ scaler.update()
496
+ optimizer.zero_grad(set_to_none=True)
497
+ print(sample_param)
498
+
499
+ # from safetensors.torch import save_file
500
+
501
+ # print("save weights")
502
+ # unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)
external/llite/networks/dylora.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # some codes are copied from:
2
+ # https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/
3
+
4
+ # Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Changes made to the original code:
6
+ # 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer
7
+ # ------------------------------------------------------------------------------------------
8
+ # Copyright (c) Microsoft Corporation. All rights reserved.
9
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
10
+ # ------------------------------------------------------------------------------------------
11
+
12
+ import math
13
+ import os
14
+ import random
15
+ from typing import List, Tuple, Union
16
+ import torch
17
+ from torch import nn
18
+
19
+
20
+ class DyLoRAModule(torch.nn.Module):
21
+ """
22
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
23
+ """
24
+
25
+ # NOTE: support dropout in future
26
+ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
27
+ super().__init__()
28
+ self.lora_name = lora_name
29
+ self.lora_dim = lora_dim
30
+ self.unit = unit
31
+ assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit"
32
+
33
+ if org_module.__class__.__name__ == "Conv2d":
34
+ in_dim = org_module.in_channels
35
+ out_dim = org_module.out_channels
36
+ else:
37
+ in_dim = org_module.in_features
38
+ out_dim = org_module.out_features
39
+
40
+ if type(alpha) == torch.Tensor:
41
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
42
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
43
+ self.scale = alpha / self.lora_dim
44
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
45
+
46
+ self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
47
+ self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3)
48
+
49
+ if self.is_conv2d and self.is_conv2d_3x3:
50
+ kernel_size = org_module.kernel_size
51
+ self.stride = org_module.stride
52
+ self.padding = org_module.padding
53
+ self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
54
+ self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
55
+ else:
56
+ self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
57
+ self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
58
+
59
+ # same as microsoft's
60
+ for lora in self.lora_A:
61
+ torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
62
+ for lora in self.lora_B:
63
+ torch.nn.init.zeros_(lora)
64
+
65
+ self.multiplier = multiplier
66
+ self.org_module = org_module # remove in applying
67
+
68
+ def apply_to(self):
69
+ self.org_forward = self.org_module.forward
70
+ self.org_module.forward = self.forward
71
+ del self.org_module
72
+
73
+ def forward(self, x):
74
+ result = self.org_forward(x)
75
+
76
+ # specify the dynamic rank
77
+ trainable_rank = random.randint(0, self.lora_dim - 1)
78
+ trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
79
+
80
+ # 一部のパラメータを固定して、残りのパラメータを学習する
81
+ for i in range(0, trainable_rank):
82
+ self.lora_A[i].requires_grad = False
83
+ self.lora_B[i].requires_grad = False
84
+ for i in range(trainable_rank, trainable_rank + self.unit):
85
+ self.lora_A[i].requires_grad = True
86
+ self.lora_B[i].requires_grad = True
87
+ for i in range(trainable_rank + self.unit, self.lora_dim):
88
+ self.lora_A[i].requires_grad = False
89
+ self.lora_B[i].requires_grad = False
90
+
91
+ lora_A = torch.cat(tuple(self.lora_A), dim=0)
92
+ lora_B = torch.cat(tuple(self.lora_B), dim=1)
93
+
94
+ # calculate with lora_A and lora_B
95
+ if self.is_conv2d_3x3:
96
+ ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding)
97
+ ab = torch.nn.functional.conv2d(ab, lora_B)
98
+ else:
99
+ ab = x
100
+ if self.is_conv2d:
101
+ ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
102
+
103
+ ab = torch.nn.functional.linear(ab, lora_A)
104
+ ab = torch.nn.functional.linear(ab, lora_B)
105
+
106
+ if self.is_conv2d:
107
+ ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
108
+
109
+ # 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
110
+ result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
111
+
112
+ # NOTE weightに加算してからlinear/conv2dを呼んだほうが��いかも
113
+ return result
114
+
115
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
116
+ # state dictを通常のLoRAと同じにする:
117
+ # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
118
+ sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
119
+
120
+ lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
121
+ if self.is_conv2d and not self.is_conv2d_3x3:
122
+ lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
123
+
124
+ lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
125
+ if self.is_conv2d and not self.is_conv2d_3x3:
126
+ lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
127
+
128
+ sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
129
+ sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
130
+
131
+ i = 0
132
+ while True:
133
+ key_a = f"{self.lora_name}.lora_A.{i}"
134
+ key_b = f"{self.lora_name}.lora_B.{i}"
135
+ if key_a in sd:
136
+ sd.pop(key_a)
137
+ sd.pop(key_b)
138
+ else:
139
+ break
140
+ i += 1
141
+ return sd
142
+
143
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
144
+ # 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた
145
+ lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
146
+ lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
147
+
148
+ if lora_A_weight is None or lora_B_weight is None:
149
+ if strict:
150
+ raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
151
+ else:
152
+ return
153
+
154
+ if self.is_conv2d and not self.is_conv2d_3x3:
155
+ lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
156
+ lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
157
+
158
+ state_dict.update(
159
+ {f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
160
+ )
161
+ state_dict.update(
162
+ {f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
163
+ )
164
+
165
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
166
+
167
+
168
+ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
169
+ if network_dim is None:
170
+ network_dim = 4 # default
171
+ if network_alpha is None:
172
+ network_alpha = 1.0
173
+
174
+ # extract dim/alpha for conv2d, and block dim
175
+ conv_dim = kwargs.get("conv_dim", None)
176
+ conv_alpha = kwargs.get("conv_alpha", None)
177
+ unit = kwargs.get("unit", None)
178
+ if conv_dim is not None:
179
+ conv_dim = int(conv_dim)
180
+ assert conv_dim == network_dim, "conv_dim must be same as network_dim"
181
+ if conv_alpha is None:
182
+ conv_alpha = 1.0
183
+ else:
184
+ conv_alpha = float(conv_alpha)
185
+ if unit is not None:
186
+ unit = int(unit)
187
+ else:
188
+ unit = 1
189
+
190
+ network = DyLoRANetwork(
191
+ text_encoder,
192
+ unet,
193
+ multiplier=multiplier,
194
+ lora_dim=network_dim,
195
+ alpha=network_alpha,
196
+ apply_to_conv=conv_dim is not None,
197
+ unit=unit,
198
+ varbose=True,
199
+ )
200
+ return network
201
+
202
+
203
+ # Create network from weights for inference, weights are not loaded here (because can be merged)
204
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
205
+ if weights_sd is None:
206
+ if os.path.splitext(file)[1] == ".safetensors":
207
+ from safetensors.torch import load_file, safe_open
208
+
209
+ weights_sd = load_file(file)
210
+ else:
211
+ weights_sd = torch.load(file, map_location="cpu")
212
+
213
+ # get dim/alpha mapping
214
+ modules_dim = {}
215
+ modules_alpha = {}
216
+ for key, value in weights_sd.items():
217
+ if "." not in key:
218
+ continue
219
+
220
+ lora_name = key.split(".")[0]
221
+ if "alpha" in key:
222
+ modules_alpha[lora_name] = value
223
+ elif "lora_down" in key:
224
+ dim = value.size()[0]
225
+ modules_dim[lora_name] = dim
226
+ # print(lora_name, value.size(), dim)
227
+
228
+ # support old LoRA without alpha
229
+ for key in modules_dim.keys():
230
+ if key not in modules_alpha:
231
+ modules_alpha = modules_dim[key]
232
+
233
+ module_class = DyLoRAModule
234
+
235
+ network = DyLoRANetwork(
236
+ text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
237
+ )
238
+ return network, weights_sd
239
+
240
+
241
+ class DyLoRANetwork(torch.nn.Module):
242
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
243
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
244
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
245
+ LORA_PREFIX_UNET = "lora_unet"
246
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
247
+
248
+ def __init__(
249
+ self,
250
+ text_encoder,
251
+ unet,
252
+ multiplier=1.0,
253
+ lora_dim=4,
254
+ alpha=1,
255
+ apply_to_conv=False,
256
+ modules_dim=None,
257
+ modules_alpha=None,
258
+ unit=1,
259
+ module_class=DyLoRAModule,
260
+ varbose=False,
261
+ ) -> None:
262
+ super().__init__()
263
+ self.multiplier = multiplier
264
+
265
+ self.lora_dim = lora_dim
266
+ self.alpha = alpha
267
+ self.apply_to_conv = apply_to_conv
268
+
269
+ if modules_dim is not None:
270
+ print(f"create LoRA network from weights")
271
+ else:
272
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
273
+ if self.apply_to_conv:
274
+ print(f"apply LoRA to Conv2d with kernel size (3,3).")
275
+
276
+ # create module instances
277
+ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
278
+ prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER
279
+ loras = []
280
+ for name, module in root_module.named_modules():
281
+ if module.__class__.__name__ in target_replace_modules:
282
+ for child_name, child_module in module.named_modules():
283
+ is_linear = child_module.__class__.__name__ == "Linear"
284
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
285
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
286
+
287
+ if is_linear or is_conv2d:
288
+ lora_name = prefix + "." + name + "." + child_name
289
+ lora_name = lora_name.replace(".", "_")
290
+
291
+ dim = None
292
+ alpha = None
293
+ if modules_dim is not None:
294
+ if lora_name in modules_dim:
295
+ dim = modules_dim[lora_name]
296
+ alpha = modules_alpha[lora_name]
297
+ else:
298
+ if is_linear or is_conv2d_1x1 or apply_to_conv:
299
+ dim = self.lora_dim
300
+ alpha = self.alpha
301
+
302
+ if dim is None or dim == 0:
303
+ continue
304
+
305
+ # dropout and fan_in_fan_out is default
306
+ lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
307
+ loras.append(lora)
308
+ return loras
309
+
310
+ self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
311
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
312
+
313
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
314
+ target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
315
+ if modules_dim is not None or self.apply_to_conv:
316
+ target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
317
+
318
+ self.unet_loras = create_modules(True, unet, target_modules)
319
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
320
+
321
+ def set_multiplier(self, multiplier):
322
+ self.multiplier = multiplier
323
+ for lora in self.text_encoder_loras + self.unet_loras:
324
+ lora.multiplier = self.multiplier
325
+
326
+ def load_weights(self, file):
327
+ if os.path.splitext(file)[1] == ".safetensors":
328
+ from safetensors.torch import load_file
329
+
330
+ weights_sd = load_file(file)
331
+ else:
332
+ weights_sd = torch.load(file, map_location="cpu")
333
+
334
+ info = self.load_state_dict(weights_sd, False)
335
+ return info
336
+
337
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
338
+ if apply_text_encoder:
339
+ print("enable LoRA for text encoder")
340
+ else:
341
+ self.text_encoder_loras = []
342
+
343
+ if apply_unet:
344
+ print("enable LoRA for U-Net")
345
+ else:
346
+ self.unet_loras = []
347
+
348
+ for lora in self.text_encoder_loras + self.unet_loras:
349
+ lora.apply_to()
350
+ self.add_module(lora.lora_name, lora)
351
+
352
+ """
353
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
354
+ apply_text_encoder = apply_unet = False
355
+ for key in weights_sd.keys():
356
+ if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER):
357
+ apply_text_encoder = True
358
+ elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET):
359
+ apply_unet = True
360
+
361
+ if apply_text_encoder:
362
+ print("enable LoRA for text encoder")
363
+ else:
364
+ self.text_encoder_loras = []
365
+
366
+ if apply_unet:
367
+ print("enable LoRA for U-Net")
368
+ else:
369
+ self.unet_loras = []
370
+
371
+ for lora in self.text_encoder_loras + self.unet_loras:
372
+ sd_for_lora = {}
373
+ for key in weights_sd.keys():
374
+ if key.startswith(lora.lora_name):
375
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
376
+ lora.merge_to(sd_for_lora, dtype, device)
377
+
378
+ print(f"weights are merged")
379
+ """
380
+
381
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
382
+ self.requires_grad_(True)
383
+ all_params = []
384
+
385
+ def enumerate_params(loras):
386
+ params = []
387
+ for lora in loras:
388
+ params.extend(lora.parameters())
389
+ return params
390
+
391
+ if self.text_encoder_loras:
392
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
393
+ if text_encoder_lr is not None:
394
+ param_data["lr"] = text_encoder_lr
395
+ all_params.append(param_data)
396
+
397
+ if self.unet_loras:
398
+ param_data = {"params": enumerate_params(self.unet_loras)}
399
+ if unet_lr is not None:
400
+ param_data["lr"] = unet_lr
401
+ all_params.append(param_data)
402
+
403
+ return all_params
404
+
405
+ def enable_gradient_checkpointing(self):
406
+ # not supported
407
+ pass
408
+
409
+ def prepare_grad_etc(self, text_encoder, unet):
410
+ self.requires_grad_(True)
411
+
412
+ def on_epoch_start(self, text_encoder, unet):
413
+ self.train()
414
+
415
+ def get_trainable_params(self):
416
+ return self.parameters()
417
+
418
+ def save_weights(self, file, dtype, metadata):
419
+ if metadata is not None and len(metadata) == 0:
420
+ metadata = None
421
+
422
+ state_dict = self.state_dict()
423
+
424
+ if dtype is not None:
425
+ for key in list(state_dict.keys()):
426
+ v = state_dict[key]
427
+ v = v.detach().clone().to("cpu").to(dtype)
428
+ state_dict[key] = v
429
+
430
+ if os.path.splitext(file)[1] == ".safetensors":
431
+ from safetensors.torch import save_file
432
+ from library import train_util
433
+
434
+ # Precalculate model hashes to save time on indexing
435
+ if metadata is None:
436
+ metadata = {}
437
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
438
+ metadata["sshs_model_hash"] = model_hash
439
+ metadata["sshs_legacy_hash"] = legacy_hash
440
+
441
+ save_file(state_dict, file, metadata)
442
+ else:
443
+ torch.save(state_dict, file)
444
+
445
+ # mask is a tensor with values from 0 to 1
446
+ def set_region(self, sub_prompt_index, is_last_network, mask):
447
+ pass
448
+
449
+ def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
450
+ pass
external/llite/networks/extract_lora_from_dylora.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2
+ # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo
4
+
5
+ import argparse
6
+ import math
7
+ import os
8
+ import torch
9
+ from safetensors.torch import load_file, save_file, safe_open
10
+ from tqdm import tqdm
11
+ from library import train_util, model_util
12
+ import numpy as np
13
+
14
+
15
+ def load_state_dict(file_name):
16
+ if model_util.is_safetensors(file_name):
17
+ sd = load_file(file_name)
18
+ with safe_open(file_name, framework="pt") as f:
19
+ metadata = f.metadata()
20
+ else:
21
+ sd = torch.load(file_name, map_location="cpu")
22
+ metadata = None
23
+
24
+ return sd, metadata
25
+
26
+
27
+ def save_to_file(file_name, model, metadata):
28
+ if model_util.is_safetensors(file_name):
29
+ save_file(model, file_name, metadata)
30
+ else:
31
+ torch.save(model, file_name)
32
+
33
+
34
+ def split_lora_model(lora_sd, unit):
35
+ max_rank = 0
36
+
37
+ # Extract loaded lora dim and alpha
38
+ for key, value in lora_sd.items():
39
+ if "lora_down" in key:
40
+ rank = value.size()[0]
41
+ if rank > max_rank:
42
+ max_rank = rank
43
+ print(f"Max rank: {max_rank}")
44
+
45
+ rank = unit
46
+ split_models = []
47
+ new_alpha = None
48
+ while rank < max_rank:
49
+ print(f"Splitting rank {rank}")
50
+ new_sd = {}
51
+ for key, value in lora_sd.items():
52
+ if "lora_down" in key:
53
+ new_sd[key] = value[:rank].contiguous()
54
+ elif "lora_up" in key:
55
+ new_sd[key] = value[:, :rank].contiguous()
56
+ else:
57
+ # なぜかscaleするとおかしくなる……
58
+ # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
59
+ # scale = math.sqrt(this_rank / rank) # rank is > unit
60
+ # print(key, value.size(), this_rank, rank, value, scale)
61
+ # new_alpha = value * scale # always same
62
+ # new_sd[key] = new_alpha
63
+ new_sd[key] = value
64
+
65
+ split_models.append((new_sd, rank, new_alpha))
66
+ rank += unit
67
+
68
+ return max_rank, split_models
69
+
70
+
71
+ def split(args):
72
+ print("loading Model...")
73
+ lora_sd, metadata = load_state_dict(args.model)
74
+
75
+ print("Splitting Model...")
76
+ original_rank, split_models = split_lora_model(lora_sd, args.unit)
77
+
78
+ comment = metadata.get("ss_training_comment", "")
79
+ for state_dict, new_rank, new_alpha in split_models:
80
+ # update metadata
81
+ if metadata is None:
82
+ new_metadata = {}
83
+ else:
84
+ new_metadata = metadata.copy()
85
+
86
+ new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
87
+ new_metadata["ss_network_dim"] = str(new_rank)
88
+ # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
89
+
90
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
91
+ metadata["sshs_model_hash"] = model_hash
92
+ metadata["sshs_legacy_hash"] = legacy_hash
93
+
94
+ filename, ext = os.path.splitext(args.save_to)
95
+ model_file_name = filename + f"-{new_rank:04d}{ext}"
96
+
97
+ print(f"saving model to: {model_file_name}")
98
+ save_to_file(model_file_name, state_dict, new_metadata)
99
+
100
+
101
+ def setup_parser() -> argparse.ArgumentParser:
102
+ parser = argparse.ArgumentParser()
103
+
104
+ parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
105
+ parser.add_argument(
106
+ "--save_to",
107
+ type=str,
108
+ default=None,
109
+ help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
110
+ )
111
+ parser.add_argument(
112
+ "--model",
113
+ type=str,
114
+ default=None,
115
+ help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
116
+ )
117
+
118
+ return parser
119
+
120
+
121
+ if __name__ == "__main__":
122
+ parser = setup_parser()
123
+
124
+ args = parser.parse_args()
125
+ split(args)
external/llite/networks/extract_lora_from_models.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # extract approximating LoRA by svd from two SD models
2
+ # The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo!
4
+
5
+ import argparse
6
+ import json
7
+ import os
8
+ import time
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+ from tqdm import tqdm
12
+ from library import sai_model_spec, model_util, sdxl_model_util
13
+ import lora
14
+
15
+
16
+ # CLAMP_QUANTILE = 0.99
17
+ # MIN_DIFF = 1e-1
18
+
19
+
20
+ def save_to_file(file_name, model, state_dict, dtype):
21
+ if dtype is not None:
22
+ for key in list(state_dict.keys()):
23
+ if type(state_dict[key]) == torch.Tensor:
24
+ state_dict[key] = state_dict[key].to(dtype)
25
+
26
+ if os.path.splitext(file_name)[1] == ".safetensors":
27
+ save_file(model, file_name)
28
+ else:
29
+ torch.save(model, file_name)
30
+
31
+
32
+ def svd(
33
+ model_org=None,
34
+ model_tuned=None,
35
+ save_to=None,
36
+ dim=4,
37
+ v2=None,
38
+ sdxl=None,
39
+ conv_dim=None,
40
+ v_parameterization=None,
41
+ device=None,
42
+ save_precision=None,
43
+ clamp_quantile=0.99,
44
+ min_diff=0.01,
45
+ no_metadata=False,
46
+ ):
47
+ def str_to_dtype(p):
48
+ if p == "float":
49
+ return torch.float
50
+ if p == "fp16":
51
+ return torch.float16
52
+ if p == "bf16":
53
+ return torch.bfloat16
54
+ return None
55
+
56
+ assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
57
+ if v_parameterization is None:
58
+ v_parameterization = v2
59
+
60
+ save_dtype = str_to_dtype(save_precision)
61
+
62
+ # load models
63
+ if not sdxl:
64
+ print(f"loading original SD model : {model_org}")
65
+ text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
66
+ text_encoders_o = [text_encoder_o]
67
+ print(f"loading tuned SD model : {model_tuned}")
68
+ text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
69
+ text_encoders_t = [text_encoder_t]
70
+ model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
71
+ else:
72
+ print(f"loading original SDXL model : {model_org}")
73
+ text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
74
+ sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
75
+ )
76
+ text_encoders_o = [text_encoder_o1, text_encoder_o2]
77
+ print(f"loading original SDXL model : {model_tuned}")
78
+ text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
79
+ sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
80
+ )
81
+ text_encoders_t = [text_encoder_t1, text_encoder_t2]
82
+ model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
83
+
84
+ # create LoRA network to extract weights: Use dim (rank) as alpha
85
+ if conv_dim is None:
86
+ kwargs = {}
87
+ else:
88
+ kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
89
+
90
+ lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
91
+ lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
92
+ assert len(lora_network_o.text_encoder_loras) == len(
93
+ lora_network_t.text_encoder_loras
94
+ ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
95
+
96
+ # get diffs
97
+ diffs = {}
98
+ text_encoder_different = False
99
+ for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
100
+ lora_name = lora_o.lora_name
101
+ module_o = lora_o.org_module
102
+ module_t = lora_t.org_module
103
+ diff = module_t.weight - module_o.weight
104
+
105
+ # Text Encoder might be same
106
+ if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
107
+ text_encoder_different = True
108
+ print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
109
+
110
+ diff = diff.float()
111
+ diffs[lora_name] = diff
112
+
113
+ if not text_encoder_different:
114
+ print("Text encoder is same. Extract U-Net only.")
115
+ lora_network_o.text_encoder_loras = []
116
+ diffs = {}
117
+
118
+ for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
119
+ lora_name = lora_o.lora_name
120
+ module_o = lora_o.org_module
121
+ module_t = lora_t.org_module
122
+ diff = module_t.weight - module_o.weight
123
+ diff = diff.float()
124
+
125
+ if args.device:
126
+ diff = diff.to(args.device)
127
+
128
+ diffs[lora_name] = diff
129
+
130
+ # make LoRA with svd
131
+ print("calculating by svd")
132
+ lora_weights = {}
133
+ with torch.no_grad():
134
+ for lora_name, mat in tqdm(list(diffs.items())):
135
+ # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
136
+ conv2d = len(mat.size()) == 4
137
+ kernel_size = None if not conv2d else mat.size()[2:4]
138
+ conv2d_3x3 = conv2d and kernel_size != (1, 1)
139
+
140
+ rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
141
+ out_dim, in_dim = mat.size()[0:2]
142
+
143
+ if device:
144
+ mat = mat.to(device)
145
+
146
+ # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
147
+ rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
148
+
149
+ if conv2d:
150
+ if conv2d_3x3:
151
+ mat = mat.flatten(start_dim=1)
152
+ else:
153
+ mat = mat.squeeze()
154
+
155
+ U, S, Vh = torch.linalg.svd(mat)
156
+
157
+ U = U[:, :rank]
158
+ S = S[:rank]
159
+ U = U @ torch.diag(S)
160
+
161
+ Vh = Vh[:rank, :]
162
+
163
+ dist = torch.cat([U.flatten(), Vh.flatten()])
164
+ hi_val = torch.quantile(dist, clamp_quantile)
165
+ low_val = -hi_val
166
+
167
+ U = U.clamp(low_val, hi_val)
168
+ Vh = Vh.clamp(low_val, hi_val)
169
+
170
+ if conv2d:
171
+ U = U.reshape(out_dim, rank, 1, 1)
172
+ Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
173
+
174
+ U = U.to("cpu").contiguous()
175
+ Vh = Vh.to("cpu").contiguous()
176
+
177
+ lora_weights[lora_name] = (U, Vh)
178
+
179
+ # make state dict for LoRA
180
+ lora_sd = {}
181
+ for lora_name, (up_weight, down_weight) in lora_weights.items():
182
+ lora_sd[lora_name + ".lora_up.weight"] = up_weight
183
+ lora_sd[lora_name + ".lora_down.weight"] = down_weight
184
+ lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
185
+
186
+ # load state dict to LoRA and save it
187
+ lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
188
+ lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
189
+
190
+ info = lora_network_save.load_state_dict(lora_sd)
191
+ print(f"Loading extracted LoRA weights: {info}")
192
+
193
+ dir_name = os.path.dirname(save_to)
194
+ if dir_name and not os.path.exists(dir_name):
195
+ os.makedirs(dir_name, exist_ok=True)
196
+
197
+ # minimum metadata
198
+ net_kwargs = {}
199
+ if conv_dim is not None:
200
+ net_kwargs["conv_dim"] = str(conv_dim)
201
+ net_kwargs["conv_alpha"] = str(float(conv_dim))
202
+
203
+ metadata = {
204
+ "ss_v2": str(v2),
205
+ "ss_base_model_version": model_version,
206
+ "ss_network_module": "networks.lora",
207
+ "ss_network_dim": str(dim),
208
+ "ss_network_alpha": str(float(dim)),
209
+ "ss_network_args": json.dumps(net_kwargs),
210
+ }
211
+
212
+ if not no_metadata:
213
+ title = os.path.splitext(os.path.basename(save_to))[0]
214
+ sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
215
+ metadata.update(sai_metadata)
216
+
217
+ lora_network_save.save_weights(save_to, save_dtype, metadata)
218
+ print(f"LoRA weights are saved to: {save_to}")
219
+
220
+
221
+ def setup_parser() -> argparse.ArgumentParser:
222
+ parser = argparse.ArgumentParser()
223
+ parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
224
+ parser.add_argument(
225
+ "--v_parameterization",
226
+ action="store_true",
227
+ default=None,
228
+ help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
229
+ )
230
+ parser.add_argument(
231
+ "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
232
+ )
233
+ parser.add_argument(
234
+ "--save_precision",
235
+ type=str,
236
+ default=None,
237
+ choices=[None, "float", "fp16", "bf16"],
238
+ help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
239
+ )
240
+ parser.add_argument(
241
+ "--model_org",
242
+ type=str,
243
+ default=None,
244
+ required=True,
245
+ help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
246
+ )
247
+ parser.add_argument(
248
+ "--model_tuned",
249
+ type=str,
250
+ default=None,
251
+ required=True,
252
+ help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
253
+ )
254
+ parser.add_argument(
255
+ "--save_to",
256
+ type=str,
257
+ default=None,
258
+ required=True,
259
+ help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
260
+ )
261
+ parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
262
+ parser.add_argument(
263
+ "--conv_dim",
264
+ type=int,
265
+ default=None,
266
+ help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
267
+ )
268
+ parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
269
+ parser.add_argument(
270
+ "--clamp_quantile",
271
+ type=float,
272
+ default=0.99,
273
+ help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
274
+ )
275
+ parser.add_argument(
276
+ "--min_diff",
277
+ type=float,
278
+ default=0.01,
279
+ help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
280
+ + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
281
+ )
282
+ parser.add_argument(
283
+ "--no_metadata",
284
+ action="store_true",
285
+ help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
286
+ + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
287
+ )
288
+
289
+ return parser
290
+
291
+
292
+ if __name__ == "__main__":
293
+ parser = setup_parser()
294
+
295
+ args = parser.parse_args()
296
+ svd(**vars(args))
external/llite/networks/lora.py ADDED
@@ -0,0 +1,1225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+
6
+ import math
7
+ import os
8
+ from typing import Dict, List, Optional, Tuple, Type, Union
9
+ from diffusers import AutoencoderKL
10
+ from transformers import CLIPTextModel
11
+ import numpy as np
12
+ import torch
13
+ import re
14
+
15
+
16
+ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
17
+
18
+
19
+ class LoRAModule(torch.nn.Module):
20
+ """
21
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ lora_name,
27
+ org_module: torch.nn.Module,
28
+ multiplier=1.0,
29
+ lora_dim=4,
30
+ alpha=1,
31
+ dropout=None,
32
+ rank_dropout=None,
33
+ module_dropout=None,
34
+ ):
35
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
36
+ super().__init__()
37
+ self.lora_name = lora_name
38
+
39
+ if org_module.__class__.__name__ == "Conv2d":
40
+ in_dim = org_module.in_channels
41
+ out_dim = org_module.out_channels
42
+ else:
43
+ in_dim = org_module.in_features
44
+ out_dim = org_module.out_features
45
+
46
+ # if limit_rank:
47
+ # self.lora_dim = min(lora_dim, in_dim, out_dim)
48
+ # if self.lora_dim != lora_dim:
49
+ # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
50
+ # else:
51
+ self.lora_dim = lora_dim
52
+
53
+ if org_module.__class__.__name__ == "Conv2d":
54
+ kernel_size = org_module.kernel_size
55
+ stride = org_module.stride
56
+ padding = org_module.padding
57
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
58
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
59
+ else:
60
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
61
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
62
+
63
+ if type(alpha) == torch.Tensor:
64
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
65
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
66
+ self.scale = alpha / self.lora_dim
67
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
68
+
69
+ # same as microsoft's
70
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
71
+ torch.nn.init.zeros_(self.lora_up.weight)
72
+
73
+ self.multiplier = multiplier
74
+ self.org_module = org_module # remove in applying
75
+ self.dropout = dropout
76
+ self.rank_dropout = rank_dropout
77
+ self.module_dropout = module_dropout
78
+
79
+ def apply_to(self):
80
+ self.org_forward = self.org_module.forward
81
+ self.org_module.forward = self.forward
82
+ del self.org_module
83
+
84
+ def forward(self, x):
85
+ org_forwarded = self.org_forward(x)
86
+
87
+ # module dropout
88
+ if self.module_dropout is not None and self.training:
89
+ if torch.rand(1) < self.module_dropout:
90
+ return org_forwarded
91
+
92
+ lx = self.lora_down(x)
93
+
94
+ # normal dropout
95
+ if self.dropout is not None and self.training:
96
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
97
+
98
+ # rank dropout
99
+ if self.rank_dropout is not None and self.training:
100
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
101
+ if len(lx.size()) == 3:
102
+ mask = mask.unsqueeze(1) # for Text Encoder
103
+ elif len(lx.size()) == 4:
104
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
105
+ lx = lx * mask
106
+
107
+ # scaling for rank dropout: treat as if the rank is changed
108
+ # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
109
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
110
+ else:
111
+ scale = self.scale
112
+
113
+ lx = self.lora_up(lx)
114
+
115
+ return org_forwarded + lx * self.multiplier * scale
116
+
117
+
118
+ class LoRAInfModule(LoRAModule):
119
+ def __init__(
120
+ self,
121
+ lora_name,
122
+ org_module: torch.nn.Module,
123
+ multiplier=1.0,
124
+ lora_dim=4,
125
+ alpha=1,
126
+ **kwargs,
127
+ ):
128
+ # no dropout for inference
129
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
130
+
131
+ self.org_module_ref = [org_module] # 後から参照できるように
132
+ self.enabled = True
133
+
134
+ # check regional or not by lora_name
135
+ self.text_encoder = False
136
+ if lora_name.startswith("lora_te_"):
137
+ self.regional = False
138
+ self.use_sub_prompt = True
139
+ self.text_encoder = True
140
+ elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
141
+ self.regional = False
142
+ self.use_sub_prompt = True
143
+ elif "time_emb" in lora_name:
144
+ self.regional = False
145
+ self.use_sub_prompt = False
146
+ else:
147
+ self.regional = True
148
+ self.use_sub_prompt = False
149
+
150
+ self.network: LoRANetwork = None
151
+
152
+ def set_network(self, network):
153
+ self.network = network
154
+
155
+ # freezeしてマージする
156
+ def merge_to(self, sd, dtype, device):
157
+ # get up/down weight
158
+ up_weight = sd["lora_up.weight"].to(torch.float).to(device)
159
+ down_weight = sd["lora_down.weight"].to(torch.float).to(device)
160
+
161
+ # extract weight from org_module
162
+ org_sd = self.org_module.state_dict()
163
+ weight = org_sd["weight"].to(torch.float)
164
+
165
+ # merge weight
166
+ if len(weight.size()) == 2:
167
+ # linear
168
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
169
+ elif down_weight.size()[2:4] == (1, 1):
170
+ # conv2d 1x1
171
+ weight = (
172
+ weight
173
+ + self.multiplier
174
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
175
+ * self.scale
176
+ )
177
+ else:
178
+ # conv2d 3x3
179
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
180
+ # print(conved.size(), weight.size(), module.stride, module.padding)
181
+ weight = weight + self.multiplier * conved * self.scale
182
+
183
+ # set weight to org_module
184
+ org_sd["weight"] = weight.to(dtype)
185
+ self.org_module.load_state_dict(org_sd)
186
+
187
+ # 復元できるマージのため、このモジュールのweightを返す
188
+ def get_weight(self, multiplier=None):
189
+ if multiplier is None:
190
+ multiplier = self.multiplier
191
+
192
+ # get up/down weight from module
193
+ up_weight = self.lora_up.weight.to(torch.float)
194
+ down_weight = self.lora_down.weight.to(torch.float)
195
+
196
+ # pre-calculated weight
197
+ if len(down_weight.size()) == 2:
198
+ # linear
199
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
200
+ elif down_weight.size()[2:4] == (1, 1):
201
+ # conv2d 1x1
202
+ weight = (
203
+ self.multiplier
204
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
205
+ * self.scale
206
+ )
207
+ else:
208
+ # conv2d 3x3
209
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
210
+ weight = self.multiplier * conved * self.scale
211
+
212
+ return weight
213
+
214
+ def set_region(self, region):
215
+ self.region = region
216
+ self.region_mask = None
217
+
218
+ def default_forward(self, x):
219
+ # print("default_forward", self.lora_name, x.size())
220
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
221
+
222
+ def forward(self, x):
223
+ if not self.enabled:
224
+ return self.org_forward(x)
225
+
226
+ if self.network is None or self.network.sub_prompt_index is None:
227
+ return self.default_forward(x)
228
+ if not self.regional and not self.use_sub_prompt:
229
+ return self.default_forward(x)
230
+
231
+ if self.regional:
232
+ return self.regional_forward(x)
233
+ else:
234
+ return self.sub_prompt_forward(x)
235
+
236
+ def get_mask_for_x(self, x):
237
+ # calculate size from shape of x
238
+ if len(x.size()) == 4:
239
+ h, w = x.size()[2:4]
240
+ area = h * w
241
+ else:
242
+ area = x.size()[1]
243
+
244
+ mask = self.network.mask_dic.get(area, None)
245
+ if mask is None:
246
+ # raise ValueError(f"mask is None for resolution {area}")
247
+ # emb_layers in SDXL doesn't have mask
248
+ # print(f"mask is None for resolution {area}, {x.size()}")
249
+ mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
250
+ return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
251
+ if len(x.size()) != 4:
252
+ mask = torch.reshape(mask, (1, -1, 1))
253
+ return mask
254
+
255
+ def regional_forward(self, x):
256
+ if "attn2_to_out" in self.lora_name:
257
+ return self.to_out_forward(x)
258
+
259
+ if self.network.mask_dic is None: # sub_prompt_index >= 3
260
+ return self.default_forward(x)
261
+
262
+ # apply mask for LoRA result
263
+ lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
264
+ mask = self.get_mask_for_x(lx)
265
+ # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
266
+ lx = lx * mask
267
+
268
+ x = self.org_forward(x)
269
+ x = x + lx
270
+
271
+ if "attn2_to_q" in self.lora_name and self.network.is_last_network:
272
+ x = self.postp_to_q(x)
273
+
274
+ return x
275
+
276
+ def postp_to_q(self, x):
277
+ # repeat x to num_sub_prompts
278
+ has_real_uncond = x.size()[0] // self.network.batch_size == 3
279
+ qc = self.network.batch_size # uncond
280
+ qc += self.network.batch_size * self.network.num_sub_prompts # cond
281
+ if has_real_uncond:
282
+ qc += self.network.batch_size # real_uncond
283
+
284
+ query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
285
+ query[: self.network.batch_size] = x[: self.network.batch_size]
286
+
287
+ for i in range(self.network.batch_size):
288
+ qi = self.network.batch_size + i * self.network.num_sub_prompts
289
+ query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
290
+
291
+ if has_real_uncond:
292
+ query[-self.network.batch_size :] = x[-self.network.batch_size :]
293
+
294
+ # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
295
+ return query
296
+
297
+ def sub_prompt_forward(self, x):
298
+ if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
299
+ return self.org_forward(x)
300
+
301
+ emb_idx = self.network.sub_prompt_index
302
+ if not self.text_encoder:
303
+ emb_idx += self.network.batch_size
304
+
305
+ # apply sub prompt of X
306
+ lx = x[emb_idx :: self.network.num_sub_prompts]
307
+ lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
308
+
309
+ # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
310
+
311
+ x = self.org_forward(x)
312
+ x[emb_idx :: self.network.num_sub_prompts] += lx
313
+
314
+ return x
315
+
316
+ def to_out_forward(self, x):
317
+ # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
318
+
319
+ if self.network.is_last_network:
320
+ masks = [None] * self.network.num_sub_prompts
321
+ self.network.shared[self.lora_name] = (None, masks)
322
+ else:
323
+ lx, masks = self.network.shared[self.lora_name]
324
+
325
+ # call own LoRA
326
+ x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
327
+ lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
328
+
329
+ if self.network.is_last_network:
330
+ lx = torch.zeros(
331
+ (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
332
+ )
333
+ self.network.shared[self.lora_name] = (lx, masks)
334
+
335
+ # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
336
+ lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
337
+ masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
338
+
339
+ # if not last network, return x and masks
340
+ x = self.org_forward(x)
341
+ if not self.network.is_last_network:
342
+ return x
343
+
344
+ lx, masks = self.network.shared.pop(self.lora_name)
345
+
346
+ # if last network, combine separated x with mask weighted sum
347
+ has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
348
+
349
+ out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
350
+ out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
351
+ if has_real_uncond:
352
+ out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
353
+
354
+ # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
355
+ # if num_sub_prompts > num of LoRAs, fill with zero
356
+ for i in range(len(masks)):
357
+ if masks[i] is None:
358
+ masks[i] = torch.zeros_like(masks[0])
359
+
360
+ mask = torch.cat(masks)
361
+ mask_sum = torch.sum(mask, dim=0) + 1e-4
362
+ for i in range(self.network.batch_size):
363
+ # 1枚の画像ごとに処理する
364
+ lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
365
+ lx1 = lx1 * mask
366
+ lx1 = torch.sum(lx1, dim=0)
367
+
368
+ xi = self.network.batch_size + i * self.network.num_sub_prompts
369
+ x1 = x[xi : xi + self.network.num_sub_prompts]
370
+ x1 = x1 * mask
371
+ x1 = torch.sum(x1, dim=0)
372
+ x1 = x1 / mask_sum
373
+
374
+ x1 = x1 + lx1
375
+ out[self.network.batch_size + i] = x1
376
+
377
+ # print("to_out_forward", x.size(), out.size(), has_real_uncond)
378
+ return out
379
+
380
+
381
+ def parse_block_lr_kwargs(nw_kwargs):
382
+ down_lr_weight = nw_kwargs.get("down_lr_weight", None)
383
+ mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
384
+ up_lr_weight = nw_kwargs.get("up_lr_weight", None)
385
+
386
+ # 以上のいずれにも設定がない場合は無効としてNoneを返す
387
+ if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
388
+ return None, None, None
389
+
390
+ # extract learning rate weight for each block
391
+ if down_lr_weight is not None:
392
+ # if some parameters are not set, use zero
393
+ if "," in down_lr_weight:
394
+ down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
395
+
396
+ if mid_lr_weight is not None:
397
+ mid_lr_weight = float(mid_lr_weight)
398
+
399
+ if up_lr_weight is not None:
400
+ if "," in up_lr_weight:
401
+ up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
402
+
403
+ down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
404
+ down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
405
+ )
406
+
407
+ return down_lr_weight, mid_lr_weight, up_lr_weight
408
+
409
+
410
+ def create_network(
411
+ multiplier: float,
412
+ network_dim: Optional[int],
413
+ network_alpha: Optional[float],
414
+ vae: AutoencoderKL,
415
+ text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
416
+ unet,
417
+ neuron_dropout: Optional[float] = None,
418
+ **kwargs,
419
+ ):
420
+ if network_dim is None:
421
+ network_dim = 4 # default
422
+ if network_alpha is None:
423
+ network_alpha = 1.0
424
+
425
+ # extract dim/alpha for conv2d, and block dim
426
+ conv_dim = kwargs.get("conv_dim", None)
427
+ conv_alpha = kwargs.get("conv_alpha", None)
428
+ if conv_dim is not None:
429
+ conv_dim = int(conv_dim)
430
+ if conv_alpha is None:
431
+ conv_alpha = 1.0
432
+ else:
433
+ conv_alpha = float(conv_alpha)
434
+
435
+ # block dim/alpha/lr
436
+ block_dims = kwargs.get("block_dims", None)
437
+ down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
438
+
439
+ # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
440
+ if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
441
+ block_alphas = kwargs.get("block_alphas", None)
442
+ conv_block_dims = kwargs.get("conv_block_dims", None)
443
+ conv_block_alphas = kwargs.get("conv_block_alphas", None)
444
+
445
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
446
+ block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
447
+ )
448
+
449
+ # remove block dim/alpha without learning rate
450
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
451
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
452
+ )
453
+
454
+ else:
455
+ block_alphas = None
456
+ conv_block_dims = None
457
+ conv_block_alphas = None
458
+
459
+ # rank/module dropout
460
+ rank_dropout = kwargs.get("rank_dropout", None)
461
+ if rank_dropout is not None:
462
+ rank_dropout = float(rank_dropout)
463
+ module_dropout = kwargs.get("module_dropout", None)
464
+ if module_dropout is not None:
465
+ module_dropout = float(module_dropout)
466
+
467
+ # すごく引数が多いな ( ^ω^)・・・
468
+ network = LoRANetwork(
469
+ text_encoder,
470
+ unet,
471
+ multiplier=multiplier,
472
+ lora_dim=network_dim,
473
+ alpha=network_alpha,
474
+ dropout=neuron_dropout,
475
+ rank_dropout=rank_dropout,
476
+ module_dropout=module_dropout,
477
+ conv_lora_dim=conv_dim,
478
+ conv_alpha=conv_alpha,
479
+ block_dims=block_dims,
480
+ block_alphas=block_alphas,
481
+ conv_block_dims=conv_block_dims,
482
+ conv_block_alphas=conv_block_alphas,
483
+ varbose=True,
484
+ )
485
+
486
+ if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
487
+ network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
488
+
489
+ return network
490
+
491
+
492
+ # このメソッドは外部から呼び出される可能性を考慮しておく
493
+ # network_dim, network_alpha にはデフォルト値が入っている。
494
+ # block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
495
+ # conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
496
+ def get_block_dims_and_alphas(
497
+ block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
498
+ ):
499
+ num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
500
+
501
+ def parse_ints(s):
502
+ return [int(i) for i in s.split(",")]
503
+
504
+ def parse_floats(s):
505
+ return [float(i) for i in s.split(",")]
506
+
507
+ # block_dimsとblock_alphasをパースする。必ず値が入る
508
+ if block_dims is not None:
509
+ block_dims = parse_ints(block_dims)
510
+ assert (
511
+ len(block_dims) == num_total_blocks
512
+ ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
513
+ else:
514
+ print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
515
+ block_dims = [network_dim] * num_total_blocks
516
+
517
+ if block_alphas is not None:
518
+ block_alphas = parse_floats(block_alphas)
519
+ assert (
520
+ len(block_alphas) == num_total_blocks
521
+ ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してくださ��"
522
+ else:
523
+ print(
524
+ f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
525
+ )
526
+ block_alphas = [network_alpha] * num_total_blocks
527
+
528
+ # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
529
+ if conv_block_dims is not None:
530
+ conv_block_dims = parse_ints(conv_block_dims)
531
+ assert (
532
+ len(conv_block_dims) == num_total_blocks
533
+ ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
534
+
535
+ if conv_block_alphas is not None:
536
+ conv_block_alphas = parse_floats(conv_block_alphas)
537
+ assert (
538
+ len(conv_block_alphas) == num_total_blocks
539
+ ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
540
+ else:
541
+ if conv_alpha is None:
542
+ conv_alpha = 1.0
543
+ print(
544
+ f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
545
+ )
546
+ conv_block_alphas = [conv_alpha] * num_total_blocks
547
+ else:
548
+ if conv_dim is not None:
549
+ print(
550
+ f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
551
+ )
552
+ conv_block_dims = [conv_dim] * num_total_blocks
553
+ conv_block_alphas = [conv_alpha] * num_total_blocks
554
+ else:
555
+ conv_block_dims = None
556
+ conv_block_alphas = None
557
+
558
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
559
+
560
+
561
+ # 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
562
+ def get_block_lr_weight(
563
+ down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
564
+ ) -> Tuple[List[float], List[float], List[float]]:
565
+ # パラメータ未指定時は何もせず、今までと同じ動作とする
566
+ if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
567
+ return None, None, None
568
+
569
+ max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
570
+
571
+ def get_list(name_with_suffix) -> List[float]:
572
+ import math
573
+
574
+ tokens = name_with_suffix.split("+")
575
+ name = tokens[0]
576
+ base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
577
+
578
+ if name == "cosine":
579
+ return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
580
+ elif name == "sine":
581
+ return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
582
+ elif name == "linear":
583
+ return [i / (max_len - 1) + base_lr for i in range(max_len)]
584
+ elif name == "reverse_linear":
585
+ return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
586
+ elif name == "zeros":
587
+ return [0.0 + base_lr] * max_len
588
+ else:
589
+ print(
590
+ "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
591
+ % (name)
592
+ )
593
+ return None
594
+
595
+ if type(down_lr_weight) == str:
596
+ down_lr_weight = get_list(down_lr_weight)
597
+ if type(up_lr_weight) == str:
598
+ up_lr_weight = get_list(up_lr_weight)
599
+
600
+ if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
601
+ print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
602
+ print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
603
+ up_lr_weight = up_lr_weight[:max_len]
604
+ down_lr_weight = down_lr_weight[:max_len]
605
+
606
+ if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
607
+ print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
608
+ print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
609
+
610
+ if down_lr_weight != None and len(down_lr_weight) < max_len:
611
+ down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
612
+ if up_lr_weight != None and len(up_lr_weight) < max_len:
613
+ up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
614
+
615
+ if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
616
+ print("apply block learning rate / 階層別学習率を適用します。")
617
+ if down_lr_weight != None:
618
+ down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
619
+ print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
620
+ else:
621
+ print("down_lr_weight: all 1.0, すべて1.0")
622
+
623
+ if mid_lr_weight != None:
624
+ mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
625
+ print("mid_lr_weight:", mid_lr_weight)
626
+ else:
627
+ print("mid_lr_weight: 1.0")
628
+
629
+ if up_lr_weight != None:
630
+ up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
631
+ print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
632
+ else:
633
+ print("up_lr_weight: all 1.0, すべて1.0")
634
+
635
+ return down_lr_weight, mid_lr_weight, up_lr_weight
636
+
637
+
638
+ # lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
639
+ def remove_block_dims_and_alphas(
640
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
641
+ ):
642
+ # set 0 to block dim without learning rate to remove the block
643
+ if down_lr_weight != None:
644
+ for i, lr in enumerate(down_lr_weight):
645
+ if lr == 0:
646
+ block_dims[i] = 0
647
+ if conv_block_dims is not None:
648
+ conv_block_dims[i] = 0
649
+ if mid_lr_weight != None:
650
+ if mid_lr_weight == 0:
651
+ block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
652
+ if conv_block_dims is not None:
653
+ conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
654
+ if up_lr_weight != None:
655
+ for i, lr in enumerate(up_lr_weight):
656
+ if lr == 0:
657
+ block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
658
+ if conv_block_dims is not None:
659
+ conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
660
+
661
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
662
+
663
+
664
+ # 外部から呼び出す可能性を考慮しておく
665
+ def get_block_index(lora_name: str) -> int:
666
+ block_idx = -1 # invalid lora name
667
+
668
+ m = RE_UPDOWN.search(lora_name)
669
+ if m:
670
+ g = m.groups()
671
+ i = int(g[1])
672
+ j = int(g[3])
673
+ if g[2] == "resnets":
674
+ idx = 3 * i + j
675
+ elif g[2] == "attentions":
676
+ idx = 3 * i + j
677
+ elif g[2] == "upsamplers" or g[2] == "downsamplers":
678
+ idx = 3 * i + 2
679
+
680
+ if g[0] == "down":
681
+ block_idx = 1 + idx # 0に該当するLoRAは存在しない
682
+ elif g[0] == "up":
683
+ block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
684
+
685
+ elif "mid_block_" in lora_name:
686
+ block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
687
+
688
+ return block_idx
689
+
690
+
691
+ # Create network from weights for inference, weights are not loaded here (because can be merged)
692
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
693
+ if weights_sd is None:
694
+ if os.path.splitext(file)[1] == ".safetensors":
695
+ from safetensors.torch import load_file, safe_open
696
+
697
+ weights_sd = load_file(file)
698
+ else:
699
+ weights_sd = torch.load(file, map_location="cpu")
700
+
701
+ # get dim/alpha mapping
702
+ modules_dim = {}
703
+ modules_alpha = {}
704
+ for key, value in weights_sd.items():
705
+ if "." not in key:
706
+ continue
707
+
708
+ lora_name = key.split(".")[0]
709
+ if "alpha" in key:
710
+ modules_alpha[lora_name] = value
711
+ elif "lora_down" in key:
712
+ dim = value.size()[0]
713
+ modules_dim[lora_name] = dim
714
+ # print(lora_name, value.size(), dim)
715
+
716
+ # support old LoRA without alpha
717
+ for key in modules_dim.keys():
718
+ if key not in modules_alpha:
719
+ modules_alpha[key] = modules_dim[key]
720
+
721
+ module_class = LoRAInfModule if for_inference else LoRAModule
722
+
723
+ network = LoRANetwork(
724
+ text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
725
+ )
726
+
727
+ # block lr
728
+ down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
729
+ if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
730
+ network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
731
+
732
+ return network, weights_sd
733
+
734
+
735
+ class LoRANetwork(torch.nn.Module):
736
+ NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
737
+
738
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
739
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
740
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
741
+ LORA_PREFIX_UNET = "lora_unet"
742
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
743
+
744
+ # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
745
+ LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
746
+ LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
747
+
748
+ def __init__(
749
+ self,
750
+ text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
751
+ unet,
752
+ multiplier: float = 1.0,
753
+ lora_dim: int = 4,
754
+ alpha: float = 1,
755
+ dropout: Optional[float] = None,
756
+ rank_dropout: Optional[float] = None,
757
+ module_dropout: Optional[float] = None,
758
+ conv_lora_dim: Optional[int] = None,
759
+ conv_alpha: Optional[float] = None,
760
+ block_dims: Optional[List[int]] = None,
761
+ block_alphas: Optional[List[float]] = None,
762
+ conv_block_dims: Optional[List[int]] = None,
763
+ conv_block_alphas: Optional[List[float]] = None,
764
+ modules_dim: Optional[Dict[str, int]] = None,
765
+ modules_alpha: Optional[Dict[str, int]] = None,
766
+ module_class: Type[object] = LoRAModule,
767
+ varbose: Optional[bool] = False,
768
+ ) -> None:
769
+ """
770
+ LoRA network: すごく引数が多いが、パターンは以下の通り
771
+ 1. lora_dimとalphaを指定
772
+ 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
773
+ 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
774
+ 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
775
+ 5. modules_dimとmodules_alphaを指定 (推論用)
776
+ """
777
+ super().__init__()
778
+ self.multiplier = multiplier
779
+
780
+ self.lora_dim = lora_dim
781
+ self.alpha = alpha
782
+ self.conv_lora_dim = conv_lora_dim
783
+ self.conv_alpha = conv_alpha
784
+ self.dropout = dropout
785
+ self.rank_dropout = rank_dropout
786
+ self.module_dropout = module_dropout
787
+
788
+ if modules_dim is not None:
789
+ print(f"create LoRA network from weights")
790
+ elif block_dims is not None:
791
+ print(f"create LoRA network from block_dims")
792
+ print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
793
+ print(f"block_dims: {block_dims}")
794
+ print(f"block_alphas: {block_alphas}")
795
+ if conv_block_dims is not None:
796
+ print(f"conv_block_dims: {conv_block_dims}")
797
+ print(f"conv_block_alphas: {conv_block_alphas}")
798
+ else:
799
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
800
+ print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
801
+ if self.conv_lora_dim is not None:
802
+ print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
803
+
804
+ # create module instances
805
+ def create_modules(
806
+ is_unet: bool,
807
+ text_encoder_idx: Optional[int], # None, 1, 2
808
+ root_module: torch.nn.Module,
809
+ target_replace_modules: List[torch.nn.Module],
810
+ ) -> List[LoRAModule]:
811
+ prefix = (
812
+ self.LORA_PREFIX_UNET
813
+ if is_unet
814
+ else (
815
+ self.LORA_PREFIX_TEXT_ENCODER
816
+ if text_encoder_idx is None
817
+ else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
818
+ )
819
+ )
820
+ loras = []
821
+ skipped = []
822
+ for name, module in root_module.named_modules():
823
+ if module.__class__.__name__ in target_replace_modules:
824
+ for child_name, child_module in module.named_modules():
825
+ is_linear = child_module.__class__.__name__ == "Linear"
826
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
827
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
828
+
829
+ if is_linear or is_conv2d:
830
+ lora_name = prefix + "." + name + "." + child_name
831
+ lora_name = lora_name.replace(".", "_")
832
+
833
+ dim = None
834
+ alpha = None
835
+
836
+ if modules_dim is not None:
837
+ # モジュール指定あり
838
+ if lora_name in modules_dim:
839
+ dim = modules_dim[lora_name]
840
+ alpha = modules_alpha[lora_name]
841
+ elif is_unet and block_dims is not None:
842
+ # U-Netでblock_dims指定あり
843
+ block_idx = get_block_index(lora_name)
844
+ if is_linear or is_conv2d_1x1:
845
+ dim = block_dims[block_idx]
846
+ alpha = block_alphas[block_idx]
847
+ elif conv_block_dims is not None:
848
+ dim = conv_block_dims[block_idx]
849
+ alpha = conv_block_alphas[block_idx]
850
+ else:
851
+ # 通常、すべて対象とする
852
+ if is_linear or is_conv2d_1x1:
853
+ dim = self.lora_dim
854
+ alpha = self.alpha
855
+ elif self.conv_lora_dim is not None:
856
+ dim = self.conv_lora_dim
857
+ alpha = self.conv_alpha
858
+
859
+ if dim is None or dim == 0:
860
+ # skipした情報を出力
861
+ if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
862
+ skipped.append(lora_name)
863
+ continue
864
+
865
+ lora = module_class(
866
+ lora_name,
867
+ child_module,
868
+ self.multiplier,
869
+ dim,
870
+ alpha,
871
+ dropout=dropout,
872
+ rank_dropout=rank_dropout,
873
+ module_dropout=module_dropout,
874
+ )
875
+ loras.append(lora)
876
+ return loras, skipped
877
+
878
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
879
+
880
+ # create LoRA for text encoder
881
+ # 毎回すべてのモジュールを作るのは無駄なので要検討
882
+ self.text_encoder_loras = []
883
+ skipped_te = []
884
+ for i, text_encoder in enumerate(text_encoders):
885
+ if len(text_encoders) > 1:
886
+ index = i + 1
887
+ print(f"create LoRA for Text Encoder {index}:")
888
+ else:
889
+ index = None
890
+ print(f"create LoRA for Text Encoder:")
891
+
892
+ text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
893
+ self.text_encoder_loras.extend(text_encoder_loras)
894
+ skipped_te += skipped
895
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
896
+
897
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
898
+ target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
899
+ if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
900
+ target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
901
+
902
+ self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
903
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
904
+
905
+ skipped = skipped_te + skipped_un
906
+ if varbose and len(skipped) > 0:
907
+ print(
908
+ f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
909
+ )
910
+ for name in skipped:
911
+ print(f"\t{name}")
912
+
913
+ self.up_lr_weight: List[float] = None
914
+ self.down_lr_weight: List[float] = None
915
+ self.mid_lr_weight: float = None
916
+ self.block_lr = False
917
+
918
+ # assertion
919
+ names = set()
920
+ for lora in self.text_encoder_loras + self.unet_loras:
921
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
922
+ names.add(lora.lora_name)
923
+
924
+ def set_multiplier(self, multiplier):
925
+ self.multiplier = multiplier
926
+ for lora in self.text_encoder_loras + self.unet_loras:
927
+ lora.multiplier = self.multiplier
928
+
929
+ def load_weights(self, file):
930
+ if os.path.splitext(file)[1] == ".safetensors":
931
+ from safetensors.torch import load_file
932
+
933
+ weights_sd = load_file(file)
934
+ else:
935
+ weights_sd = torch.load(file, map_location="cpu")
936
+
937
+ info = self.load_state_dict(weights_sd, False)
938
+ return info
939
+
940
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
941
+ if apply_text_encoder:
942
+ print("enable LoRA for text encoder")
943
+ else:
944
+ self.text_encoder_loras = []
945
+
946
+ if apply_unet:
947
+ print("enable LoRA for U-Net")
948
+ else:
949
+ self.unet_loras = []
950
+
951
+ for lora in self.text_encoder_loras + self.unet_loras:
952
+ lora.apply_to()
953
+ self.add_module(lora.lora_name, lora)
954
+
955
+ # マージできるかどうかを返す
956
+ def is_mergeable(self):
957
+ return True
958
+
959
+ # TODO refactor to common function with apply_to
960
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
961
+ apply_text_encoder = apply_unet = False
962
+ for key in weights_sd.keys():
963
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
964
+ apply_text_encoder = True
965
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
966
+ apply_unet = True
967
+
968
+ if apply_text_encoder:
969
+ print("enable LoRA for text encoder")
970
+ else:
971
+ self.text_encoder_loras = []
972
+
973
+ if apply_unet:
974
+ print("enable LoRA for U-Net")
975
+ else:
976
+ self.unet_loras = []
977
+
978
+ for lora in self.text_encoder_loras + self.unet_loras:
979
+ sd_for_lora = {}
980
+ for key in weights_sd.keys():
981
+ if key.startswith(lora.lora_name):
982
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
983
+ lora.merge_to(sd_for_lora, dtype, device)
984
+
985
+ print(f"weights are merged")
986
+
987
+ # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
988
+ def set_block_lr_weight(
989
+ self,
990
+ up_lr_weight: List[float] = None,
991
+ mid_lr_weight: float = None,
992
+ down_lr_weight: List[float] = None,
993
+ ):
994
+ self.block_lr = True
995
+ self.down_lr_weight = down_lr_weight
996
+ self.mid_lr_weight = mid_lr_weight
997
+ self.up_lr_weight = up_lr_weight
998
+
999
+ def get_lr_weight(self, lora: LoRAModule) -> float:
1000
+ lr_weight = 1.0
1001
+ block_idx = get_block_index(lora.lora_name)
1002
+ if block_idx < 0:
1003
+ return lr_weight
1004
+
1005
+ if block_idx < LoRANetwork.NUM_OF_BLOCKS:
1006
+ if self.down_lr_weight != None:
1007
+ lr_weight = self.down_lr_weight[block_idx]
1008
+ elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
1009
+ if self.mid_lr_weight != None:
1010
+ lr_weight = self.mid_lr_weight
1011
+ elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
1012
+ if self.up_lr_weight != None:
1013
+ lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
1014
+
1015
+ return lr_weight
1016
+
1017
+ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1018
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1019
+ self.requires_grad_(True)
1020
+ all_params = []
1021
+
1022
+ def enumerate_params(loras):
1023
+ params = []
1024
+ for lora in loras:
1025
+ params.extend(lora.parameters())
1026
+ return params
1027
+
1028
+ if self.text_encoder_loras:
1029
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
1030
+ if text_encoder_lr is not None:
1031
+ param_data["lr"] = text_encoder_lr
1032
+ all_params.append(param_data)
1033
+
1034
+ if self.unet_loras:
1035
+ if self.block_lr:
1036
+ # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
1037
+ block_idx_to_lora = {}
1038
+ for lora in self.unet_loras:
1039
+ idx = get_block_index(lora.lora_name)
1040
+ if idx not in block_idx_to_lora:
1041
+ block_idx_to_lora[idx] = []
1042
+ block_idx_to_lora[idx].append(lora)
1043
+
1044
+ # blockごとにパラメータを設定する
1045
+ for idx, block_loras in block_idx_to_lora.items():
1046
+ param_data = {"params": enumerate_params(block_loras)}
1047
+
1048
+ if unet_lr is not None:
1049
+ param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
1050
+ elif default_lr is not None:
1051
+ param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
1052
+ if ("lr" in param_data) and (param_data["lr"] == 0):
1053
+ continue
1054
+ all_params.append(param_data)
1055
+
1056
+ else:
1057
+ param_data = {"params": enumerate_params(self.unet_loras)}
1058
+ if unet_lr is not None:
1059
+ param_data["lr"] = unet_lr
1060
+ all_params.append(param_data)
1061
+
1062
+ return all_params
1063
+
1064
+ def enable_gradient_checkpointing(self):
1065
+ # not supported
1066
+ pass
1067
+
1068
+ def prepare_grad_etc(self, text_encoder, unet):
1069
+ self.requires_grad_(True)
1070
+
1071
+ def on_epoch_start(self, text_encoder, unet):
1072
+ self.train()
1073
+
1074
+ def get_trainable_params(self):
1075
+ return self.parameters()
1076
+
1077
+ def save_weights(self, file, dtype, metadata):
1078
+ if metadata is not None and len(metadata) == 0:
1079
+ metadata = None
1080
+
1081
+ state_dict = self.state_dict()
1082
+
1083
+ if dtype is not None:
1084
+ for key in list(state_dict.keys()):
1085
+ v = state_dict[key]
1086
+ v = v.detach().clone().to("cpu").to(dtype)
1087
+ state_dict[key] = v
1088
+
1089
+ if os.path.splitext(file)[1] == ".safetensors":
1090
+ from safetensors.torch import save_file
1091
+ from library import train_util
1092
+
1093
+ # Precalculate model hashes to save time on indexing
1094
+ if metadata is None:
1095
+ metadata = {}
1096
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
1097
+ metadata["sshs_model_hash"] = model_hash
1098
+ metadata["sshs_legacy_hash"] = legacy_hash
1099
+
1100
+ save_file(state_dict, file, metadata)
1101
+ else:
1102
+ torch.save(state_dict, file)
1103
+
1104
+ # mask is a tensor with values from 0 to 1
1105
+ def set_region(self, sub_prompt_index, is_last_network, mask):
1106
+ if mask.max() == 0:
1107
+ mask = torch.ones_like(mask)
1108
+
1109
+ self.mask = mask
1110
+ self.sub_prompt_index = sub_prompt_index
1111
+ self.is_last_network = is_last_network
1112
+
1113
+ for lora in self.text_encoder_loras + self.unet_loras:
1114
+ lora.set_network(self)
1115
+
1116
+ def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
1117
+ self.batch_size = batch_size
1118
+ self.num_sub_prompts = num_sub_prompts
1119
+ self.current_size = (height, width)
1120
+ self.shared = shared
1121
+
1122
+ # create masks
1123
+ mask = self.mask
1124
+ mask_dic = {}
1125
+ mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
1126
+ ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
1127
+ dtype = ref_weight.dtype
1128
+ device = ref_weight.device
1129
+
1130
+ def resize_add(mh, mw):
1131
+ # print(mh, mw, mh * mw)
1132
+ m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
1133
+ m = m.to(device, dtype=dtype)
1134
+ mask_dic[mh * mw] = m
1135
+
1136
+ h = height // 8
1137
+ w = width // 8
1138
+ for _ in range(4):
1139
+ resize_add(h, w)
1140
+ if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
1141
+ resize_add(h + h % 2, w + w % 2)
1142
+ h = (h + 1) // 2
1143
+ w = (w + 1) // 2
1144
+
1145
+ self.mask_dic = mask_dic
1146
+
1147
+ def backup_weights(self):
1148
+ # 重みのバックアップを行う
1149
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1150
+ for lora in loras:
1151
+ org_module = lora.org_module_ref[0]
1152
+ if not hasattr(org_module, "_lora_org_weight"):
1153
+ sd = org_module.state_dict()
1154
+ org_module._lora_org_weight = sd["weight"].detach().clone()
1155
+ org_module._lora_restored = True
1156
+
1157
+ def restore_weights(self):
1158
+ # 重みのリストアを行う
1159
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1160
+ for lora in loras:
1161
+ org_module = lora.org_module_ref[0]
1162
+ if not org_module._lora_restored:
1163
+ sd = org_module.state_dict()
1164
+ sd["weight"] = org_module._lora_org_weight
1165
+ org_module.load_state_dict(sd)
1166
+ org_module._lora_restored = True
1167
+
1168
+ def pre_calculation(self):
1169
+ # 事前計算を行う
1170
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1171
+ for lora in loras:
1172
+ org_module = lora.org_module_ref[0]
1173
+ sd = org_module.state_dict()
1174
+
1175
+ org_weight = sd["weight"]
1176
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
1177
+ sd["weight"] = org_weight + lora_weight
1178
+ assert sd["weight"].shape == org_weight.shape
1179
+ org_module.load_state_dict(sd)
1180
+
1181
+ org_module._lora_restored = False
1182
+ lora.enabled = False
1183
+
1184
+ def apply_max_norm_regularization(self, max_norm_value, device):
1185
+ downkeys = []
1186
+ upkeys = []
1187
+ alphakeys = []
1188
+ norms = []
1189
+ keys_scaled = 0
1190
+
1191
+ state_dict = self.state_dict()
1192
+ for key in state_dict.keys():
1193
+ if "lora_down" in key and "weight" in key:
1194
+ downkeys.append(key)
1195
+ upkeys.append(key.replace("lora_down", "lora_up"))
1196
+ alphakeys.append(key.replace("lora_down.weight", "alpha"))
1197
+
1198
+ for i in range(len(downkeys)):
1199
+ down = state_dict[downkeys[i]].to(device)
1200
+ up = state_dict[upkeys[i]].to(device)
1201
+ alpha = state_dict[alphakeys[i]].to(device)
1202
+ dim = down.shape[0]
1203
+ scale = alpha / dim
1204
+
1205
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
1206
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
1207
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
1208
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
1209
+ else:
1210
+ updown = up @ down
1211
+
1212
+ updown *= scale
1213
+
1214
+ norm = updown.norm().clamp(min=max_norm_value / 2)
1215
+ desired = torch.clamp(norm, max=max_norm_value)
1216
+ ratio = desired.cpu() / norm.cpu()
1217
+ sqrt_ratio = ratio**0.5
1218
+ if ratio != 1:
1219
+ keys_scaled += 1
1220
+ state_dict[upkeys[i]] *= sqrt_ratio
1221
+ state_dict[downkeys[i]] *= sqrt_ratio
1222
+ scalednorm = updown.norm() * ratio
1223
+ norms.append(scalednorm.item())
1224
+
1225
+ return keys_scaled, sum(norms) / len(norms), max(norms)
external/llite/networks/lora_diffusers.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusersで動くLoRA。このファイル単独で完結する。
2
+ # LoRA module for Diffusers. This file works independently.
3
+
4
+ import bisect
5
+ import math
6
+ import random
7
+ from typing import Any, Dict, List, Mapping, Optional, Union
8
+ from diffusers import UNet2DConditionModel
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ from transformers import CLIPTextModel
12
+ import torch
13
+
14
+
15
+ def make_unet_conversion_map() -> Dict[str, str]:
16
+ unet_conversion_map_layer = []
17
+
18
+ for i in range(3): # num_blocks is 3 in sdxl
19
+ # loop over downblocks/upblocks
20
+ for j in range(2):
21
+ # loop over resnets/attentions for downblocks
22
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
23
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
24
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
25
+
26
+ if i < 3:
27
+ # no attention layers in down_blocks.3
28
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
29
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
30
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
31
+
32
+ for j in range(3):
33
+ # loop over resnets/attentions for upblocks
34
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
35
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
36
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
37
+
38
+ # if i > 0: commentout for sdxl
39
+ # no attention layers in up_blocks.0
40
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
41
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
42
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
43
+
44
+ if i < 3:
45
+ # no downsample in down_blocks.3
46
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
47
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
48
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
49
+
50
+ # no upsample in up_blocks.3
51
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
52
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
53
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
54
+
55
+ hf_mid_atn_prefix = "mid_block.attentions.0."
56
+ sd_mid_atn_prefix = "middle_block.1."
57
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
58
+
59
+ for j in range(2):
60
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
61
+ sd_mid_res_prefix = f"middle_block.{2*j}."
62
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
63
+
64
+ unet_conversion_map_resnet = [
65
+ # (stable-diffusion, HF Diffusers)
66
+ ("in_layers.0.", "norm1."),
67
+ ("in_layers.2.", "conv1."),
68
+ ("out_layers.0.", "norm2."),
69
+ ("out_layers.3.", "conv2."),
70
+ ("emb_layers.1.", "time_emb_proj."),
71
+ ("skip_connection.", "conv_shortcut."),
72
+ ]
73
+
74
+ unet_conversion_map = []
75
+ for sd, hf in unet_conversion_map_layer:
76
+ if "resnets" in hf:
77
+ for sd_res, hf_res in unet_conversion_map_resnet:
78
+ unet_conversion_map.append((sd + sd_res, hf + hf_res))
79
+ else:
80
+ unet_conversion_map.append((sd, hf))
81
+
82
+ for j in range(2):
83
+ hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
84
+ sd_time_embed_prefix = f"time_embed.{j*2}."
85
+ unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
86
+
87
+ for j in range(2):
88
+ hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
89
+ sd_label_embed_prefix = f"label_emb.0.{j*2}."
90
+ unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
91
+
92
+ unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
93
+ unet_conversion_map.append(("out.0.", "conv_norm_out."))
94
+ unet_conversion_map.append(("out.2.", "conv_out."))
95
+
96
+ sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
97
+ return sd_hf_conversion_map
98
+
99
+
100
+ UNET_CONVERSION_MAP = make_unet_conversion_map()
101
+
102
+
103
+ class LoRAModule(torch.nn.Module):
104
+ """
105
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ lora_name,
111
+ org_module: torch.nn.Module,
112
+ multiplier=1.0,
113
+ lora_dim=4,
114
+ alpha=1,
115
+ ):
116
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
117
+ super().__init__()
118
+ self.lora_name = lora_name
119
+
120
+ if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
121
+ in_dim = org_module.in_channels
122
+ out_dim = org_module.out_channels
123
+ else:
124
+ in_dim = org_module.in_features
125
+ out_dim = org_module.out_features
126
+
127
+ self.lora_dim = lora_dim
128
+
129
+ if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
130
+ kernel_size = org_module.kernel_size
131
+ stride = org_module.stride
132
+ padding = org_module.padding
133
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
134
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
135
+ else:
136
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
137
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
138
+
139
+ if type(alpha) == torch.Tensor:
140
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
141
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
142
+ self.scale = alpha / self.lora_dim
143
+ self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation
144
+
145
+ # same as microsoft's
146
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
147
+ torch.nn.init.zeros_(self.lora_up.weight)
148
+
149
+ self.multiplier = multiplier
150
+ self.org_module = [org_module]
151
+ self.enabled = True
152
+ self.network: LoRANetwork = None
153
+ self.org_forward = None
154
+
155
+ # override org_module's forward method
156
+ def apply_to(self, multiplier=None):
157
+ if multiplier is not None:
158
+ self.multiplier = multiplier
159
+ if self.org_forward is None:
160
+ self.org_forward = self.org_module[0].forward
161
+ self.org_module[0].forward = self.forward
162
+
163
+ # restore org_module's forward method
164
+ def unapply_to(self):
165
+ if self.org_forward is not None:
166
+ self.org_module[0].forward = self.org_forward
167
+
168
+ # forward with lora
169
+ # scale is used LoRACompatibleConv, but we ignore it because we have multiplier
170
+ def forward(self, x, scale=1.0):
171
+ if not self.enabled:
172
+ return self.org_forward(x)
173
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
174
+
175
+ def set_network(self, network):
176
+ self.network = network
177
+
178
+ # merge lora weight to org weight
179
+ def merge_to(self, multiplier=1.0):
180
+ # get lora weight
181
+ lora_weight = self.get_weight(multiplier)
182
+
183
+ # get org weight
184
+ org_sd = self.org_module[0].state_dict()
185
+ org_weight = org_sd["weight"]
186
+ weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype)
187
+
188
+ # set weight to org_module
189
+ org_sd["weight"] = weight
190
+ self.org_module[0].load_state_dict(org_sd)
191
+
192
+ # restore org weight from lora weight
193
+ def restore_from(self, multiplier=1.0):
194
+ # get lora weight
195
+ lora_weight = self.get_weight(multiplier)
196
+
197
+ # get org weight
198
+ org_sd = self.org_module[0].state_dict()
199
+ org_weight = org_sd["weight"]
200
+ weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype)
201
+
202
+ # set weight to org_module
203
+ org_sd["weight"] = weight
204
+ self.org_module[0].load_state_dict(org_sd)
205
+
206
+ # return lora weight
207
+ def get_weight(self, multiplier=None):
208
+ if multiplier is None:
209
+ multiplier = self.multiplier
210
+
211
+ # get up/down weight from module
212
+ up_weight = self.lora_up.weight.to(torch.float)
213
+ down_weight = self.lora_down.weight.to(torch.float)
214
+
215
+ # pre-calculated weight
216
+ if len(down_weight.size()) == 2:
217
+ # linear
218
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
219
+ elif down_weight.size()[2:4] == (1, 1):
220
+ # conv2d 1x1
221
+ weight = (
222
+ self.multiplier
223
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
224
+ * self.scale
225
+ )
226
+ else:
227
+ # conv2d 3x3
228
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
229
+ weight = self.multiplier * conved * self.scale
230
+
231
+ return weight
232
+
233
+
234
+ # Create network from weights for inference, weights are not loaded here
235
+ def create_network_from_weights(
236
+ text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0
237
+ ):
238
+ # get dim/alpha mapping
239
+ modules_dim = {}
240
+ modules_alpha = {}
241
+ for key, value in weights_sd.items():
242
+ if "." not in key:
243
+ continue
244
+
245
+ lora_name = key.split(".")[0]
246
+ if "alpha" in key:
247
+ modules_alpha[lora_name] = value
248
+ elif "lora_down" in key:
249
+ dim = value.size()[0]
250
+ modules_dim[lora_name] = dim
251
+ # print(lora_name, value.size(), dim)
252
+
253
+ # support old LoRA without alpha
254
+ for key in modules_dim.keys():
255
+ if key not in modules_alpha:
256
+ modules_alpha[key] = modules_dim[key]
257
+
258
+ return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
259
+
260
+
261
+ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
262
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
263
+ unet = pipe.unet
264
+
265
+ lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
266
+ lora_network.load_state_dict(weights_sd)
267
+ lora_network.merge_to(multiplier=multiplier)
268
+
269
+
270
+ # block weightや学習に対応しない簡易版 / simple version without block weight and training
271
+ class LoRANetwork(torch.nn.Module):
272
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
273
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
274
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
275
+ LORA_PREFIX_UNET = "lora_unet"
276
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
277
+
278
+ # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
279
+ LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
280
+ LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
281
+
282
+ def __init__(
283
+ self,
284
+ text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
285
+ unet: UNet2DConditionModel,
286
+ multiplier: float = 1.0,
287
+ modules_dim: Optional[Dict[str, int]] = None,
288
+ modules_alpha: Optional[Dict[str, int]] = None,
289
+ varbose: Optional[bool] = False,
290
+ ) -> None:
291
+ super().__init__()
292
+ self.multiplier = multiplier
293
+
294
+ print(f"create LoRA network from weights")
295
+
296
+ # convert SDXL Stability AI's U-Net modules to Diffusers
297
+ converted = self.convert_unet_modules(modules_dim, modules_alpha)
298
+ if converted:
299
+ print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
300
+
301
+ # create module instances
302
+ def create_modules(
303
+ is_unet: bool,
304
+ text_encoder_idx: Optional[int], # None, 1, 2
305
+ root_module: torch.nn.Module,
306
+ target_replace_modules: List[torch.nn.Module],
307
+ ) -> List[LoRAModule]:
308
+ prefix = (
309
+ self.LORA_PREFIX_UNET
310
+ if is_unet
311
+ else (
312
+ self.LORA_PREFIX_TEXT_ENCODER
313
+ if text_encoder_idx is None
314
+ else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
315
+ )
316
+ )
317
+ loras = []
318
+ skipped = []
319
+ for name, module in root_module.named_modules():
320
+ if module.__class__.__name__ in target_replace_modules:
321
+ for child_name, child_module in module.named_modules():
322
+ is_linear = (
323
+ child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
324
+ )
325
+ is_conv2d = (
326
+ child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
327
+ )
328
+
329
+ if is_linear or is_conv2d:
330
+ lora_name = prefix + "." + name + "." + child_name
331
+ lora_name = lora_name.replace(".", "_")
332
+
333
+ if lora_name not in modules_dim:
334
+ # print(f"skipped {lora_name} (not found in modules_dim)")
335
+ skipped.append(lora_name)
336
+ continue
337
+
338
+ dim = modules_dim[lora_name]
339
+ alpha = modules_alpha[lora_name]
340
+ lora = LoRAModule(
341
+ lora_name,
342
+ child_module,
343
+ self.multiplier,
344
+ dim,
345
+ alpha,
346
+ )
347
+ loras.append(lora)
348
+ return loras, skipped
349
+
350
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
351
+
352
+ # create LoRA for text encoder
353
+ # 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider
354
+ self.text_encoder_loras: List[LoRAModule] = []
355
+ skipped_te = []
356
+ for i, text_encoder in enumerate(text_encoders):
357
+ if len(text_encoders) > 1:
358
+ index = i + 1
359
+ else:
360
+ index = None
361
+
362
+ text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
363
+ self.text_encoder_loras.extend(text_encoder_loras)
364
+ skipped_te += skipped
365
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
366
+ if len(skipped_te) > 0:
367
+ print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
368
+
369
+ # extend U-Net target modules to include Conv2d 3x3
370
+ target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
371
+
372
+ self.unet_loras: List[LoRAModule]
373
+ self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
374
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
375
+ if len(skipped_un) > 0:
376
+ print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
377
+
378
+ # assertion
379
+ names = set()
380
+ for lora in self.text_encoder_loras + self.unet_loras:
381
+ names.add(lora.lora_name)
382
+ for lora_name in modules_dim.keys():
383
+ assert lora_name in names, f"{lora_name} is not found in created LoRA modules."
384
+
385
+ # make to work load_state_dict
386
+ for lora in self.text_encoder_loras + self.unet_loras:
387
+ self.add_module(lora.lora_name, lora)
388
+
389
+ # SDXL: convert SDXL Stability AI's U-Net modules to Diffusers
390
+ def convert_unet_modules(self, modules_dim, modules_alpha):
391
+ converted_count = 0
392
+ not_converted_count = 0
393
+
394
+ map_keys = list(UNET_CONVERSION_MAP.keys())
395
+ map_keys.sort()
396
+
397
+ for key in list(modules_dim.keys()):
398
+ if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
399
+ search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
400
+ position = bisect.bisect_right(map_keys, search_key)
401
+ map_key = map_keys[position - 1]
402
+ if search_key.startswith(map_key):
403
+ new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
404
+ modules_dim[new_key] = modules_dim[key]
405
+ modules_alpha[new_key] = modules_alpha[key]
406
+ del modules_dim[key]
407
+ del modules_alpha[key]
408
+ converted_count += 1
409
+ else:
410
+ not_converted_count += 1
411
+ assert (
412
+ converted_count == 0 or not_converted_count == 0
413
+ ), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted"
414
+ return converted_count
415
+
416
+ def set_multiplier(self, multiplier):
417
+ self.multiplier = multiplier
418
+ for lora in self.text_encoder_loras + self.unet_loras:
419
+ lora.multiplier = self.multiplier
420
+
421
+ def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
422
+ if apply_text_encoder:
423
+ print("enable LoRA for text encoder")
424
+ for lora in self.text_encoder_loras:
425
+ lora.apply_to(multiplier)
426
+ if apply_unet:
427
+ print("enable LoRA for U-Net")
428
+ for lora in self.unet_loras:
429
+ lora.apply_to(multiplier)
430
+
431
+ def unapply_to(self):
432
+ for lora in self.text_encoder_loras + self.unet_loras:
433
+ lora.unapply_to()
434
+
435
+ def merge_to(self, multiplier=1.0):
436
+ print("merge LoRA weights to original weights")
437
+ for lora in tqdm(self.text_encoder_loras + self.unet_loras):
438
+ lora.merge_to(multiplier)
439
+ print(f"weights are merged")
440
+
441
+ def restore_from(self, multiplier=1.0):
442
+ print("restore LoRA weights from original weights")
443
+ for lora in tqdm(self.text_encoder_loras + self.unet_loras):
444
+ lora.restore_from(multiplier)
445
+ print(f"weights are restored")
446
+
447
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
448
+ # convert SDXL Stability AI's state dict to Diffusers' based state dict
449
+ map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
450
+ map_keys.sort()
451
+ for key in list(state_dict.keys()):
452
+ if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
453
+ search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
454
+ position = bisect.bisect_right(map_keys, search_key)
455
+ map_key = map_keys[position - 1]
456
+ if search_key.startswith(map_key):
457
+ new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
458
+ state_dict[new_key] = state_dict[key]
459
+ del state_dict[key]
460
+
461
+ # in case of V2, some weights have different shape, so we need to convert them
462
+ # because V2 LoRA is based on U-Net created by use_linear_projection=False
463
+ my_state_dict = self.state_dict()
464
+ for key in state_dict.keys():
465
+ if state_dict[key].size() != my_state_dict[key].size():
466
+ # print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
467
+ state_dict[key] = state_dict[key].view(my_state_dict[key].size())
468
+
469
+ return super().load_state_dict(state_dict, strict)
470
+
471
+
472
+ if __name__ == "__main__":
473
+ # sample code to use LoRANetwork
474
+ import os
475
+ import argparse
476
+ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
477
+ import torch
478
+
479
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
480
+
481
+ parser = argparse.ArgumentParser()
482
+ parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
483
+ parser.add_argument("--lora_weights", type=str, default=None, help="path to LoRA weights")
484
+ parser.add_argument("--sdxl", action="store_true", help="use SDXL model")
485
+ parser.add_argument("--prompt", type=str, default="A photo of cat", help="prompt text")
486
+ parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt text")
487
+ parser.add_argument("--seed", type=int, default=0, help="random seed")
488
+ args = parser.parse_args()
489
+
490
+ image_prefix = args.model_id.replace("/", "_") + "_"
491
+
492
+ # load Diffusers model
493
+ print(f"load model from {args.model_id}")
494
+ pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
495
+ if args.sdxl:
496
+ # use_safetensors=True does not work with 0.18.2
497
+ pipe = StableDiffusionXLPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
498
+ else:
499
+ pipe = StableDiffusionPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
500
+ pipe.to(device)
501
+ pipe.set_use_memory_efficient_attention_xformers(True)
502
+
503
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
504
+
505
+ # load LoRA weights
506
+ print(f"load LoRA weights from {args.lora_weights}")
507
+ if os.path.splitext(args.lora_weights)[1] == ".safetensors":
508
+ from safetensors.torch import load_file
509
+
510
+ lora_sd = load_file(args.lora_weights)
511
+ else:
512
+ lora_sd = torch.load(args.lora_weights)
513
+
514
+ # create by LoRA weights and load weights
515
+ print(f"create LoRA network")
516
+ lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
517
+
518
+ print(f"load LoRA network weights")
519
+ lora_network.load_state_dict(lora_sd)
520
+
521
+ lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
522
+
523
+ # 必要があれば、元のモデルの重みをバックアップしておく
524
+ # back-up unet/text encoder weights if necessary
525
+ def detach_and_move_to_cpu(state_dict):
526
+ for k, v in state_dict.items():
527
+ state_dict[k] = v.detach().cpu()
528
+ return state_dict
529
+
530
+ org_unet_sd = pipe.unet.state_dict()
531
+ detach_and_move_to_cpu(org_unet_sd)
532
+
533
+ org_text_encoder_sd = pipe.text_encoder.state_dict()
534
+ detach_and_move_to_cpu(org_text_encoder_sd)
535
+
536
+ if args.sdxl:
537
+ org_text_encoder_2_sd = pipe.text_encoder_2.state_dict()
538
+ detach_and_move_to_cpu(org_text_encoder_2_sd)
539
+
540
+ def seed_everything(seed):
541
+ torch.manual_seed(seed)
542
+ torch.cuda.manual_seed_all(seed)
543
+ np.random.seed(seed)
544
+ random.seed(seed)
545
+
546
+ # create image with original weights
547
+ print(f"create image with original weights")
548
+ seed_everything(args.seed)
549
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
550
+ image.save(image_prefix + "original.png")
551
+
552
+ # apply LoRA network to the model: slower than merge_to, but can be reverted easily
553
+ print(f"apply LoRA network to the model")
554
+ lora_network.apply_to(multiplier=1.0)
555
+
556
+ print(f"create image with applied LoRA")
557
+ seed_everything(args.seed)
558
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
559
+ image.save(image_prefix + "applied_lora.png")
560
+
561
+ # unapply LoRA network to the model
562
+ print(f"unapply LoRA network to the model")
563
+ lora_network.unapply_to()
564
+
565
+ print(f"create image with unapplied LoRA")
566
+ seed_everything(args.seed)
567
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
568
+ image.save(image_prefix + "unapplied_lora.png")
569
+
570
+ # merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
571
+ print(f"merge LoRA network to the model")
572
+ lora_network.merge_to(multiplier=1.0)
573
+
574
+ print(f"create image with LoRA")
575
+ seed_everything(args.seed)
576
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
577
+ image.save(image_prefix + "merged_lora.png")
578
+
579
+ # restore (unmerge) LoRA weights: numerically unstable
580
+ # マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
581
+ # 保存したstate_dictから元の重みを復元するのが確実
582
+ print(f"restore (unmerge) LoRA weights")
583
+ lora_network.restore_from(multiplier=1.0)
584
+
585
+ print(f"create image without LoRA")
586
+ seed_everything(args.seed)
587
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
588
+ image.save(image_prefix + "unmerged_lora.png")
589
+
590
+ # restore original weights
591
+ print(f"restore original weights")
592
+ pipe.unet.load_state_dict(org_unet_sd)
593
+ pipe.text_encoder.load_state_dict(org_text_encoder_sd)
594
+ if args.sdxl:
595
+ pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
596
+
597
+ print(f"create image with restored original weights")
598
+ seed_everything(args.seed)
599
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
600
+ image.save(image_prefix + "restore_original.png")
601
+
602
+ # use convenience function to merge LoRA weights
603
+ print(f"merge LoRA weights with convenience function")
604
+ merge_lora_weights(pipe, lora_sd, multiplier=1.0)
605
+
606
+ print(f"create image with merged LoRA weights")
607
+ seed_everything(args.seed)
608
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
609
+ image.save(image_prefix + "convenience_merged_lora.png")
external/llite/networks/lora_fa.py ADDED
@@ -0,0 +1,1241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+
6
+ # temporary implementation of LoRA-FA: https://arxiv.org/abs/2308.03303
7
+ # need to be refactored and merged to lora.py
8
+
9
+ import math
10
+ import os
11
+ from typing import Dict, List, Optional, Tuple, Type, Union
12
+ from diffusers import AutoencoderKL
13
+ from transformers import CLIPTextModel
14
+ import numpy as np
15
+ import torch
16
+ import re
17
+
18
+
19
+ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
20
+
21
+
22
+ class LoRAModule(torch.nn.Module):
23
+ """
24
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ lora_name,
30
+ org_module: torch.nn.Module,
31
+ multiplier=1.0,
32
+ lora_dim=4,
33
+ alpha=1,
34
+ dropout=None,
35
+ rank_dropout=None,
36
+ module_dropout=None,
37
+ ):
38
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
39
+ super().__init__()
40
+ self.lora_name = lora_name
41
+
42
+ if org_module.__class__.__name__ == "Conv2d":
43
+ in_dim = org_module.in_channels
44
+ out_dim = org_module.out_channels
45
+ else:
46
+ in_dim = org_module.in_features
47
+ out_dim = org_module.out_features
48
+
49
+ # if limit_rank:
50
+ # self.lora_dim = min(lora_dim, in_dim, out_dim)
51
+ # if self.lora_dim != lora_dim:
52
+ # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
53
+ # else:
54
+ self.lora_dim = lora_dim
55
+
56
+ if org_module.__class__.__name__ == "Conv2d":
57
+ kernel_size = org_module.kernel_size
58
+ stride = org_module.stride
59
+ padding = org_module.padding
60
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
61
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
62
+ else:
63
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
64
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
65
+
66
+ if type(alpha) == torch.Tensor:
67
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
68
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
69
+ self.scale = alpha / self.lora_dim
70
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
71
+
72
+ # # same as microsoft's
73
+ # torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
74
+
75
+ # according to the paper, initialize LoRA-A (down) as normal distribution
76
+ torch.nn.init.normal_(self.lora_down.weight, std=math.sqrt(2.0 / (in_dim + self.lora_dim)))
77
+
78
+ torch.nn.init.zeros_(self.lora_up.weight)
79
+
80
+ self.multiplier = multiplier
81
+ self.org_module = org_module # remove in applying
82
+ self.dropout = dropout
83
+ self.rank_dropout = rank_dropout
84
+ self.module_dropout = module_dropout
85
+
86
+ def get_trainable_params(self):
87
+ params = self.named_parameters()
88
+ trainable_params = []
89
+ for param in params:
90
+ if param[0] == "lora_up.weight": # up only
91
+ trainable_params.append(param[1])
92
+ return trainable_params
93
+
94
+ def requires_grad_(self, requires_grad: bool = True):
95
+ self.lora_up.requires_grad_(requires_grad)
96
+ self.lora_down.requires_grad_(False)
97
+ return self
98
+
99
+ def apply_to(self):
100
+ self.org_forward = self.org_module.forward
101
+ self.org_module.forward = self.forward
102
+ del self.org_module
103
+
104
+ def forward(self, x):
105
+ org_forwarded = self.org_forward(x)
106
+
107
+ # module dropout
108
+ if self.module_dropout is not None and self.training:
109
+ if torch.rand(1) < self.module_dropout:
110
+ return org_forwarded
111
+
112
+ lx = self.lora_down(x)
113
+
114
+ # normal dropout
115
+ if self.dropout is not None and self.training:
116
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
117
+
118
+ # rank dropout
119
+ if self.rank_dropout is not None and self.training:
120
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
121
+ if len(lx.size()) == 3:
122
+ mask = mask.unsqueeze(1) # for Text Encoder
123
+ elif len(lx.size()) == 4:
124
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
125
+ lx = lx * mask
126
+
127
+ # scaling for rank dropout: treat as if the rank is changed
128
+ # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
129
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
130
+ else:
131
+ scale = self.scale
132
+
133
+ lx = self.lora_up(lx)
134
+
135
+ return org_forwarded + lx * self.multiplier * scale
136
+
137
+
138
+ class LoRAInfModule(LoRAModule):
139
+ def __init__(
140
+ self,
141
+ lora_name,
142
+ org_module: torch.nn.Module,
143
+ multiplier=1.0,
144
+ lora_dim=4,
145
+ alpha=1,
146
+ **kwargs,
147
+ ):
148
+ # no dropout for inference
149
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
150
+
151
+ self.org_module_ref = [org_module] # 後から参照できるように
152
+ self.enabled = True
153
+
154
+ # check regional or not by lora_name
155
+ self.text_encoder = False
156
+ if lora_name.startswith("lora_te_"):
157
+ self.regional = False
158
+ self.use_sub_prompt = True
159
+ self.text_encoder = True
160
+ elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
161
+ self.regional = False
162
+ self.use_sub_prompt = True
163
+ elif "time_emb" in lora_name:
164
+ self.regional = False
165
+ self.use_sub_prompt = False
166
+ else:
167
+ self.regional = True
168
+ self.use_sub_prompt = False
169
+
170
+ self.network: LoRANetwork = None
171
+
172
+ def set_network(self, network):
173
+ self.network = network
174
+
175
+ # freezeしてマージする
176
+ def merge_to(self, sd, dtype, device):
177
+ # get up/down weight
178
+ up_weight = sd["lora_up.weight"].to(torch.float).to(device)
179
+ down_weight = sd["lora_down.weight"].to(torch.float).to(device)
180
+
181
+ # extract weight from org_module
182
+ org_sd = self.org_module.state_dict()
183
+ weight = org_sd["weight"].to(torch.float)
184
+
185
+ # merge weight
186
+ if len(weight.size()) == 2:
187
+ # linear
188
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
189
+ elif down_weight.size()[2:4] == (1, 1):
190
+ # conv2d 1x1
191
+ weight = (
192
+ weight
193
+ + self.multiplier
194
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
195
+ * self.scale
196
+ )
197
+ else:
198
+ # conv2d 3x3
199
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
200
+ # print(conved.size(), weight.size(), module.stride, module.padding)
201
+ weight = weight + self.multiplier * conved * self.scale
202
+
203
+ # set weight to org_module
204
+ org_sd["weight"] = weight.to(dtype)
205
+ self.org_module.load_state_dict(org_sd)
206
+
207
+ # 復元できるマージのため、このモジュールのweightを返す
208
+ def get_weight(self, multiplier=None):
209
+ if multiplier is None:
210
+ multiplier = self.multiplier
211
+
212
+ # get up/down weight from module
213
+ up_weight = self.lora_up.weight.to(torch.float)
214
+ down_weight = self.lora_down.weight.to(torch.float)
215
+
216
+ # pre-calculated weight
217
+ if len(down_weight.size()) == 2:
218
+ # linear
219
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
220
+ elif down_weight.size()[2:4] == (1, 1):
221
+ # conv2d 1x1
222
+ weight = (
223
+ self.multiplier
224
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
225
+ * self.scale
226
+ )
227
+ else:
228
+ # conv2d 3x3
229
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
230
+ weight = self.multiplier * conved * self.scale
231
+
232
+ return weight
233
+
234
+ def set_region(self, region):
235
+ self.region = region
236
+ self.region_mask = None
237
+
238
+ def default_forward(self, x):
239
+ # print("default_forward", self.lora_name, x.size())
240
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
241
+
242
+ def forward(self, x):
243
+ if not self.enabled:
244
+ return self.org_forward(x)
245
+
246
+ if self.network is None or self.network.sub_prompt_index is None:
247
+ return self.default_forward(x)
248
+ if not self.regional and not self.use_sub_prompt:
249
+ return self.default_forward(x)
250
+
251
+ if self.regional:
252
+ return self.regional_forward(x)
253
+ else:
254
+ return self.sub_prompt_forward(x)
255
+
256
+ def get_mask_for_x(self, x):
257
+ # calculate size from shape of x
258
+ if len(x.size()) == 4:
259
+ h, w = x.size()[2:4]
260
+ area = h * w
261
+ else:
262
+ area = x.size()[1]
263
+
264
+ mask = self.network.mask_dic[area]
265
+ if mask is None:
266
+ raise ValueError(f"mask is None for resolution {area}")
267
+ if len(x.size()) != 4:
268
+ mask = torch.reshape(mask, (1, -1, 1))
269
+ return mask
270
+
271
+ def regional_forward(self, x):
272
+ if "attn2_to_out" in self.lora_name:
273
+ return self.to_out_forward(x)
274
+
275
+ if self.network.mask_dic is None: # sub_prompt_index >= 3
276
+ return self.default_forward(x)
277
+
278
+ # apply mask for LoRA result
279
+ lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
280
+ mask = self.get_mask_for_x(lx)
281
+ # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
282
+ lx = lx * mask
283
+
284
+ x = self.org_forward(x)
285
+ x = x + lx
286
+
287
+ if "attn2_to_q" in self.lora_name and self.network.is_last_network:
288
+ x = self.postp_to_q(x)
289
+
290
+ return x
291
+
292
+ def postp_to_q(self, x):
293
+ # repeat x to num_sub_prompts
294
+ has_real_uncond = x.size()[0] // self.network.batch_size == 3
295
+ qc = self.network.batch_size # uncond
296
+ qc += self.network.batch_size * self.network.num_sub_prompts # cond
297
+ if has_real_uncond:
298
+ qc += self.network.batch_size # real_uncond
299
+
300
+ query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
301
+ query[: self.network.batch_size] = x[: self.network.batch_size]
302
+
303
+ for i in range(self.network.batch_size):
304
+ qi = self.network.batch_size + i * self.network.num_sub_prompts
305
+ query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
306
+
307
+ if has_real_uncond:
308
+ query[-self.network.batch_size :] = x[-self.network.batch_size :]
309
+
310
+ # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
311
+ return query
312
+
313
+ def sub_prompt_forward(self, x):
314
+ if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
315
+ return self.org_forward(x)
316
+
317
+ emb_idx = self.network.sub_prompt_index
318
+ if not self.text_encoder:
319
+ emb_idx += self.network.batch_size
320
+
321
+ # apply sub prompt of X
322
+ lx = x[emb_idx :: self.network.num_sub_prompts]
323
+ lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
324
+
325
+ # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
326
+
327
+ x = self.org_forward(x)
328
+ x[emb_idx :: self.network.num_sub_prompts] += lx
329
+
330
+ return x
331
+
332
+ def to_out_forward(self, x):
333
+ # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
334
+
335
+ if self.network.is_last_network:
336
+ masks = [None] * self.network.num_sub_prompts
337
+ self.network.shared[self.lora_name] = (None, masks)
338
+ else:
339
+ lx, masks = self.network.shared[self.lora_name]
340
+
341
+ # call own LoRA
342
+ x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
343
+ lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
344
+
345
+ if self.network.is_last_network:
346
+ lx = torch.zeros(
347
+ (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
348
+ )
349
+ self.network.shared[self.lora_name] = (lx, masks)
350
+
351
+ # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
352
+ lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
353
+ masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
354
+
355
+ # if not last network, return x and masks
356
+ x = self.org_forward(x)
357
+ if not self.network.is_last_network:
358
+ return x
359
+
360
+ lx, masks = self.network.shared.pop(self.lora_name)
361
+
362
+ # if last network, combine separated x with mask weighted sum
363
+ has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
364
+
365
+ out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
366
+ out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
367
+ if has_real_uncond:
368
+ out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
369
+
370
+ # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
371
+ # for i in range(len(masks)):
372
+ # if masks[i] is None:
373
+ # masks[i] = torch.zeros_like(masks[-1])
374
+
375
+ mask = torch.cat(masks)
376
+ mask_sum = torch.sum(mask, dim=0) + 1e-4
377
+ for i in range(self.network.batch_size):
378
+ # 1枚の画像ごとに処理する
379
+ lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
380
+ lx1 = lx1 * mask
381
+ lx1 = torch.sum(lx1, dim=0)
382
+
383
+ xi = self.network.batch_size + i * self.network.num_sub_prompts
384
+ x1 = x[xi : xi + self.network.num_sub_prompts]
385
+ x1 = x1 * mask
386
+ x1 = torch.sum(x1, dim=0)
387
+ x1 = x1 / mask_sum
388
+
389
+ x1 = x1 + lx1
390
+ out[self.network.batch_size + i] = x1
391
+
392
+ # print("to_out_forward", x.size(), out.size(), has_real_uncond)
393
+ return out
394
+
395
+
396
+ def parse_block_lr_kwargs(nw_kwargs):
397
+ down_lr_weight = nw_kwargs.get("down_lr_weight", None)
398
+ mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
399
+ up_lr_weight = nw_kwargs.get("up_lr_weight", None)
400
+
401
+ # 以上のいずれにも設定がない場合は無効としてNoneを返す
402
+ if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
403
+ return None, None, None
404
+
405
+ # extract learning rate weight for each block
406
+ if down_lr_weight is not None:
407
+ # if some parameters are not set, use zero
408
+ if "," in down_lr_weight:
409
+ down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
410
+
411
+ if mid_lr_weight is not None:
412
+ mid_lr_weight = float(mid_lr_weight)
413
+
414
+ if up_lr_weight is not None:
415
+ if "," in up_lr_weight:
416
+ up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
417
+
418
+ down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
419
+ down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
420
+ )
421
+
422
+ return down_lr_weight, mid_lr_weight, up_lr_weight
423
+
424
+
425
+ def create_network(
426
+ multiplier: float,
427
+ network_dim: Optional[int],
428
+ network_alpha: Optional[float],
429
+ vae: AutoencoderKL,
430
+ text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
431
+ unet,
432
+ neuron_dropout: Optional[float] = None,
433
+ **kwargs,
434
+ ):
435
+ if network_dim is None:
436
+ network_dim = 4 # default
437
+ if network_alpha is None:
438
+ network_alpha = 1.0
439
+
440
+ # extract dim/alpha for conv2d, and block dim
441
+ conv_dim = kwargs.get("conv_dim", None)
442
+ conv_alpha = kwargs.get("conv_alpha", None)
443
+ if conv_dim is not None:
444
+ conv_dim = int(conv_dim)
445
+ if conv_alpha is None:
446
+ conv_alpha = 1.0
447
+ else:
448
+ conv_alpha = float(conv_alpha)
449
+
450
+ # block dim/alpha/lr
451
+ block_dims = kwargs.get("block_dims", None)
452
+ down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
453
+
454
+ # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
455
+ if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
456
+ block_alphas = kwargs.get("block_alphas", None)
457
+ conv_block_dims = kwargs.get("conv_block_dims", None)
458
+ conv_block_alphas = kwargs.get("conv_block_alphas", None)
459
+
460
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
461
+ block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
462
+ )
463
+
464
+ # remove block dim/alpha without learning rate
465
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
466
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
467
+ )
468
+
469
+ else:
470
+ block_alphas = None
471
+ conv_block_dims = None
472
+ conv_block_alphas = None
473
+
474
+ # rank/module dropout
475
+ rank_dropout = kwargs.get("rank_dropout", None)
476
+ if rank_dropout is not None:
477
+ rank_dropout = float(rank_dropout)
478
+ module_dropout = kwargs.get("module_dropout", None)
479
+ if module_dropout is not None:
480
+ module_dropout = float(module_dropout)
481
+
482
+ # すごく引数が多いな ( ^ω^)・・・
483
+ network = LoRANetwork(
484
+ text_encoder,
485
+ unet,
486
+ multiplier=multiplier,
487
+ lora_dim=network_dim,
488
+ alpha=network_alpha,
489
+ dropout=neuron_dropout,
490
+ rank_dropout=rank_dropout,
491
+ module_dropout=module_dropout,
492
+ conv_lora_dim=conv_dim,
493
+ conv_alpha=conv_alpha,
494
+ block_dims=block_dims,
495
+ block_alphas=block_alphas,
496
+ conv_block_dims=conv_block_dims,
497
+ conv_block_alphas=conv_block_alphas,
498
+ varbose=True,
499
+ )
500
+
501
+ if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
502
+ network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
503
+
504
+ return network
505
+
506
+
507
+ # このメソッドは外部から呼び出される可能性を考慮しておく
508
+ # network_dim, network_alpha にはデフォルト値が入っている。
509
+ # block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
510
+ # conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
511
+ def get_block_dims_and_alphas(
512
+ block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
513
+ ):
514
+ num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
515
+
516
+ def parse_ints(s):
517
+ return [int(i) for i in s.split(",")]
518
+
519
+ def parse_floats(s):
520
+ return [float(i) for i in s.split(",")]
521
+
522
+ # block_dimsとblock_alphasをパースする。必ず値が入る
523
+ if block_dims is not None:
524
+ block_dims = parse_ints(block_dims)
525
+ assert (
526
+ len(block_dims) == num_total_blocks
527
+ ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
528
+ else:
529
+ print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
530
+ block_dims = [network_dim] * num_total_blocks
531
+
532
+ if block_alphas is not None:
533
+ block_alphas = parse_floats(block_alphas)
534
+ assert (
535
+ len(block_alphas) == num_total_blocks
536
+ ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
537
+ else:
538
+ print(
539
+ f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
540
+ )
541
+ block_alphas = [network_alpha] * num_total_blocks
542
+
543
+ # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
544
+ if conv_block_dims is not None:
545
+ conv_block_dims = parse_ints(conv_block_dims)
546
+ assert (
547
+ len(conv_block_dims) == num_total_blocks
548
+ ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
549
+
550
+ if conv_block_alphas is not None:
551
+ conv_block_alphas = parse_floats(conv_block_alphas)
552
+ assert (
553
+ len(conv_block_alphas) == num_total_blocks
554
+ ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
555
+ else:
556
+ if conv_alpha is None:
557
+ conv_alpha = 1.0
558
+ print(
559
+ f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
560
+ )
561
+ conv_block_alphas = [conv_alpha] * num_total_blocks
562
+ else:
563
+ if conv_dim is not None:
564
+ print(
565
+ f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
566
+ )
567
+ conv_block_dims = [conv_dim] * num_total_blocks
568
+ conv_block_alphas = [conv_alpha] * num_total_blocks
569
+ else:
570
+ conv_block_dims = None
571
+ conv_block_alphas = None
572
+
573
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
574
+
575
+
576
+ # 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
577
+ def get_block_lr_weight(
578
+ down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
579
+ ) -> Tuple[List[float], List[float], List[float]]:
580
+ # パラメータ未指定時は何もせず、今までと同じ動作とする
581
+ if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
582
+ return None, None, None
583
+
584
+ max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
585
+
586
+ def get_list(name_with_suffix) -> List[float]:
587
+ import math
588
+
589
+ tokens = name_with_suffix.split("+")
590
+ name = tokens[0]
591
+ base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
592
+
593
+ if name == "cosine":
594
+ return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
595
+ elif name == "sine":
596
+ return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
597
+ elif name == "linear":
598
+ return [i / (max_len - 1) + base_lr for i in range(max_len)]
599
+ elif name == "reverse_linear":
600
+ return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
601
+ elif name == "zeros":
602
+ return [0.0 + base_lr] * max_len
603
+ else:
604
+ print(
605
+ "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
606
+ % (name)
607
+ )
608
+ return None
609
+
610
+ if type(down_lr_weight) == str:
611
+ down_lr_weight = get_list(down_lr_weight)
612
+ if type(up_lr_weight) == str:
613
+ up_lr_weight = get_list(up_lr_weight)
614
+
615
+ if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
616
+ print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
617
+ print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
618
+ up_lr_weight = up_lr_weight[:max_len]
619
+ down_lr_weight = down_lr_weight[:max_len]
620
+
621
+ if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
622
+ print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
623
+ print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
624
+
625
+ if down_lr_weight != None and len(down_lr_weight) < max_len:
626
+ down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
627
+ if up_lr_weight != None and len(up_lr_weight) < max_len:
628
+ up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
629
+
630
+ if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
631
+ print("apply block learning rate / 階層別学習率を適用します。")
632
+ if down_lr_weight != None:
633
+ down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
634
+ print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
635
+ else:
636
+ print("down_lr_weight: all 1.0, すべて1.0")
637
+
638
+ if mid_lr_weight != None:
639
+ mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
640
+ print("mid_lr_weight:", mid_lr_weight)
641
+ else:
642
+ print("mid_lr_weight: 1.0")
643
+
644
+ if up_lr_weight != None:
645
+ up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
646
+ print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
647
+ else:
648
+ print("up_lr_weight: all 1.0, すべて1.0")
649
+
650
+ return down_lr_weight, mid_lr_weight, up_lr_weight
651
+
652
+
653
+ # lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
654
+ def remove_block_dims_and_alphas(
655
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
656
+ ):
657
+ # set 0 to block dim without learning rate to remove the block
658
+ if down_lr_weight != None:
659
+ for i, lr in enumerate(down_lr_weight):
660
+ if lr == 0:
661
+ block_dims[i] = 0
662
+ if conv_block_dims is not None:
663
+ conv_block_dims[i] = 0
664
+ if mid_lr_weight != None:
665
+ if mid_lr_weight == 0:
666
+ block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
667
+ if conv_block_dims is not None:
668
+ conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
669
+ if up_lr_weight != None:
670
+ for i, lr in enumerate(up_lr_weight):
671
+ if lr == 0:
672
+ block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
673
+ if conv_block_dims is not None:
674
+ conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
675
+
676
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
677
+
678
+
679
+ # 外部から呼び出す可能性を考慮しておく
680
+ def get_block_index(lora_name: str) -> int:
681
+ block_idx = -1 # invalid lora name
682
+
683
+ m = RE_UPDOWN.search(lora_name)
684
+ if m:
685
+ g = m.groups()
686
+ i = int(g[1])
687
+ j = int(g[3])
688
+ if g[2] == "resnets":
689
+ idx = 3 * i + j
690
+ elif g[2] == "attentions":
691
+ idx = 3 * i + j
692
+ elif g[2] == "upsamplers" or g[2] == "downsamplers":
693
+ idx = 3 * i + 2
694
+
695
+ if g[0] == "down":
696
+ block_idx = 1 + idx # 0に該当するLoRAは存在しない
697
+ elif g[0] == "up":
698
+ block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
699
+
700
+ elif "mid_block_" in lora_name:
701
+ block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
702
+
703
+ return block_idx
704
+
705
+
706
+ # Create network from weights for inference, weights are not loaded here (because can be merged)
707
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
708
+ if weights_sd is None:
709
+ if os.path.splitext(file)[1] == ".safetensors":
710
+ from safetensors.torch import load_file, safe_open
711
+
712
+ weights_sd = load_file(file)
713
+ else:
714
+ weights_sd = torch.load(file, map_location="cpu")
715
+
716
+ # get dim/alpha mapping
717
+ modules_dim = {}
718
+ modules_alpha = {}
719
+ for key, value in weights_sd.items():
720
+ if "." not in key:
721
+ continue
722
+
723
+ lora_name = key.split(".")[0]
724
+ if "alpha" in key:
725
+ modules_alpha[lora_name] = value
726
+ elif "lora_down" in key:
727
+ dim = value.size()[0]
728
+ modules_dim[lora_name] = dim
729
+ # print(lora_name, value.size(), dim)
730
+
731
+ # support old LoRA without alpha
732
+ for key in modules_dim.keys():
733
+ if key not in modules_alpha:
734
+ modules_alpha[key] = modules_dim[key]
735
+
736
+ module_class = LoRAInfModule if for_inference else LoRAModule
737
+
738
+ network = LoRANetwork(
739
+ text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
740
+ )
741
+
742
+ # block lr
743
+ down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
744
+ if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
745
+ network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
746
+
747
+ return network, weights_sd
748
+
749
+
750
+ class LoRANetwork(torch.nn.Module):
751
+ NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
752
+
753
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
754
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
755
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
756
+ LORA_PREFIX_UNET = "lora_unet"
757
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
758
+
759
+ # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
760
+ LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
761
+ LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
762
+
763
+ def __init__(
764
+ self,
765
+ text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
766
+ unet,
767
+ multiplier: float = 1.0,
768
+ lora_dim: int = 4,
769
+ alpha: float = 1,
770
+ dropout: Optional[float] = None,
771
+ rank_dropout: Optional[float] = None,
772
+ module_dropout: Optional[float] = None,
773
+ conv_lora_dim: Optional[int] = None,
774
+ conv_alpha: Optional[float] = None,
775
+ block_dims: Optional[List[int]] = None,
776
+ block_alphas: Optional[List[float]] = None,
777
+ conv_block_dims: Optional[List[int]] = None,
778
+ conv_block_alphas: Optional[List[float]] = None,
779
+ modules_dim: Optional[Dict[str, int]] = None,
780
+ modules_alpha: Optional[Dict[str, int]] = None,
781
+ module_class: Type[object] = LoRAModule,
782
+ varbose: Optional[bool] = False,
783
+ ) -> None:
784
+ """
785
+ LoRA network: すごく引数が多いが、パターンは以下の通り
786
+ 1. lora_dimとalphaを指定
787
+ 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
788
+ 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
789
+ 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
790
+ 5. modules_dimとmodules_alphaを指定 (推論用)
791
+ """
792
+ super().__init__()
793
+ self.multiplier = multiplier
794
+
795
+ self.lora_dim = lora_dim
796
+ self.alpha = alpha
797
+ self.conv_lora_dim = conv_lora_dim
798
+ self.conv_alpha = conv_alpha
799
+ self.dropout = dropout
800
+ self.rank_dropout = rank_dropout
801
+ self.module_dropout = module_dropout
802
+
803
+ if modules_dim is not None:
804
+ print(f"create LoRA network from weights")
805
+ elif block_dims is not None:
806
+ print(f"create LoRA network from block_dims")
807
+ print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
808
+ print(f"block_dims: {block_dims}")
809
+ print(f"block_alphas: {block_alphas}")
810
+ if conv_block_dims is not None:
811
+ print(f"conv_block_dims: {conv_block_dims}")
812
+ print(f"conv_block_alphas: {conv_block_alphas}")
813
+ else:
814
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
815
+ print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
816
+ if self.conv_lora_dim is not None:
817
+ print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
818
+
819
+ # create module instances
820
+ def create_modules(
821
+ is_unet: bool,
822
+ text_encoder_idx: Optional[int], # None, 1, 2
823
+ root_module: torch.nn.Module,
824
+ target_replace_modules: List[torch.nn.Module],
825
+ ) -> List[LoRAModule]:
826
+ prefix = (
827
+ self.LORA_PREFIX_UNET
828
+ if is_unet
829
+ else (
830
+ self.LORA_PREFIX_TEXT_ENCODER
831
+ if text_encoder_idx is None
832
+ else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
833
+ )
834
+ )
835
+ loras = []
836
+ skipped = []
837
+ for name, module in root_module.named_modules():
838
+ if module.__class__.__name__ in target_replace_modules:
839
+ for child_name, child_module in module.named_modules():
840
+ is_linear = child_module.__class__.__name__ == "Linear"
841
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
842
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
843
+
844
+ if is_linear or is_conv2d:
845
+ lora_name = prefix + "." + name + "." + child_name
846
+ lora_name = lora_name.replace(".", "_")
847
+
848
+ dim = None
849
+ alpha = None
850
+
851
+ if modules_dim is not None:
852
+ # モジュール指定あり
853
+ if lora_name in modules_dim:
854
+ dim = modules_dim[lora_name]
855
+ alpha = modules_alpha[lora_name]
856
+ elif is_unet and block_dims is not None:
857
+ # U-Netでblock_dims指定あり
858
+ block_idx = get_block_index(lora_name)
859
+ if is_linear or is_conv2d_1x1:
860
+ dim = block_dims[block_idx]
861
+ alpha = block_alphas[block_idx]
862
+ elif conv_block_dims is not None:
863
+ dim = conv_block_dims[block_idx]
864
+ alpha = conv_block_alphas[block_idx]
865
+ else:
866
+ # 通常、すべて対象とする
867
+ if is_linear or is_conv2d_1x1:
868
+ dim = self.lora_dim
869
+ alpha = self.alpha
870
+ elif self.conv_lora_dim is not None:
871
+ dim = self.conv_lora_dim
872
+ alpha = self.conv_alpha
873
+
874
+ if dim is None or dim == 0:
875
+ # skipした情報を出力
876
+ if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
877
+ skipped.append(lora_name)
878
+ continue
879
+
880
+ lora = module_class(
881
+ lora_name,
882
+ child_module,
883
+ self.multiplier,
884
+ dim,
885
+ alpha,
886
+ dropout=dropout,
887
+ rank_dropout=rank_dropout,
888
+ module_dropout=module_dropout,
889
+ )
890
+ loras.append(lora)
891
+ return loras, skipped
892
+
893
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
894
+
895
+ # create LoRA for text encoder
896
+ # 毎回すべてのモジュールを作るのは無駄なので要検討
897
+ self.text_encoder_loras = []
898
+ skipped_te = []
899
+ for i, text_encoder in enumerate(text_encoders):
900
+ if len(text_encoders) > 1:
901
+ index = i + 1
902
+ print(f"create LoRA for Text Encoder {index}:")
903
+ else:
904
+ index = None
905
+ print(f"create LoRA for Text Encoder:")
906
+
907
+ text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
908
+ self.text_encoder_loras.extend(text_encoder_loras)
909
+ skipped_te += skipped
910
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
911
+
912
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
913
+ target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
914
+ if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
915
+ target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
916
+
917
+ self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
918
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
919
+
920
+ skipped = skipped_te + skipped_un
921
+ if varbose and len(skipped) > 0:
922
+ print(
923
+ f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
924
+ )
925
+ for name in skipped:
926
+ print(f"\t{name}")
927
+
928
+ self.up_lr_weight: List[float] = None
929
+ self.down_lr_weight: List[float] = None
930
+ self.mid_lr_weight: float = None
931
+ self.block_lr = False
932
+
933
+ # assertion
934
+ names = set()
935
+ for lora in self.text_encoder_loras + self.unet_loras:
936
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
937
+ names.add(lora.lora_name)
938
+
939
+ def set_multiplier(self, multiplier):
940
+ self.multiplier = multiplier
941
+ for lora in self.text_encoder_loras + self.unet_loras:
942
+ lora.multiplier = self.multiplier
943
+
944
+ def load_weights(self, file):
945
+ if os.path.splitext(file)[1] == ".safetensors":
946
+ from safetensors.torch import load_file
947
+
948
+ weights_sd = load_file(file)
949
+ else:
950
+ weights_sd = torch.load(file, map_location="cpu")
951
+
952
+ info = self.load_state_dict(weights_sd, False)
953
+ return info
954
+
955
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
956
+ if apply_text_encoder:
957
+ print("enable LoRA for text encoder")
958
+ else:
959
+ self.text_encoder_loras = []
960
+
961
+ if apply_unet:
962
+ print("enable LoRA for U-Net")
963
+ else:
964
+ self.unet_loras = []
965
+
966
+ for lora in self.text_encoder_loras + self.unet_loras:
967
+ lora.apply_to()
968
+ self.add_module(lora.lora_name, lora)
969
+
970
+ # マージできるかどうかを返す
971
+ def is_mergeable(self):
972
+ return True
973
+
974
+ # TODO refactor to common function with apply_to
975
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
976
+ apply_text_encoder = apply_unet = False
977
+ for key in weights_sd.keys():
978
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
979
+ apply_text_encoder = True
980
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
981
+ apply_unet = True
982
+
983
+ if apply_text_encoder:
984
+ print("enable LoRA for text encoder")
985
+ else:
986
+ self.text_encoder_loras = []
987
+
988
+ if apply_unet:
989
+ print("enable LoRA for U-Net")
990
+ else:
991
+ self.unet_loras = []
992
+
993
+ for lora in self.text_encoder_loras + self.unet_loras:
994
+ sd_for_lora = {}
995
+ for key in weights_sd.keys():
996
+ if key.startswith(lora.lora_name):
997
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
998
+ lora.merge_to(sd_for_lora, dtype, device)
999
+
1000
+ print(f"weights are merged")
1001
+
1002
+ # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
1003
+ def set_block_lr_weight(
1004
+ self,
1005
+ up_lr_weight: List[float] = None,
1006
+ mid_lr_weight: float = None,
1007
+ down_lr_weight: List[float] = None,
1008
+ ):
1009
+ self.block_lr = True
1010
+ self.down_lr_weight = down_lr_weight
1011
+ self.mid_lr_weight = mid_lr_weight
1012
+ self.up_lr_weight = up_lr_weight
1013
+
1014
+ def get_lr_weight(self, lora: LoRAModule) -> float:
1015
+ lr_weight = 1.0
1016
+ block_idx = get_block_index(lora.lora_name)
1017
+ if block_idx < 0:
1018
+ return lr_weight
1019
+
1020
+ if block_idx < LoRANetwork.NUM_OF_BLOCKS:
1021
+ if self.down_lr_weight != None:
1022
+ lr_weight = self.down_lr_weight[block_idx]
1023
+ elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
1024
+ if self.mid_lr_weight != None:
1025
+ lr_weight = self.mid_lr_weight
1026
+ elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
1027
+ if self.up_lr_weight != None:
1028
+ lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
1029
+
1030
+ return lr_weight
1031
+
1032
+ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1033
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1034
+ self.requires_grad_(True)
1035
+ all_params = []
1036
+
1037
+ def enumerate_params(loras: List[LoRAModule]):
1038
+ params = []
1039
+ for lora in loras:
1040
+ # params.extend(lora.parameters())
1041
+ params.extend(lora.get_trainable_params())
1042
+ return params
1043
+
1044
+ if self.text_encoder_loras:
1045
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
1046
+ if text_encoder_lr is not None:
1047
+ param_data["lr"] = text_encoder_lr
1048
+ all_params.append(param_data)
1049
+
1050
+ if self.unet_loras:
1051
+ if self.block_lr:
1052
+ # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
1053
+ block_idx_to_lora = {}
1054
+ for lora in self.unet_loras:
1055
+ idx = get_block_index(lora.lora_name)
1056
+ if idx not in block_idx_to_lora:
1057
+ block_idx_to_lora[idx] = []
1058
+ block_idx_to_lora[idx].append(lora)
1059
+
1060
+ # blockごとにパラメータを設定する
1061
+ for idx, block_loras in block_idx_to_lora.items():
1062
+ param_data = {"params": enumerate_params(block_loras)}
1063
+
1064
+ if unet_lr is not None:
1065
+ param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
1066
+ elif default_lr is not None:
1067
+ param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
1068
+ if ("lr" in param_data) and (param_data["lr"] == 0):
1069
+ continue
1070
+ all_params.append(param_data)
1071
+
1072
+ else:
1073
+ param_data = {"params": enumerate_params(self.unet_loras)}
1074
+ if unet_lr is not None:
1075
+ param_data["lr"] = unet_lr
1076
+ all_params.append(param_data)
1077
+
1078
+ return all_params
1079
+
1080
+ def enable_gradient_checkpointing(self):
1081
+ # not supported
1082
+ pass
1083
+
1084
+ def prepare_grad_etc(self, text_encoder, unet):
1085
+ self.requires_grad_(True)
1086
+
1087
+ def on_epoch_start(self, text_encoder, unet):
1088
+ self.train()
1089
+
1090
+ def get_trainable_params(self):
1091
+ return self.parameters()
1092
+
1093
+ def save_weights(self, file, dtype, metadata):
1094
+ if metadata is not None and len(metadata) == 0:
1095
+ metadata = None
1096
+
1097
+ state_dict = self.state_dict()
1098
+
1099
+ if dtype is not None:
1100
+ for key in list(state_dict.keys()):
1101
+ v = state_dict[key]
1102
+ v = v.detach().clone().to("cpu").to(dtype)
1103
+ state_dict[key] = v
1104
+
1105
+ if os.path.splitext(file)[1] == ".safetensors":
1106
+ from safetensors.torch import save_file
1107
+ from library import train_util
1108
+
1109
+ # Precalculate model hashes to save time on indexing
1110
+ if metadata is None:
1111
+ metadata = {}
1112
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
1113
+ metadata["sshs_model_hash"] = model_hash
1114
+ metadata["sshs_legacy_hash"] = legacy_hash
1115
+
1116
+ save_file(state_dict, file, metadata)
1117
+ else:
1118
+ torch.save(state_dict, file)
1119
+
1120
+ # mask is a tensor with values from 0 to 1
1121
+ def set_region(self, sub_prompt_index, is_last_network, mask):
1122
+ if mask.max() == 0:
1123
+ mask = torch.ones_like(mask)
1124
+
1125
+ self.mask = mask
1126
+ self.sub_prompt_index = sub_prompt_index
1127
+ self.is_last_network = is_last_network
1128
+
1129
+ for lora in self.text_encoder_loras + self.unet_loras:
1130
+ lora.set_network(self)
1131
+
1132
+ def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
1133
+ self.batch_size = batch_size
1134
+ self.num_sub_prompts = num_sub_prompts
1135
+ self.current_size = (height, width)
1136
+ self.shared = shared
1137
+
1138
+ # create masks
1139
+ mask = self.mask
1140
+ mask_dic = {}
1141
+ mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
1142
+ ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
1143
+ dtype = ref_weight.dtype
1144
+ device = ref_weight.device
1145
+
1146
+ def resize_add(mh, mw):
1147
+ # print(mh, mw, mh * mw)
1148
+ m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
1149
+ m = m.to(device, dtype=dtype)
1150
+ mask_dic[mh * mw] = m
1151
+
1152
+ h = height // 8
1153
+ w = width // 8
1154
+ for _ in range(4):
1155
+ resize_add(h, w)
1156
+ if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
1157
+ resize_add(h + h % 2, w + w % 2)
1158
+ h = (h + 1) // 2
1159
+ w = (w + 1) // 2
1160
+
1161
+ self.mask_dic = mask_dic
1162
+
1163
+ def backup_weights(self):
1164
+ # 重みのバックアップを行う
1165
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1166
+ for lora in loras:
1167
+ org_module = lora.org_module_ref[0]
1168
+ if not hasattr(org_module, "_lora_org_weight"):
1169
+ sd = org_module.state_dict()
1170
+ org_module._lora_org_weight = sd["weight"].detach().clone()
1171
+ org_module._lora_restored = True
1172
+
1173
+ def restore_weights(self):
1174
+ # 重みのリストアを行う
1175
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1176
+ for lora in loras:
1177
+ org_module = lora.org_module_ref[0]
1178
+ if not org_module._lora_restored:
1179
+ sd = org_module.state_dict()
1180
+ sd["weight"] = org_module._lora_org_weight
1181
+ org_module.load_state_dict(sd)
1182
+ org_module._lora_restored = True
1183
+
1184
+ def pre_calculation(self):
1185
+ # 事前計算を行う
1186
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1187
+ for lora in loras:
1188
+ org_module = lora.org_module_ref[0]
1189
+ sd = org_module.state_dict()
1190
+
1191
+ org_weight = sd["weight"]
1192
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
1193
+ sd["weight"] = org_weight + lora_weight
1194
+ assert sd["weight"].shape == org_weight.shape
1195
+ org_module.load_state_dict(sd)
1196
+
1197
+ org_module._lora_restored = False
1198
+ lora.enabled = False
1199
+
1200
+ def apply_max_norm_regularization(self, max_norm_value, device):
1201
+ downkeys = []
1202
+ upkeys = []
1203
+ alphakeys = []
1204
+ norms = []
1205
+ keys_scaled = 0
1206
+
1207
+ state_dict = self.state_dict()
1208
+ for key in state_dict.keys():
1209
+ if "lora_down" in key and "weight" in key:
1210
+ downkeys.append(key)
1211
+ upkeys.append(key.replace("lora_down", "lora_up"))
1212
+ alphakeys.append(key.replace("lora_down.weight", "alpha"))
1213
+
1214
+ for i in range(len(downkeys)):
1215
+ down = state_dict[downkeys[i]].to(device)
1216
+ up = state_dict[upkeys[i]].to(device)
1217
+ alpha = state_dict[alphakeys[i]].to(device)
1218
+ dim = down.shape[0]
1219
+ scale = alpha / dim
1220
+
1221
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
1222
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
1223
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
1224
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
1225
+ else:
1226
+ updown = up @ down
1227
+
1228
+ updown *= scale
1229
+
1230
+ norm = updown.norm().clamp(min=max_norm_value / 2)
1231
+ desired = torch.clamp(norm, max=max_norm_value)
1232
+ ratio = desired.cpu() / norm.cpu()
1233
+ sqrt_ratio = ratio**0.5
1234
+ if ratio != 1:
1235
+ keys_scaled += 1
1236
+ state_dict[upkeys[i]] *= sqrt_ratio
1237
+ state_dict[downkeys[i]] *= sqrt_ratio
1238
+ scalednorm = updown.norm() * ratio
1239
+ norms.append(scalednorm.item())
1240
+
1241
+ return keys_scaled, sum(norms) / len(norms), max(norms)
external/llite/networks/lora_interrogator.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from tqdm import tqdm
4
+ from library import model_util
5
+ import library.train_util as train_util
6
+ import argparse
7
+ from transformers import CLIPTokenizer
8
+ import torch
9
+
10
+ import library.model_util as model_util
11
+ import lora
12
+
13
+ TOKENIZER_PATH = "openai/clip-vit-large-patch14"
14
+ V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
15
+
16
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+
19
+ def interrogate(args):
20
+ weights_dtype = torch.float16
21
+
22
+ # いろいろ準備する
23
+ print(f"loading SD model: {args.sd_model}")
24
+ args.pretrained_model_name_or_path = args.sd_model
25
+ args.vae = None
26
+ text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
27
+
28
+ print(f"loading LoRA: {args.model}")
29
+ network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
30
+
31
+ # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
32
+ has_te_weight = False
33
+ for key in weights_sd.keys():
34
+ if 'lora_te' in key:
35
+ has_te_weight = True
36
+ break
37
+ if not has_te_weight:
38
+ print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
39
+ return
40
+ del vae
41
+
42
+ print("loading tokenizer")
43
+ if args.v2:
44
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
45
+ else:
46
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
47
+
48
+ text_encoder.to(DEVICE, dtype=weights_dtype)
49
+ text_encoder.eval()
50
+ unet.to(DEVICE, dtype=weights_dtype)
51
+ unet.eval() # U-Netは呼び出さないので不要だけど
52
+
53
+ # トークンをひとつひとつ当たっていく
54
+ token_id_start = 0
55
+ token_id_end = max(tokenizer.all_special_ids)
56
+ print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
57
+
58
+ def get_all_embeddings(text_encoder):
59
+ embs = []
60
+ with torch.no_grad():
61
+ for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
62
+ batch = []
63
+ for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
64
+ tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
65
+ # tokens = [tid] # こちらは結果がいまひとつ
66
+ batch.append(tokens)
67
+
68
+ # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
69
+ # clip skip対応
70
+ batch = torch.tensor(batch).to(DEVICE)
71
+ if args.clip_skip is None:
72
+ encoder_hidden_states = text_encoder(batch)[0]
73
+ else:
74
+ enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
75
+ encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
76
+ encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
77
+ encoder_hidden_states = encoder_hidden_states.to("cpu")
78
+
79
+ embs.extend(encoder_hidden_states)
80
+ return torch.stack(embs)
81
+
82
+ print("get original text encoder embeddings.")
83
+ orig_embs = get_all_embeddings(text_encoder)
84
+
85
+ network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
86
+ info = network.load_state_dict(weights_sd, strict=False)
87
+ print(f"Loading LoRA weights: {info}")
88
+
89
+ network.to(DEVICE, dtype=weights_dtype)
90
+ network.eval()
91
+
92
+ del unet
93
+
94
+ print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
95
+ print("get text encoder embeddings with lora.")
96
+ lora_embs = get_all_embeddings(text_encoder)
97
+
98
+ # 比べる:とりあえず単純に差分の絶対値で
99
+ print("comparing...")
100
+ diffs = {}
101
+ for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
102
+ diff = torch.mean(torch.abs(orig_emb - lora_emb))
103
+ # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
104
+ diff = float(diff.detach().to('cpu').numpy())
105
+ diffs[token_id_start + i] = diff
106
+
107
+ diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
108
+
109
+ # 結果を表示する
110
+ print("top 100:")
111
+ for i, (token, diff) in enumerate(diffs_sorted[:100]):
112
+ # if diff < 1e-6:
113
+ # break
114
+ string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
115
+ print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
116
+
117
+
118
+ def setup_parser() -> argparse.ArgumentParser:
119
+ parser = argparse.ArgumentParser()
120
+
121
+ parser.add_argument("--v2", action='store_true',
122
+ help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
123
+ parser.add_argument("--sd_model", type=str, default=None,
124
+ help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
125
+ parser.add_argument("--model", type=str, default=None,
126
+ help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
127
+ parser.add_argument("--batch_size", type=int, default=16,
128
+ help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
129
+ parser.add_argument("--clip_skip", type=int, default=None,
130
+ help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
131
+
132
+ return parser
133
+
134
+
135
+ if __name__ == '__main__':
136
+ parser = setup_parser()
137
+
138
+ args = parser.parse_args()
139
+ interrogate(args)
external/llite/networks/merge_lora.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import argparse
3
+ import os
4
+ import time
5
+ import torch
6
+ from safetensors.torch import load_file, save_file
7
+ from library import sai_model_spec, train_util
8
+ import library.model_util as model_util
9
+ import lora
10
+
11
+
12
+ def load_state_dict(file_name, dtype):
13
+ if os.path.splitext(file_name)[1] == ".safetensors":
14
+ sd = load_file(file_name)
15
+ metadata = train_util.load_metadata_from_safetensors(file_name)
16
+ else:
17
+ sd = torch.load(file_name, map_location="cpu")
18
+ metadata = {}
19
+
20
+ for key in list(sd.keys()):
21
+ if type(sd[key]) == torch.Tensor:
22
+ sd[key] = sd[key].to(dtype)
23
+
24
+ return sd, metadata
25
+
26
+
27
+ def save_to_file(file_name, model, state_dict, dtype, metadata):
28
+ if dtype is not None:
29
+ for key in list(state_dict.keys()):
30
+ if type(state_dict[key]) == torch.Tensor:
31
+ state_dict[key] = state_dict[key].to(dtype)
32
+
33
+ if os.path.splitext(file_name)[1] == ".safetensors":
34
+ save_file(model, file_name, metadata=metadata)
35
+ else:
36
+ torch.save(model, file_name)
37
+
38
+
39
+ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
40
+ text_encoder.to(merge_dtype)
41
+ unet.to(merge_dtype)
42
+
43
+ # create module map
44
+ name_to_module = {}
45
+ for i, root_module in enumerate([text_encoder, unet]):
46
+ if i == 0:
47
+ prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
48
+ target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
49
+ else:
50
+ prefix = lora.LoRANetwork.LORA_PREFIX_UNET
51
+ target_replace_modules = (
52
+ lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
53
+ )
54
+
55
+ for name, module in root_module.named_modules():
56
+ if module.__class__.__name__ in target_replace_modules:
57
+ for child_name, child_module in module.named_modules():
58
+ if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
59
+ lora_name = prefix + "." + name + "." + child_name
60
+ lora_name = lora_name.replace(".", "_")
61
+ name_to_module[lora_name] = child_module
62
+
63
+ for model, ratio in zip(models, ratios):
64
+ print(f"loading: {model}")
65
+ lora_sd, _ = load_state_dict(model, merge_dtype)
66
+
67
+ print(f"merging...")
68
+ for key in lora_sd.keys():
69
+ if "lora_down" in key:
70
+ up_key = key.replace("lora_down", "lora_up")
71
+ alpha_key = key[: key.index("lora_down")] + "alpha"
72
+
73
+ # find original module for this lora
74
+ module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
75
+ if module_name not in name_to_module:
76
+ print(f"no module found for LoRA weight: {key}")
77
+ continue
78
+ module = name_to_module[module_name]
79
+ # print(f"apply {key} to {module}")
80
+
81
+ down_weight = lora_sd[key]
82
+ up_weight = lora_sd[up_key]
83
+
84
+ dim = down_weight.size()[0]
85
+ alpha = lora_sd.get(alpha_key, dim)
86
+ scale = alpha / dim
87
+
88
+ # W <- W + U * D
89
+ weight = module.weight
90
+ if len(weight.size()) == 2:
91
+ # linear
92
+ if len(up_weight.size()) == 4: # use linear projection mismatch
93
+ up_weight = up_weight.squeeze(3).squeeze(2)
94
+ down_weight = down_weight.squeeze(3).squeeze(2)
95
+ weight = weight + ratio * (up_weight @ down_weight) * scale
96
+ elif down_weight.size()[2:4] == (1, 1):
97
+ # conv2d 1x1
98
+ weight = (
99
+ weight
100
+ + ratio
101
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
102
+ * scale
103
+ )
104
+ else:
105
+ # conv2d 3x3
106
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
107
+ # print(conved.size(), weight.size(), module.stride, module.padding)
108
+ weight = weight + ratio * conved * scale
109
+
110
+ module.weight = torch.nn.Parameter(weight)
111
+
112
+
113
+ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
114
+ base_alphas = {} # alpha for merged model
115
+ base_dims = {}
116
+
117
+ merged_sd = {}
118
+ v2 = None
119
+ base_model = None
120
+ for model, ratio in zip(models, ratios):
121
+ print(f"loading: {model}")
122
+ lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
123
+
124
+ if lora_metadata is not None:
125
+ if v2 is None:
126
+ v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string
127
+ if base_model is None:
128
+ base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
129
+
130
+ # get alpha and dim
131
+ alphas = {} # alpha for current model
132
+ dims = {} # dims for current model
133
+ for key in lora_sd.keys():
134
+ if "alpha" in key:
135
+ lora_module_name = key[: key.rfind(".alpha")]
136
+ alpha = float(lora_sd[key].detach().numpy())
137
+ alphas[lora_module_name] = alpha
138
+ if lora_module_name not in base_alphas:
139
+ base_alphas[lora_module_name] = alpha
140
+ elif "lora_down" in key:
141
+ lora_module_name = key[: key.rfind(".lora_down")]
142
+ dim = lora_sd[key].size()[0]
143
+ dims[lora_module_name] = dim
144
+ if lora_module_name not in base_dims:
145
+ base_dims[lora_module_name] = dim
146
+
147
+ for lora_module_name in dims.keys():
148
+ if lora_module_name not in alphas:
149
+ alpha = dims[lora_module_name]
150
+ alphas[lora_module_name] = alpha
151
+ if lora_module_name not in base_alphas:
152
+ base_alphas[lora_module_name] = alpha
153
+
154
+ print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
155
+
156
+ # merge
157
+ print(f"merging...")
158
+ for key in lora_sd.keys():
159
+ if "alpha" in key:
160
+ continue
161
+ if "lora_up" in key and concat:
162
+ concat_dim = 1
163
+ elif "lora_down" in key and concat:
164
+ concat_dim = 0
165
+ else:
166
+ concat_dim = None
167
+
168
+ lora_module_name = key[: key.rfind(".lora_")]
169
+
170
+ base_alpha = base_alphas[lora_module_name]
171
+ alpha = alphas[lora_module_name]
172
+
173
+ scale = math.sqrt(alpha / base_alpha) * ratio
174
+ scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
175
+
176
+ if key in merged_sd:
177
+ assert (
178
+ merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
179
+ ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
180
+ if concat_dim is not None:
181
+ merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
182
+ else:
183
+ merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
184
+ else:
185
+ merged_sd[key] = lora_sd[key] * scale
186
+
187
+ # set alpha to sd
188
+ for lora_module_name, alpha in base_alphas.items():
189
+ key = lora_module_name + ".alpha"
190
+ merged_sd[key] = torch.tensor(alpha)
191
+ if shuffle:
192
+ key_down = lora_module_name + ".lora_down.weight"
193
+ key_up = lora_module_name + ".lora_up.weight"
194
+ dim = merged_sd[key_down].shape[0]
195
+ perm = torch.randperm(dim)
196
+ merged_sd[key_down] = merged_sd[key_down][perm]
197
+ merged_sd[key_up] = merged_sd[key_up][:,perm]
198
+
199
+ print("merged model")
200
+ print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
201
+
202
+ # check all dims are same
203
+ dims_list = list(set(base_dims.values()))
204
+ alphas_list = list(set(base_alphas.values()))
205
+ all_same_dims = True
206
+ all_same_alphas = True
207
+ for dims in dims_list:
208
+ if dims != dims_list[0]:
209
+ all_same_dims = False
210
+ break
211
+ for alphas in alphas_list:
212
+ if alphas != alphas_list[0]:
213
+ all_same_alphas = False
214
+ break
215
+
216
+ # build minimum metadata
217
+ dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
218
+ alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
219
+ metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
220
+
221
+ return merged_sd, metadata, v2 == "True"
222
+
223
+
224
+ def merge(args):
225
+ assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
226
+
227
+ def str_to_dtype(p):
228
+ if p == "float":
229
+ return torch.float
230
+ if p == "fp16":
231
+ return torch.float16
232
+ if p == "bf16":
233
+ return torch.bfloat16
234
+ return None
235
+
236
+ merge_dtype = str_to_dtype(args.precision)
237
+ save_dtype = str_to_dtype(args.save_precision)
238
+ if save_dtype is None:
239
+ save_dtype = merge_dtype
240
+
241
+ if args.sd_model is not None:
242
+ print(f"loading SD model: {args.sd_model}")
243
+
244
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
245
+
246
+ merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
247
+
248
+ if args.no_metadata:
249
+ sai_metadata = None
250
+ else:
251
+ merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
252
+ title = os.path.splitext(os.path.basename(args.save_to))[0]
253
+ sai_metadata = sai_model_spec.build_metadata(
254
+ None,
255
+ args.v2,
256
+ args.v2,
257
+ False,
258
+ False,
259
+ False,
260
+ time.time(),
261
+ title=title,
262
+ merged_from=merged_from,
263
+ is_stable_diffusion_ckpt=True,
264
+ )
265
+ if args.v2:
266
+ # TODO read sai modelspec
267
+ print(
268
+ "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
269
+ )
270
+
271
+ print(f"saving SD model to: {args.save_to}")
272
+ model_util.save_stable_diffusion_checkpoint(
273
+ args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
274
+ )
275
+ else:
276
+ state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
277
+
278
+ print(f"calculating hashes and creating metadata...")
279
+
280
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
281
+ metadata["sshs_model_hash"] = model_hash
282
+ metadata["sshs_legacy_hash"] = legacy_hash
283
+
284
+ if not args.no_metadata:
285
+ merged_from = sai_model_spec.build_merged_from(args.models)
286
+ title = os.path.splitext(os.path.basename(args.save_to))[0]
287
+ sai_metadata = sai_model_spec.build_metadata(
288
+ state_dict, v2, v2, False, True, False, time.time(), title=title, merged_from=merged_from
289
+ )
290
+ if v2:
291
+ # TODO read sai modelspec
292
+ print(
293
+ "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
294
+ )
295
+ metadata.update(sai_metadata)
296
+
297
+ print(f"saving model to: {args.save_to}")
298
+ save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
299
+
300
+
301
+ def setup_parser() -> argparse.ArgumentParser:
302
+ parser = argparse.ArgumentParser()
303
+ parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
304
+ parser.add_argument(
305
+ "--save_precision",
306
+ type=str,
307
+ default=None,
308
+ choices=[None, "float", "fp16", "bf16"],
309
+ help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
310
+ )
311
+ parser.add_argument(
312
+ "--precision",
313
+ type=str,
314
+ default="float",
315
+ choices=["float", "fp16", "bf16"],
316
+ help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
317
+ )
318
+ parser.add_argument(
319
+ "--sd_model",
320
+ type=str,
321
+ default=None,
322
+ help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
323
+ )
324
+ parser.add_argument(
325
+ "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
326
+ )
327
+ parser.add_argument(
328
+ "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
329
+ )
330
+ parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
331
+ parser.add_argument(
332
+ "--no_metadata",
333
+ action="store_true",
334
+ help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
335
+ + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
336
+ )
337
+ parser.add_argument(
338
+ "--concat",
339
+ action="store_true",
340
+ help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
341
+ + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
342
+ )
343
+ parser.add_argument(
344
+ "--shuffle",
345
+ action="store_true",
346
+ help="shuffle lora weight./ "
347
+ + "LoRAの重みをシャッフルする",
348
+ )
349
+
350
+ return parser
351
+
352
+
353
+ if __name__ == "__main__":
354
+ parser = setup_parser()
355
+
356
+ args = parser.parse_args()
357
+ merge(args)
external/llite/networks/merge_lora_old.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import argparse
4
+ import os
5
+ import torch
6
+ from safetensors.torch import load_file, save_file
7
+ import library.model_util as model_util
8
+ import lora
9
+
10
+
11
+ def load_state_dict(file_name, dtype):
12
+ if os.path.splitext(file_name)[1] == '.safetensors':
13
+ sd = load_file(file_name)
14
+ else:
15
+ sd = torch.load(file_name, map_location='cpu')
16
+ for key in list(sd.keys()):
17
+ if type(sd[key]) == torch.Tensor:
18
+ sd[key] = sd[key].to(dtype)
19
+ return sd
20
+
21
+
22
+ def save_to_file(file_name, model, state_dict, dtype):
23
+ if dtype is not None:
24
+ for key in list(state_dict.keys()):
25
+ if type(state_dict[key]) == torch.Tensor:
26
+ state_dict[key] = state_dict[key].to(dtype)
27
+
28
+ if os.path.splitext(file_name)[1] == '.safetensors':
29
+ save_file(model, file_name)
30
+ else:
31
+ torch.save(model, file_name)
32
+
33
+
34
+ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
35
+ text_encoder.to(merge_dtype)
36
+ unet.to(merge_dtype)
37
+
38
+ # create module map
39
+ name_to_module = {}
40
+ for i, root_module in enumerate([text_encoder, unet]):
41
+ if i == 0:
42
+ prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
43
+ target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
44
+ else:
45
+ prefix = lora.LoRANetwork.LORA_PREFIX_UNET
46
+ target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
47
+
48
+ for name, module in root_module.named_modules():
49
+ if module.__class__.__name__ in target_replace_modules:
50
+ for child_name, child_module in module.named_modules():
51
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
52
+ lora_name = prefix + '.' + name + '.' + child_name
53
+ lora_name = lora_name.replace('.', '_')
54
+ name_to_module[lora_name] = child_module
55
+
56
+ for model, ratio in zip(models, ratios):
57
+ print(f"loading: {model}")
58
+ lora_sd = load_state_dict(model, merge_dtype)
59
+
60
+ print(f"merging...")
61
+ for key in lora_sd.keys():
62
+ if "lora_down" in key:
63
+ up_key = key.replace("lora_down", "lora_up")
64
+ alpha_key = key[:key.index("lora_down")] + 'alpha'
65
+
66
+ # find original module for this lora
67
+ module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
68
+ if module_name not in name_to_module:
69
+ print(f"no module found for LoRA weight: {key}")
70
+ continue
71
+ module = name_to_module[module_name]
72
+ # print(f"apply {key} to {module}")
73
+
74
+ down_weight = lora_sd[key]
75
+ up_weight = lora_sd[up_key]
76
+
77
+ dim = down_weight.size()[0]
78
+ alpha = lora_sd.get(alpha_key, dim)
79
+ scale = alpha / dim
80
+
81
+ # W <- W + U * D
82
+ weight = module.weight
83
+ if len(weight.size()) == 2:
84
+ # linear
85
+ weight = weight + ratio * (up_weight @ down_weight) * scale
86
+ else:
87
+ # conv2d
88
+ weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
89
+
90
+ module.weight = torch.nn.Parameter(weight)
91
+
92
+
93
+ def merge_lora_models(models, ratios, merge_dtype):
94
+ merged_sd = {}
95
+
96
+ alpha = None
97
+ dim = None
98
+ for model, ratio in zip(models, ratios):
99
+ print(f"loading: {model}")
100
+ lora_sd = load_state_dict(model, merge_dtype)
101
+
102
+ print(f"merging...")
103
+ for key in lora_sd.keys():
104
+ if 'alpha' in key:
105
+ if key in merged_sd:
106
+ assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
107
+ else:
108
+ alpha = lora_sd[key].detach().numpy()
109
+ merged_sd[key] = lora_sd[key]
110
+ else:
111
+ if key in merged_sd:
112
+ assert merged_sd[key].size() == lora_sd[key].size(
113
+ ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
114
+ merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
115
+ else:
116
+ if "lora_down" in key:
117
+ dim = lora_sd[key].size()[0]
118
+ merged_sd[key] = lora_sd[key] * ratio
119
+
120
+ print(f"dim (rank): {dim}, alpha: {alpha}")
121
+ if alpha is None:
122
+ alpha = dim
123
+
124
+ return merged_sd, dim, alpha
125
+
126
+
127
+ def merge(args):
128
+ assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
129
+
130
+ def str_to_dtype(p):
131
+ if p == 'float':
132
+ return torch.float
133
+ if p == 'fp16':
134
+ return torch.float16
135
+ if p == 'bf16':
136
+ return torch.bfloat16
137
+ return None
138
+
139
+ merge_dtype = str_to_dtype(args.precision)
140
+ save_dtype = str_to_dtype(args.save_precision)
141
+ if save_dtype is None:
142
+ save_dtype = merge_dtype
143
+
144
+ if args.sd_model is not None:
145
+ print(f"loading SD model: {args.sd_model}")
146
+
147
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
148
+
149
+ merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
150
+
151
+ print(f"\nsaving SD model to: {args.save_to}")
152
+ model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
153
+ args.sd_model, 0, 0, save_dtype, vae)
154
+ else:
155
+ state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
156
+
157
+ print(f"\nsaving model to: {args.save_to}")
158
+ save_to_file(args.save_to, state_dict, state_dict, save_dtype)
159
+
160
+
161
+ def setup_parser() -> argparse.ArgumentParser:
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument("--v2", action='store_true',
164
+ help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
165
+ parser.add_argument("--save_precision", type=str, default=None,
166
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
167
+ parser.add_argument("--precision", type=str, default="float",
168
+ choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
169
+ parser.add_argument("--sd_model", type=str, default=None,
170
+ help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
171
+ parser.add_argument("--save_to", type=str, default=None,
172
+ help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
173
+ parser.add_argument("--models", type=str, nargs='*',
174
+ help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
175
+ parser.add_argument("--ratios", type=float, nargs='*',
176
+ help="ratios for each model / それぞれのLoRAモデルの比率")
177
+
178
+ return parser
179
+
180
+
181
+ if __name__ == '__main__':
182
+ parser = setup_parser()
183
+
184
+ args = parser.parse_args()
185
+ merge(args)
external/llite/networks/oft.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OFT network module
2
+
3
+ import math
4
+ import os
5
+ from typing import Dict, List, Optional, Tuple, Type, Union
6
+ from diffusers import AutoencoderKL
7
+ from transformers import CLIPTextModel
8
+ import numpy as np
9
+ import torch
10
+ import re
11
+
12
+
13
+ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
14
+
15
+
16
+ class OFTModule(torch.nn.Module):
17
+ """
18
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ oft_name,
24
+ org_module: torch.nn.Module,
25
+ multiplier=1.0,
26
+ dim=4,
27
+ alpha=1,
28
+ ):
29
+ """
30
+ dim -> num blocks
31
+ alpha -> constraint
32
+ """
33
+ super().__init__()
34
+ self.oft_name = oft_name
35
+
36
+ self.num_blocks = dim
37
+
38
+ if "Linear" in org_module.__class__.__name__:
39
+ out_dim = org_module.out_features
40
+ elif "Conv" in org_module.__class__.__name__:
41
+ out_dim = org_module.out_channels
42
+
43
+ if type(alpha) == torch.Tensor:
44
+ alpha = alpha.detach().numpy()
45
+ self.constraint = alpha * out_dim
46
+ self.register_buffer("alpha", torch.tensor(alpha))
47
+
48
+ self.block_size = out_dim // self.num_blocks
49
+ self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
50
+
51
+ self.out_dim = out_dim
52
+ self.shape = org_module.weight.shape
53
+
54
+ self.multiplier = multiplier
55
+ self.org_module = [org_module] # moduleにならないようにlistに入れる
56
+
57
+ def apply_to(self):
58
+ self.org_forward = self.org_module[0].forward
59
+ self.org_module[0].forward = self.forward
60
+
61
+ def get_weight(self, multiplier=None):
62
+ if multiplier is None:
63
+ multiplier = self.multiplier
64
+
65
+ block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
66
+ norm_Q = torch.norm(block_Q.flatten())
67
+ new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
68
+ block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
69
+ I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
70
+ block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
71
+
72
+ block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
73
+ R = torch.block_diag(*block_R_weighted)
74
+
75
+ return R
76
+
77
+ def forward(self, x, scale=None):
78
+ x = self.org_forward(x)
79
+ if self.multiplier == 0.0:
80
+ return x
81
+
82
+ R = self.get_weight().to(x.device, dtype=x.dtype)
83
+ if x.dim() == 4:
84
+ x = x.permute(0, 2, 3, 1)
85
+ x = torch.matmul(x, R)
86
+ x = x.permute(0, 3, 1, 2)
87
+ else:
88
+ x = torch.matmul(x, R)
89
+ return x
90
+
91
+
92
+ class OFTInfModule(OFTModule):
93
+ def __init__(
94
+ self,
95
+ oft_name,
96
+ org_module: torch.nn.Module,
97
+ multiplier=1.0,
98
+ dim=4,
99
+ alpha=1,
100
+ **kwargs,
101
+ ):
102
+ # no dropout for inference
103
+ super().__init__(oft_name, org_module, multiplier, dim, alpha)
104
+ self.enabled = True
105
+ self.network: OFTNetwork = None
106
+
107
+ def set_network(self, network):
108
+ self.network = network
109
+
110
+ def forward(self, x, scale=None):
111
+ if not self.enabled:
112
+ return self.org_forward(x)
113
+ return super().forward(x, scale)
114
+
115
+ def merge_to(self, multiplier=None, sign=1):
116
+ R = self.get_weight(multiplier) * sign
117
+
118
+ # get org weight
119
+ org_sd = self.org_module[0].state_dict()
120
+ org_weight = org_sd["weight"]
121
+ R = R.to(org_weight.device, dtype=org_weight.dtype)
122
+
123
+ if org_weight.dim() == 4:
124
+ weight = torch.einsum("oihw, op -> pihw", org_weight, R)
125
+ else:
126
+ weight = torch.einsum("oi, op -> pi", org_weight, R)
127
+
128
+ # set weight to org_module
129
+ org_sd["weight"] = weight
130
+ self.org_module[0].load_state_dict(org_sd)
131
+
132
+
133
+ def create_network(
134
+ multiplier: float,
135
+ network_dim: Optional[int],
136
+ network_alpha: Optional[float],
137
+ vae: AutoencoderKL,
138
+ text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
139
+ unet,
140
+ neuron_dropout: Optional[float] = None,
141
+ **kwargs,
142
+ ):
143
+ if network_dim is None:
144
+ network_dim = 4 # default
145
+ if network_alpha is None:
146
+ network_alpha = 1.0
147
+
148
+ enable_all_linear = kwargs.get("enable_all_linear", None)
149
+ enable_conv = kwargs.get("enable_conv", None)
150
+ if enable_all_linear is not None:
151
+ enable_all_linear = bool(enable_all_linear)
152
+ if enable_conv is not None:
153
+ enable_conv = bool(enable_conv)
154
+
155
+ network = OFTNetwork(
156
+ text_encoder,
157
+ unet,
158
+ multiplier=multiplier,
159
+ dim=network_dim,
160
+ alpha=network_alpha,
161
+ enable_all_linear=enable_all_linear,
162
+ enable_conv=enable_conv,
163
+ varbose=True,
164
+ )
165
+ return network
166
+
167
+
168
+ # Create network from weights for inference, weights are not loaded here (because can be merged)
169
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
170
+ if weights_sd is None:
171
+ if os.path.splitext(file)[1] == ".safetensors":
172
+ from safetensors.torch import load_file, safe_open
173
+
174
+ weights_sd = load_file(file)
175
+ else:
176
+ weights_sd = torch.load(file, map_location="cpu")
177
+
178
+ # check dim, alpha and if weights have for conv2d
179
+ dim = None
180
+ alpha = None
181
+ has_conv2d = None
182
+ all_linear = None
183
+ for name, param in weights_sd.items():
184
+ if name.endswith(".alpha"):
185
+ if alpha is None:
186
+ alpha = param.item()
187
+ else:
188
+ if dim is None:
189
+ dim = param.size()[0]
190
+ if has_conv2d is None and param.dim() == 4:
191
+ has_conv2d = True
192
+ if all_linear is None:
193
+ if param.dim() == 3 and "attn" not in name:
194
+ all_linear = True
195
+ if dim is not None and alpha is not None and has_conv2d is not None:
196
+ break
197
+ if has_conv2d is None:
198
+ has_conv2d = False
199
+ if all_linear is None:
200
+ all_linear = False
201
+
202
+ module_class = OFTInfModule if for_inference else OFTModule
203
+ network = OFTNetwork(
204
+ text_encoder,
205
+ unet,
206
+ multiplier=multiplier,
207
+ dim=dim,
208
+ alpha=alpha,
209
+ enable_all_linear=all_linear,
210
+ enable_conv=has_conv2d,
211
+ module_class=module_class,
212
+ )
213
+ return network, weights_sd
214
+
215
+
216
+ class OFTNetwork(torch.nn.Module):
217
+ UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
218
+ UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
219
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
220
+ OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな
221
+
222
+ def __init__(
223
+ self,
224
+ text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
225
+ unet,
226
+ multiplier: float = 1.0,
227
+ dim: int = 4,
228
+ alpha: float = 1,
229
+ enable_all_linear: Optional[bool] = False,
230
+ enable_conv: Optional[bool] = False,
231
+ module_class: Type[object] = OFTModule,
232
+ varbose: Optional[bool] = False,
233
+ ) -> None:
234
+ super().__init__()
235
+ self.multiplier = multiplier
236
+
237
+ self.dim = dim
238
+ self.alpha = alpha
239
+
240
+ print(
241
+ f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
242
+ )
243
+
244
+ # create module instances
245
+ def create_modules(
246
+ root_module: torch.nn.Module,
247
+ target_replace_modules: List[torch.nn.Module],
248
+ ) -> List[OFTModule]:
249
+ prefix = self.OFT_PREFIX_UNET
250
+ ofts = []
251
+ for name, module in root_module.named_modules():
252
+ if module.__class__.__name__ in target_replace_modules:
253
+ for child_name, child_module in module.named_modules():
254
+ is_linear = "Linear" in child_module.__class__.__name__
255
+ is_conv2d = "Conv2d" in child_module.__class__.__name__
256
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
257
+
258
+ if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
259
+ oft_name = prefix + "." + name + "." + child_name
260
+ oft_name = oft_name.replace(".", "_")
261
+ # print(oft_name)
262
+
263
+ oft = module_class(
264
+ oft_name,
265
+ child_module,
266
+ self.multiplier,
267
+ dim,
268
+ alpha,
269
+ )
270
+ ofts.append(oft)
271
+ return ofts
272
+
273
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
274
+ if enable_all_linear:
275
+ target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
276
+ else:
277
+ target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
278
+ if enable_conv:
279
+ target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
280
+
281
+ self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
282
+ print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
283
+
284
+ # assertion
285
+ names = set()
286
+ for oft in self.unet_ofts:
287
+ assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
288
+ names.add(oft.oft_name)
289
+
290
+ def set_multiplier(self, multiplier):
291
+ self.multiplier = multiplier
292
+ for oft in self.unet_ofts:
293
+ oft.multiplier = self.multiplier
294
+
295
+ def load_weights(self, file):
296
+ if os.path.splitext(file)[1] == ".safetensors":
297
+ from safetensors.torch import load_file
298
+
299
+ weights_sd = load_file(file)
300
+ else:
301
+ weights_sd = torch.load(file, map_location="cpu")
302
+
303
+ info = self.load_state_dict(weights_sd, False)
304
+ return info
305
+
306
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
307
+ assert apply_unet, "apply_unet must be True"
308
+
309
+ for oft in self.unet_ofts:
310
+ oft.apply_to()
311
+ self.add_module(oft.oft_name, oft)
312
+
313
+ # マージできるかどうかを返す
314
+ def is_mergeable(self):
315
+ return True
316
+
317
+ # TODO refactor to common function with apply_to
318
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
319
+ print("enable OFT for U-Net")
320
+
321
+ for oft in self.unet_ofts:
322
+ sd_for_lora = {}
323
+ for key in weights_sd.keys():
324
+ if key.startswith(oft.oft_name):
325
+ sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
326
+ oft.load_state_dict(sd_for_lora, False)
327
+ oft.merge_to()
328
+
329
+ print(f"weights are merged")
330
+
331
+ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
332
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
333
+ self.requires_grad_(True)
334
+ all_params = []
335
+
336
+ def enumerate_params(ofts):
337
+ params = []
338
+ for oft in ofts:
339
+ params.extend(oft.parameters())
340
+
341
+ # print num of params
342
+ num_params = 0
343
+ for p in params:
344
+ num_params += p.numel()
345
+ print(f"OFT params: {num_params}")
346
+ return params
347
+
348
+ param_data = {"params": enumerate_params(self.unet_ofts)}
349
+ if unet_lr is not None:
350
+ param_data["lr"] = unet_lr
351
+ all_params.append(param_data)
352
+
353
+ return all_params
354
+
355
+ def enable_gradient_checkpointing(self):
356
+ # not supported
357
+ pass
358
+
359
+ def prepare_grad_etc(self, text_encoder, unet):
360
+ self.requires_grad_(True)
361
+
362
+ def on_epoch_start(self, text_encoder, unet):
363
+ self.train()
364
+
365
+ def get_trainable_params(self):
366
+ return self.parameters()
367
+
368
+ def save_weights(self, file, dtype, metadata):
369
+ if metadata is not None and len(metadata) == 0:
370
+ metadata = None
371
+
372
+ state_dict = self.state_dict()
373
+
374
+ if dtype is not None:
375
+ for key in list(state_dict.keys()):
376
+ v = state_dict[key]
377
+ v = v.detach().clone().to("cpu").to(dtype)
378
+ state_dict[key] = v
379
+
380
+ if os.path.splitext(file)[1] == ".safetensors":
381
+ from safetensors.torch import save_file
382
+ from library import train_util
383
+
384
+ # Precalculate model hashes to save time on indexing
385
+ if metadata is None:
386
+ metadata = {}
387
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
388
+ metadata["sshs_model_hash"] = model_hash
389
+ metadata["sshs_legacy_hash"] = legacy_hash
390
+
391
+ save_file(state_dict, file, metadata)
392
+ else:
393
+ torch.save(state_dict, file)
394
+
395
+ def backup_weights(self):
396
+ # 重みのバックアップを行う
397
+ ofts: List[OFTInfModule] = self.unet_ofts
398
+ for oft in ofts:
399
+ org_module = oft.org_module[0]
400
+ if not hasattr(org_module, "_lora_org_weight"):
401
+ sd = org_module.state_dict()
402
+ org_module._lora_org_weight = sd["weight"].detach().clone()
403
+ org_module._lora_restored = True
404
+
405
+ def restore_weights(self):
406
+ # 重みのリストアを行う
407
+ ofts: List[OFTInfModule] = self.unet_ofts
408
+ for oft in ofts:
409
+ org_module = oft.org_module[0]
410
+ if not org_module._lora_restored:
411
+ sd = org_module.state_dict()
412
+ sd["weight"] = org_module._lora_org_weight
413
+ org_module.load_state_dict(sd)
414
+ org_module._lora_restored = True
415
+
416
+ def pre_calculation(self):
417
+ # 事前計算を行う
418
+ ofts: List[OFTInfModule] = self.unet_ofts
419
+ for oft in ofts:
420
+ org_module = oft.org_module[0]
421
+ oft.merge_to()
422
+ # sd = org_module.state_dict()
423
+ # org_weight = sd["weight"]
424
+ # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
425
+ # sd["weight"] = org_weight + lora_weight
426
+ # assert sd["weight"].shape == org_weight.shape
427
+ # org_module.load_state_dict(sd)
428
+
429
+ org_module._lora_restored = False
430
+ oft.enabled = False
external/llite/networks/resize_lora.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2
+ # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo
4
+
5
+ import argparse
6
+ import torch
7
+ from safetensors.torch import load_file, save_file, safe_open
8
+ from tqdm import tqdm
9
+ from library import train_util, model_util
10
+ import numpy as np
11
+
12
+ MIN_SV = 1e-6
13
+
14
+ # Model save and load functions
15
+
16
+ def load_state_dict(file_name, dtype):
17
+ if model_util.is_safetensors(file_name):
18
+ sd = load_file(file_name)
19
+ with safe_open(file_name, framework="pt") as f:
20
+ metadata = f.metadata()
21
+ else:
22
+ sd = torch.load(file_name, map_location='cpu')
23
+ metadata = None
24
+
25
+ for key in list(sd.keys()):
26
+ if type(sd[key]) == torch.Tensor:
27
+ sd[key] = sd[key].to(dtype)
28
+
29
+ return sd, metadata
30
+
31
+
32
+ def save_to_file(file_name, model, state_dict, dtype, metadata):
33
+ if dtype is not None:
34
+ for key in list(state_dict.keys()):
35
+ if type(state_dict[key]) == torch.Tensor:
36
+ state_dict[key] = state_dict[key].to(dtype)
37
+
38
+ if model_util.is_safetensors(file_name):
39
+ save_file(model, file_name, metadata)
40
+ else:
41
+ torch.save(model, file_name)
42
+
43
+
44
+ # Indexing functions
45
+
46
+ def index_sv_cumulative(S, target):
47
+ original_sum = float(torch.sum(S))
48
+ cumulative_sums = torch.cumsum(S, dim=0)/original_sum
49
+ index = int(torch.searchsorted(cumulative_sums, target)) + 1
50
+ index = max(1, min(index, len(S)-1))
51
+
52
+ return index
53
+
54
+
55
+ def index_sv_fro(S, target):
56
+ S_squared = S.pow(2)
57
+ s_fro_sq = float(torch.sum(S_squared))
58
+ sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
59
+ index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
60
+ index = max(1, min(index, len(S)-1))
61
+
62
+ return index
63
+
64
+
65
+ def index_sv_ratio(S, target):
66
+ max_sv = S[0]
67
+ min_sv = max_sv/target
68
+ index = int(torch.sum(S > min_sv).item())
69
+ index = max(1, min(index, len(S)-1))
70
+
71
+ return index
72
+
73
+
74
+ # Modified from Kohaku-blueleaf's extract/merge functions
75
+ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
76
+ out_size, in_size, kernel_size, _ = weight.size()
77
+ U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
78
+
79
+ param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
80
+ lora_rank = param_dict["new_rank"]
81
+
82
+ U = U[:, :lora_rank]
83
+ S = S[:lora_rank]
84
+ U = U @ torch.diag(S)
85
+ Vh = Vh[:lora_rank, :]
86
+
87
+ param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
88
+ param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
89
+ del U, S, Vh, weight
90
+ return param_dict
91
+
92
+
93
+ def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
94
+ out_size, in_size = weight.size()
95
+
96
+ U, S, Vh = torch.linalg.svd(weight.to(device))
97
+
98
+ param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
99
+ lora_rank = param_dict["new_rank"]
100
+
101
+ U = U[:, :lora_rank]
102
+ S = S[:lora_rank]
103
+ U = U @ torch.diag(S)
104
+ Vh = Vh[:lora_rank, :]
105
+
106
+ param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
107
+ param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
108
+ del U, S, Vh, weight
109
+ return param_dict
110
+
111
+
112
+ def merge_conv(lora_down, lora_up, device):
113
+ in_rank, in_size, kernel_size, k_ = lora_down.shape
114
+ out_size, out_rank, _, _ = lora_up.shape
115
+ assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
116
+
117
+ lora_down = lora_down.to(device)
118
+ lora_up = lora_up.to(device)
119
+
120
+ merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
121
+ weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
122
+ del lora_up, lora_down
123
+ return weight
124
+
125
+
126
+ def merge_linear(lora_down, lora_up, device):
127
+ in_rank, in_size = lora_down.shape
128
+ out_size, out_rank = lora_up.shape
129
+ assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
130
+
131
+ lora_down = lora_down.to(device)
132
+ lora_up = lora_up.to(device)
133
+
134
+ weight = lora_up @ lora_down
135
+ del lora_up, lora_down
136
+ return weight
137
+
138
+
139
+ # Calculate new rank
140
+
141
+ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
142
+ param_dict = {}
143
+
144
+ if dynamic_method=="sv_ratio":
145
+ # Calculate new dim and alpha based off ratio
146
+ new_rank = index_sv_ratio(S, dynamic_param) + 1
147
+ new_alpha = float(scale*new_rank)
148
+
149
+ elif dynamic_method=="sv_cumulative":
150
+ # Calculate new dim and alpha based off cumulative sum
151
+ new_rank = index_sv_cumulative(S, dynamic_param) + 1
152
+ new_alpha = float(scale*new_rank)
153
+
154
+ elif dynamic_method=="sv_fro":
155
+ # Calculate new dim and alpha based off sqrt sum of squares
156
+ new_rank = index_sv_fro(S, dynamic_param) + 1
157
+ new_alpha = float(scale*new_rank)
158
+ else:
159
+ new_rank = rank
160
+ new_alpha = float(scale*new_rank)
161
+
162
+
163
+ if S[0] <= MIN_SV: # Zero matrix, set dim to 1
164
+ new_rank = 1
165
+ new_alpha = float(scale*new_rank)
166
+ elif new_rank > rank: # cap max rank at rank
167
+ new_rank = rank
168
+ new_alpha = float(scale*new_rank)
169
+
170
+
171
+ # Calculate resize info
172
+ s_sum = torch.sum(torch.abs(S))
173
+ s_rank = torch.sum(torch.abs(S[:new_rank]))
174
+
175
+ S_squared = S.pow(2)
176
+ s_fro = torch.sqrt(torch.sum(S_squared))
177
+ s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
178
+ fro_percent = float(s_red_fro/s_fro)
179
+
180
+ param_dict["new_rank"] = new_rank
181
+ param_dict["new_alpha"] = new_alpha
182
+ param_dict["sum_retained"] = (s_rank)/s_sum
183
+ param_dict["fro_retained"] = fro_percent
184
+ param_dict["max_ratio"] = S[0]/S[new_rank - 1]
185
+
186
+ return param_dict
187
+
188
+
189
+ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
190
+ network_alpha = None
191
+ network_dim = None
192
+ verbose_str = "\n"
193
+ fro_list = []
194
+
195
+ # Extract loaded lora dim and alpha
196
+ for key, value in lora_sd.items():
197
+ if network_alpha is None and 'alpha' in key:
198
+ network_alpha = value
199
+ if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
200
+ network_dim = value.size()[0]
201
+ if network_alpha is not None and network_dim is not None:
202
+ break
203
+ if network_alpha is None:
204
+ network_alpha = network_dim
205
+
206
+ scale = network_alpha/network_dim
207
+
208
+ if dynamic_method:
209
+ print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
210
+
211
+ lora_down_weight = None
212
+ lora_up_weight = None
213
+
214
+ o_lora_sd = lora_sd.copy()
215
+ block_down_name = None
216
+ block_up_name = None
217
+
218
+ with torch.no_grad():
219
+ for key, value in tqdm(lora_sd.items()):
220
+ weight_name = None
221
+ if 'lora_down' in key:
222
+ block_down_name = key.rsplit('.lora_down', 1)[0]
223
+ weight_name = key.rsplit(".", 1)[-1]
224
+ lora_down_weight = value
225
+ else:
226
+ continue
227
+
228
+ # find corresponding lora_up and alpha
229
+ block_up_name = block_down_name
230
+ lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
231
+ lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
232
+
233
+ weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
234
+
235
+ if weights_loaded:
236
+
237
+ conv2d = (len(lora_down_weight.size()) == 4)
238
+ if lora_alpha is None:
239
+ scale = 1.0
240
+ else:
241
+ scale = lora_alpha/lora_down_weight.size()[0]
242
+
243
+ if conv2d:
244
+ full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
245
+ param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
246
+ else:
247
+ full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
248
+ param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
249
+
250
+ if verbose:
251
+ max_ratio = param_dict['max_ratio']
252
+ sum_retained = param_dict['sum_retained']
253
+ fro_retained = param_dict['fro_retained']
254
+ if not np.isnan(fro_retained):
255
+ fro_list.append(float(fro_retained))
256
+
257
+ verbose_str+=f"{block_down_name:75} | "
258
+ verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
259
+
260
+ if verbose and dynamic_method:
261
+ verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
262
+ else:
263
+ verbose_str+=f"\n"
264
+
265
+ new_alpha = param_dict['new_alpha']
266
+ o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
267
+ o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
268
+ o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
269
+
270
+ block_down_name = None
271
+ block_up_name = None
272
+ lora_down_weight = None
273
+ lora_up_weight = None
274
+ weights_loaded = False
275
+ del param_dict
276
+
277
+ if verbose:
278
+ print(verbose_str)
279
+
280
+ print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
281
+ print("resizing complete")
282
+ return o_lora_sd, network_dim, new_alpha
283
+
284
+
285
+ def resize(args):
286
+ if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
287
+ raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
288
+
289
+
290
+ def str_to_dtype(p):
291
+ if p == 'float':
292
+ return torch.float
293
+ if p == 'fp16':
294
+ return torch.float16
295
+ if p == 'bf16':
296
+ return torch.bfloat16
297
+ return None
298
+
299
+ if args.dynamic_method and not args.dynamic_param:
300
+ raise Exception("If using dynamic_method, then dynamic_param is required")
301
+
302
+ merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
303
+ save_dtype = str_to_dtype(args.save_precision)
304
+ if save_dtype is None:
305
+ save_dtype = merge_dtype
306
+
307
+ print("loading Model...")
308
+ lora_sd, metadata = load_state_dict(args.model, merge_dtype)
309
+
310
+ print("Resizing Lora...")
311
+ state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
312
+
313
+ # update metadata
314
+ if metadata is None:
315
+ metadata = {}
316
+
317
+ comment = metadata.get("ss_training_comment", "")
318
+
319
+ if not args.dynamic_method:
320
+ metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
321
+ metadata["ss_network_dim"] = str(args.new_rank)
322
+ metadata["ss_network_alpha"] = str(new_alpha)
323
+ else:
324
+ metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
325
+ metadata["ss_network_dim"] = 'Dynamic'
326
+ metadata["ss_network_alpha"] = 'Dynamic'
327
+
328
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
329
+ metadata["sshs_model_hash"] = model_hash
330
+ metadata["sshs_legacy_hash"] = legacy_hash
331
+
332
+ print(f"saving model to: {args.save_to}")
333
+ save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
334
+
335
+
336
+ def setup_parser() -> argparse.ArgumentParser:
337
+ parser = argparse.ArgumentParser()
338
+
339
+ parser.add_argument("--save_precision", type=str, default=None,
340
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
341
+ parser.add_argument("--new_rank", type=int, default=4,
342
+ help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
343
+ parser.add_argument("--save_to", type=str, default=None,
344
+ help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
345
+ parser.add_argument("--model", type=str, default=None,
346
+ help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
347
+ parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
348
+ parser.add_argument("--verbose", action="store_true",
349
+ help="Display verbose resizing information / rank変更時の詳細情報を出力する")
350
+ parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
351
+ help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
352
+ parser.add_argument("--dynamic_param", type=float, default=None,
353
+ help="Specify target for dynamic reduction")
354
+
355
+ return parser
356
+
357
+
358
+ if __name__ == '__main__':
359
+ parser = setup_parser()
360
+
361
+ args = parser.parse_args()
362
+ resize(args)
external/llite/networks/sdxl_merge_lora.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import argparse
3
+ import os
4
+ import time
5
+ import torch
6
+ from safetensors.torch import load_file, save_file
7
+ from tqdm import tqdm
8
+ from library import sai_model_spec, sdxl_model_util, train_util
9
+ import library.model_util as model_util
10
+ import lora
11
+
12
+
13
+ def load_state_dict(file_name, dtype):
14
+ if os.path.splitext(file_name)[1] == ".safetensors":
15
+ sd = load_file(file_name)
16
+ metadata = train_util.load_metadata_from_safetensors(file_name)
17
+ else:
18
+ sd = torch.load(file_name, map_location="cpu")
19
+ metadata = {}
20
+
21
+ for key in list(sd.keys()):
22
+ if type(sd[key]) == torch.Tensor:
23
+ sd[key] = sd[key].to(dtype)
24
+
25
+ return sd, metadata
26
+
27
+
28
+ def save_to_file(file_name, model, state_dict, dtype, metadata):
29
+ if dtype is not None:
30
+ for key in list(state_dict.keys()):
31
+ if type(state_dict[key]) == torch.Tensor:
32
+ state_dict[key] = state_dict[key].to(dtype)
33
+
34
+ if os.path.splitext(file_name)[1] == ".safetensors":
35
+ save_file(model, file_name, metadata=metadata)
36
+ else:
37
+ torch.save(model, file_name)
38
+
39
+
40
+ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
41
+ text_encoder1.to(merge_dtype)
42
+ text_encoder1.to(merge_dtype)
43
+ unet.to(merge_dtype)
44
+
45
+ # create module map
46
+ name_to_module = {}
47
+ for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
48
+ if i <= 1:
49
+ if i == 0:
50
+ prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
51
+ else:
52
+ prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
53
+ target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
54
+ else:
55
+ prefix = lora.LoRANetwork.LORA_PREFIX_UNET
56
+ target_replace_modules = (
57
+ lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
58
+ )
59
+
60
+ for name, module in root_module.named_modules():
61
+ if module.__class__.__name__ in target_replace_modules:
62
+ for child_name, child_module in module.named_modules():
63
+ if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
64
+ lora_name = prefix + "." + name + "." + child_name
65
+ lora_name = lora_name.replace(".", "_")
66
+ name_to_module[lora_name] = child_module
67
+
68
+ for model, ratio in zip(models, ratios):
69
+ print(f"loading: {model}")
70
+ lora_sd, _ = load_state_dict(model, merge_dtype)
71
+
72
+ print(f"merging...")
73
+ for key in tqdm(lora_sd.keys()):
74
+ if "lora_down" in key:
75
+ up_key = key.replace("lora_down", "lora_up")
76
+ alpha_key = key[: key.index("lora_down")] + "alpha"
77
+
78
+ # find original module for this lora
79
+ module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
80
+ if module_name not in name_to_module:
81
+ print(f"no module found for LoRA weight: {key}")
82
+ continue
83
+ module = name_to_module[module_name]
84
+ # print(f"apply {key} to {module}")
85
+
86
+ down_weight = lora_sd[key]
87
+ up_weight = lora_sd[up_key]
88
+
89
+ dim = down_weight.size()[0]
90
+ alpha = lora_sd.get(alpha_key, dim)
91
+ scale = alpha / dim
92
+
93
+ # W <- W + U * D
94
+ weight = module.weight
95
+ # print(module_name, down_weight.size(), up_weight.size())
96
+ if len(weight.size()) == 2:
97
+ # linear
98
+ weight = weight + ratio * (up_weight @ down_weight) * scale
99
+ elif down_weight.size()[2:4] == (1, 1):
100
+ # conv2d 1x1
101
+ weight = (
102
+ weight
103
+ + ratio
104
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
105
+ * scale
106
+ )
107
+ else:
108
+ # conv2d 3x3
109
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
110
+ # print(conved.size(), weight.size(), module.stride, module.padding)
111
+ weight = weight + ratio * conved * scale
112
+
113
+ module.weight = torch.nn.Parameter(weight)
114
+
115
+
116
+ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
117
+ base_alphas = {} # alpha for merged model
118
+ base_dims = {}
119
+
120
+ merged_sd = {}
121
+ v2 = None
122
+ base_model = None
123
+ for model, ratio in zip(models, ratios):
124
+ print(f"loading: {model}")
125
+ lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
126
+
127
+ if lora_metadata is not None:
128
+ if v2 is None:
129
+ v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
130
+ if base_model is None:
131
+ base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
132
+
133
+ # get alpha and dim
134
+ alphas = {} # alpha for current model
135
+ dims = {} # dims for current model
136
+ for key in lora_sd.keys():
137
+ if "alpha" in key:
138
+ lora_module_name = key[: key.rfind(".alpha")]
139
+ alpha = float(lora_sd[key].detach().numpy())
140
+ alphas[lora_module_name] = alpha
141
+ if lora_module_name not in base_alphas:
142
+ base_alphas[lora_module_name] = alpha
143
+ elif "lora_down" in key:
144
+ lora_module_name = key[: key.rfind(".lora_down")]
145
+ dim = lora_sd[key].size()[0]
146
+ dims[lora_module_name] = dim
147
+ if lora_module_name not in base_dims:
148
+ base_dims[lora_module_name] = dim
149
+
150
+ for lora_module_name in dims.keys():
151
+ if lora_module_name not in alphas:
152
+ alpha = dims[lora_module_name]
153
+ alphas[lora_module_name] = alpha
154
+ if lora_module_name not in base_alphas:
155
+ base_alphas[lora_module_name] = alpha
156
+
157
+ print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
158
+
159
+ # merge
160
+ print(f"merging...")
161
+ for key in tqdm(lora_sd.keys()):
162
+ if "alpha" in key:
163
+ continue
164
+
165
+ if "lora_up" in key and concat:
166
+ concat_dim = 1
167
+ elif "lora_down" in key and concat:
168
+ concat_dim = 0
169
+ else:
170
+ concat_dim = None
171
+
172
+ lora_module_name = key[: key.rfind(".lora_")]
173
+
174
+ base_alpha = base_alphas[lora_module_name]
175
+ alpha = alphas[lora_module_name]
176
+
177
+ scale = math.sqrt(alpha / base_alpha) * ratio
178
+ scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
179
+
180
+ if key in merged_sd:
181
+ assert (
182
+ merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
183
+ ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
184
+ if concat_dim is not None:
185
+ merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
186
+ else:
187
+ merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
188
+ else:
189
+ merged_sd[key] = lora_sd[key] * scale
190
+
191
+ # set alpha to sd
192
+ for lora_module_name, alpha in base_alphas.items():
193
+ key = lora_module_name + ".alpha"
194
+ merged_sd[key] = torch.tensor(alpha)
195
+ if shuffle:
196
+ key_down = lora_module_name + ".lora_down.weight"
197
+ key_up = lora_module_name + ".lora_up.weight"
198
+ dim = merged_sd[key_down].shape[0]
199
+ perm = torch.randperm(dim)
200
+ merged_sd[key_down] = merged_sd[key_down][perm]
201
+ merged_sd[key_up] = merged_sd[key_up][:,perm]
202
+
203
+ print("merged model")
204
+ print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
205
+
206
+ # check all dims are same
207
+ dims_list = list(set(base_dims.values()))
208
+ alphas_list = list(set(base_alphas.values()))
209
+ all_same_dims = True
210
+ all_same_alphas = True
211
+ for dims in dims_list:
212
+ if dims != dims_list[0]:
213
+ all_same_dims = False
214
+ break
215
+ for alphas in alphas_list:
216
+ if alphas != alphas_list[0]:
217
+ all_same_alphas = False
218
+ break
219
+
220
+ # build minimum metadata
221
+ dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
222
+ alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
223
+ metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
224
+
225
+ return merged_sd, metadata
226
+
227
+
228
+ def merge(args):
229
+ assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
230
+
231
+ def str_to_dtype(p):
232
+ if p == "float":
233
+ return torch.float
234
+ if p == "fp16":
235
+ return torch.float16
236
+ if p == "bf16":
237
+ return torch.bfloat16
238
+ return None
239
+
240
+ merge_dtype = str_to_dtype(args.precision)
241
+ save_dtype = str_to_dtype(args.save_precision)
242
+ if save_dtype is None:
243
+ save_dtype = merge_dtype
244
+
245
+ if args.sd_model is not None:
246
+ print(f"loading SD model: {args.sd_model}")
247
+
248
+ (
249
+ text_model1,
250
+ text_model2,
251
+ vae,
252
+ unet,
253
+ logit_scale,
254
+ ckpt_info,
255
+ ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
256
+
257
+ merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
258
+
259
+ if args.no_metadata:
260
+ sai_metadata = None
261
+ else:
262
+ merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
263
+ title = os.path.splitext(os.path.basename(args.save_to))[0]
264
+ sai_metadata = sai_model_spec.build_metadata(
265
+ None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
266
+ )
267
+
268
+ print(f"saving SD model to: {args.save_to}")
269
+ sdxl_model_util.save_stable_diffusion_checkpoint(
270
+ args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
271
+ )
272
+ else:
273
+ state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
274
+
275
+ print(f"calculating hashes and creating metadata...")
276
+
277
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
278
+ metadata["sshs_model_hash"] = model_hash
279
+ metadata["sshs_legacy_hash"] = legacy_hash
280
+
281
+ if not args.no_metadata:
282
+ merged_from = sai_model_spec.build_merged_from(args.models)
283
+ title = os.path.splitext(os.path.basename(args.save_to))[0]
284
+ sai_metadata = sai_model_spec.build_metadata(
285
+ state_dict, False, False, True, True, False, time.time(), title=title, merged_from=merged_from
286
+ )
287
+ metadata.update(sai_metadata)
288
+
289
+ print(f"saving model to: {args.save_to}")
290
+ save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
291
+
292
+
293
+ def setup_parser() -> argparse.ArgumentParser:
294
+ parser = argparse.ArgumentParser()
295
+ parser.add_argument(
296
+ "--save_precision",
297
+ type=str,
298
+ default=None,
299
+ choices=[None, "float", "fp16", "bf16"],
300
+ help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
301
+ )
302
+ parser.add_argument(
303
+ "--precision",
304
+ type=str,
305
+ default="float",
306
+ choices=["float", "fp16", "bf16"],
307
+ help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
308
+ )
309
+ parser.add_argument(
310
+ "--sd_model",
311
+ type=str,
312
+ default=None,
313
+ help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
314
+ )
315
+ parser.add_argument(
316
+ "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
317
+ )
318
+ parser.add_argument(
319
+ "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
320
+ )
321
+ parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
322
+ parser.add_argument(
323
+ "--no_metadata",
324
+ action="store_true",
325
+ help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
326
+ + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
327
+ )
328
+ parser.add_argument(
329
+ "--concat",
330
+ action="store_true",
331
+ help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
332
+ + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
333
+ )
334
+ parser.add_argument(
335
+ "--shuffle",
336
+ action="store_true",
337
+ help="shuffle lora weight./ "
338
+ + "LoRAの重みをシャッフルする",
339
+ )
340
+
341
+ return parser
342
+
343
+
344
+ if __name__ == "__main__":
345
+ parser = setup_parser()
346
+
347
+ args = parser.parse_args()
348
+ merge(args)
external/llite/networks/svd_merge_lora.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import argparse
3
+ import os
4
+ import time
5
+ import torch
6
+ from safetensors.torch import load_file, save_file
7
+ from tqdm import tqdm
8
+ from library import sai_model_spec, train_util
9
+ import library.model_util as model_util
10
+ import lora
11
+
12
+
13
+ CLAMP_QUANTILE = 0.99
14
+
15
+
16
+ def load_state_dict(file_name, dtype):
17
+ if os.path.splitext(file_name)[1] == ".safetensors":
18
+ sd = load_file(file_name)
19
+ metadata = train_util.load_metadata_from_safetensors(file_name)
20
+ else:
21
+ sd = torch.load(file_name, map_location="cpu")
22
+ metadata = {}
23
+
24
+ for key in list(sd.keys()):
25
+ if type(sd[key]) == torch.Tensor:
26
+ sd[key] = sd[key].to(dtype)
27
+
28
+ return sd, metadata
29
+
30
+
31
+ def save_to_file(file_name, state_dict, dtype, metadata):
32
+ if dtype is not None:
33
+ for key in list(state_dict.keys()):
34
+ if type(state_dict[key]) == torch.Tensor:
35
+ state_dict[key] = state_dict[key].to(dtype)
36
+
37
+ if os.path.splitext(file_name)[1] == ".safetensors":
38
+ save_file(state_dict, file_name, metadata=metadata)
39
+ else:
40
+ torch.save(state_dict, file_name)
41
+
42
+
43
+ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
44
+ print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
45
+ merged_sd = {}
46
+ v2 = None
47
+ base_model = None
48
+ for model, ratio in zip(models, ratios):
49
+ print(f"loading: {model}")
50
+ lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
51
+
52
+ if lora_metadata is not None:
53
+ if v2 is None:
54
+ v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string
55
+ if base_model is None:
56
+ base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
57
+
58
+ # merge
59
+ print(f"merging...")
60
+ for key in tqdm(list(lora_sd.keys())):
61
+ if "lora_down" not in key:
62
+ continue
63
+
64
+ lora_module_name = key[: key.rfind(".lora_down")]
65
+
66
+ down_weight = lora_sd[key]
67
+ network_dim = down_weight.size()[0]
68
+
69
+ up_weight = lora_sd[lora_module_name + ".lora_up.weight"]
70
+ alpha = lora_sd.get(lora_module_name + ".alpha", network_dim)
71
+
72
+ in_dim = down_weight.size()[1]
73
+ out_dim = up_weight.size()[0]
74
+ conv2d = len(down_weight.size()) == 4
75
+ kernel_size = None if not conv2d else down_weight.size()[2:4]
76
+ # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
77
+
78
+ # make original weight if not exist
79
+ if lora_module_name not in merged_sd:
80
+ weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
81
+ if device:
82
+ weight = weight.to(device)
83
+ else:
84
+ weight = merged_sd[lora_module_name]
85
+
86
+ # merge to weight
87
+ if device:
88
+ up_weight = up_weight.to(device)
89
+ down_weight = down_weight.to(device)
90
+
91
+ # W <- W + U * D
92
+ scale = alpha / network_dim
93
+
94
+ if device: # and isinstance(scale, torch.Tensor):
95
+ scale = scale.to(device)
96
+
97
+ if not conv2d: # linear
98
+ weight = weight + ratio * (up_weight @ down_weight) * scale
99
+ elif kernel_size == (1, 1):
100
+ weight = (
101
+ weight
102
+ + ratio
103
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
104
+ * scale
105
+ )
106
+ else:
107
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
108
+ weight = weight + ratio * conved * scale
109
+
110
+ merged_sd[lora_module_name] = weight
111
+
112
+ # extract from merged weights
113
+ print("extract new lora...")
114
+ merged_lora_sd = {}
115
+ with torch.no_grad():
116
+ for lora_module_name, mat in tqdm(list(merged_sd.items())):
117
+ conv2d = len(mat.size()) == 4
118
+ kernel_size = None if not conv2d else mat.size()[2:4]
119
+ conv2d_3x3 = conv2d and kernel_size != (1, 1)
120
+ out_dim, in_dim = mat.size()[0:2]
121
+
122
+ if conv2d:
123
+ if conv2d_3x3:
124
+ mat = mat.flatten(start_dim=1)
125
+ else:
126
+ mat = mat.squeeze()
127
+
128
+ module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
129
+ module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
130
+
131
+ U, S, Vh = torch.linalg.svd(mat)
132
+
133
+ U = U[:, :module_new_rank]
134
+ S = S[:module_new_rank]
135
+ U = U @ torch.diag(S)
136
+
137
+ Vh = Vh[:module_new_rank, :]
138
+
139
+ dist = torch.cat([U.flatten(), Vh.flatten()])
140
+ hi_val = torch.quantile(dist, CLAMP_QUANTILE)
141
+ low_val = -hi_val
142
+
143
+ U = U.clamp(low_val, hi_val)
144
+ Vh = Vh.clamp(low_val, hi_val)
145
+
146
+ if conv2d:
147
+ U = U.reshape(out_dim, module_new_rank, 1, 1)
148
+ Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
149
+
150
+ up_weight = U
151
+ down_weight = Vh
152
+
153
+ merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
154
+ merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
155
+ merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
156
+
157
+ # build minimum metadata
158
+ dims = f"{new_rank}"
159
+ alphas = f"{new_rank}"
160
+ if new_conv_rank is not None:
161
+ network_args = {"conv_dim": new_conv_rank, "conv_alpha": new_conv_rank}
162
+ else:
163
+ network_args = None
164
+ metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, network_args)
165
+
166
+ return merged_lora_sd, metadata, v2 == "True", base_model
167
+
168
+
169
+ def merge(args):
170
+ assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
171
+
172
+ def str_to_dtype(p):
173
+ if p == "float":
174
+ return torch.float
175
+ if p == "fp16":
176
+ return torch.float16
177
+ if p == "bf16":
178
+ return torch.bfloat16
179
+ return None
180
+
181
+ merge_dtype = str_to_dtype(args.precision)
182
+ save_dtype = str_to_dtype(args.save_precision)
183
+ if save_dtype is None:
184
+ save_dtype = merge_dtype
185
+
186
+ new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
187
+ state_dict, metadata, v2, base_model = merge_lora_models(
188
+ args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
189
+ )
190
+
191
+ print(f"calculating hashes and creating metadata...")
192
+
193
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
194
+ metadata["sshs_model_hash"] = model_hash
195
+ metadata["sshs_legacy_hash"] = legacy_hash
196
+
197
+ if not args.no_metadata:
198
+ is_sdxl = base_model is not None and base_model.lower().startswith("sdxl")
199
+ merged_from = sai_model_spec.build_merged_from(args.models)
200
+ title = os.path.splitext(os.path.basename(args.save_to))[0]
201
+ sai_metadata = sai_model_spec.build_metadata(
202
+ state_dict, v2, v2, is_sdxl, True, False, time.time(), title=title, merged_from=merged_from
203
+ )
204
+ if v2:
205
+ # TODO read sai modelspec
206
+ print(
207
+ "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
208
+ )
209
+ metadata.update(sai_metadata)
210
+
211
+ print(f"saving model to: {args.save_to}")
212
+ save_to_file(args.save_to, state_dict, save_dtype, metadata)
213
+
214
+
215
+ def setup_parser() -> argparse.ArgumentParser:
216
+ parser = argparse.ArgumentParser()
217
+ parser.add_argument(
218
+ "--save_precision",
219
+ type=str,
220
+ default=None,
221
+ choices=[None, "float", "fp16", "bf16"],
222
+ help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
223
+ )
224
+ parser.add_argument(
225
+ "--precision",
226
+ type=str,
227
+ default="float",
228
+ choices=["float", "fp16", "bf16"],
229
+ help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
230
+ )
231
+ parser.add_argument(
232
+ "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
233
+ )
234
+ parser.add_argument(
235
+ "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
236
+ )
237
+ parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
238
+ parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
239
+ parser.add_argument(
240
+ "--new_conv_rank",
241
+ type=int,
242
+ default=None,
243
+ help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
244
+ )
245
+ parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
246
+ parser.add_argument(
247
+ "--no_metadata",
248
+ action="store_true",
249
+ help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
250
+ + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
251
+ )
252
+
253
+ return parser
254
+
255
+
256
+ if __name__ == "__main__":
257
+ parser = setup_parser()
258
+
259
+ args = parser.parse_args()
260
+ merge(args)
external/llite/tools/cache_latents.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # latentsのdiskへの事前キャッシュを行う / cache latents to disk
2
+
3
+ import argparse
4
+ import math
5
+ from multiprocessing import Value
6
+ import os
7
+
8
+ from accelerate.utils import set_seed
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from library import config_util
13
+ from library import train_util
14
+ from library import sdxl_train_util
15
+ from library.config_util import (
16
+ ConfigSanitizer,
17
+ BlueprintGenerator,
18
+ )
19
+
20
+
21
+ def cache_to_disk(args: argparse.Namespace) -> None:
22
+ train_util.prepare_dataset_args(args, True)
23
+
24
+ # check cache latents arg
25
+ assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
26
+
27
+ use_dreambooth_method = args.in_json is None
28
+
29
+ if args.seed is not None:
30
+ set_seed(args.seed) # 乱数系列を初期化する
31
+
32
+ # tokenizerを準備する:datasetを動かすために必要
33
+ if args.sdxl:
34
+ tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
35
+ tokenizers = [tokenizer1, tokenizer2]
36
+ else:
37
+ tokenizer = train_util.load_tokenizer(args)
38
+ tokenizers = [tokenizer]
39
+
40
+ # データセットを準備する
41
+ if args.dataset_class is None:
42
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
43
+ if args.dataset_config is not None:
44
+ print(f"Load dataset config from {args.dataset_config}")
45
+ user_config = config_util.load_user_config(args.dataset_config)
46
+ ignored = ["train_data_dir", "in_json"]
47
+ if any(getattr(args, attr) is not None for attr in ignored):
48
+ print(
49
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
50
+ ", ".join(ignored)
51
+ )
52
+ )
53
+ else:
54
+ if use_dreambooth_method:
55
+ print("Using DreamBooth method.")
56
+ user_config = {
57
+ "datasets": [
58
+ {
59
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
60
+ args.train_data_dir, args.reg_data_dir
61
+ )
62
+ }
63
+ ]
64
+ }
65
+ else:
66
+ print("Training with captions.")
67
+ user_config = {
68
+ "datasets": [
69
+ {
70
+ "subsets": [
71
+ {
72
+ "image_dir": args.train_data_dir,
73
+ "metadata_file": args.in_json,
74
+ }
75
+ ]
76
+ }
77
+ ]
78
+ }
79
+
80
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
81
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
82
+ else:
83
+ train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
84
+
85
+ # datasetのcache_latentsを呼ばなければ、生の画像が返る
86
+
87
+ current_epoch = Value("i", 0)
88
+ current_step = Value("i", 0)
89
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
90
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
91
+
92
+ # acceleratorを準備する
93
+ print("prepare accelerator")
94
+ accelerator = train_util.prepare_accelerator(args)
95
+
96
+ # mixed precisionに対応した型を用意しておき適宜castする
97
+ weight_dtype, _ = train_util.prepare_dtype(args)
98
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
99
+
100
+ # モデルを読み込む
101
+ print("load model")
102
+ if args.sdxl:
103
+ (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
104
+ else:
105
+ _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
106
+
107
+ if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
108
+ vae.set_use_memory_efficient_attention_xformers(args.xformers)
109
+ vae.to(accelerator.device, dtype=vae_dtype)
110
+ vae.requires_grad_(False)
111
+ vae.eval()
112
+
113
+ # dataloaderを準備する
114
+ train_dataset_group.set_caching_mode("latents")
115
+
116
+ # DataLoaderのプロセス数:0はメインプロセスになる
117
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
118
+
119
+ train_dataloader = torch.utils.data.DataLoader(
120
+ train_dataset_group,
121
+ batch_size=1,
122
+ shuffle=True,
123
+ collate_fn=collator,
124
+ num_workers=n_workers,
125
+ persistent_workers=args.persistent_data_loader_workers,
126
+ )
127
+
128
+ # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
129
+ train_dataloader = accelerator.prepare(train_dataloader)
130
+
131
+ # データ取得のためのループ
132
+ for batch in tqdm(train_dataloader):
133
+ b_size = len(batch["images"])
134
+ vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
135
+ flip_aug = batch["flip_aug"]
136
+ random_crop = batch["random_crop"]
137
+ bucket_reso = batch["bucket_reso"]
138
+
139
+ # バッチを分割して処理する
140
+ for i in range(0, b_size, vae_batch_size):
141
+ images = batch["images"][i : i + vae_batch_size]
142
+ absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
143
+ resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
144
+
145
+ image_infos = []
146
+ for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
147
+ image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
148
+ image_info.image = image
149
+ image_info.bucket_reso = bucket_reso
150
+ image_info.resized_size = resized_size
151
+ image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
152
+
153
+ if args.skip_existing:
154
+ if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
155
+ print(f"Skipping {image_info.latents_npz} because it already exists.")
156
+ continue
157
+
158
+ image_infos.append(image_info)
159
+
160
+ if len(image_infos) > 0:
161
+ train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop)
162
+
163
+ accelerator.wait_for_everyone()
164
+ accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
165
+
166
+
167
+ def setup_parser() -> argparse.ArgumentParser:
168
+ parser = argparse.ArgumentParser()
169
+
170
+ train_util.add_sd_models_arguments(parser)
171
+ train_util.add_training_arguments(parser, True)
172
+ train_util.add_dataset_arguments(parser, True, True, True)
173
+ config_util.add_config_arguments(parser)
174
+ parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
175
+ parser.add_argument(
176
+ "--no_half_vae",
177
+ action="store_true",
178
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
179
+ )
180
+ parser.add_argument(
181
+ "--skip_existing",
182
+ action="store_true",
183
+ help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
184
+ )
185
+ return parser
186
+
187
+
188
+ if __name__ == "__main__":
189
+ parser = setup_parser()
190
+
191
+ args = parser.parse_args()
192
+ args = train_util.read_config_from_file(args, parser)
193
+
194
+ cache_to_disk(args)
external/llite/tools/cache_text_encoder_outputs.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
2
+
3
+ import argparse
4
+ import math
5
+ from multiprocessing import Value
6
+ import os
7
+
8
+ from accelerate.utils import set_seed
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from library import config_util
13
+ from library import train_util
14
+ from library import sdxl_train_util
15
+ from library.config_util import (
16
+ ConfigSanitizer,
17
+ BlueprintGenerator,
18
+ )
19
+
20
+
21
+ def cache_to_disk(args: argparse.Namespace) -> None:
22
+ train_util.prepare_dataset_args(args, True)
23
+
24
+ # check cache arg
25
+ assert (
26
+ args.cache_text_encoder_outputs_to_disk
27
+ ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
28
+
29
+ # できるだけ準備はしておくが今のところSDXLのみしか動かない
30
+ assert (
31
+ args.sdxl
32
+ ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です"
33
+
34
+ use_dreambooth_method = args.in_json is None
35
+
36
+ if args.seed is not None:
37
+ set_seed(args.seed) # 乱数系列を初期化する
38
+
39
+ # tokenizerを準備する:datasetを動かすために必要
40
+ if args.sdxl:
41
+ tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
42
+ tokenizers = [tokenizer1, tokenizer2]
43
+ else:
44
+ tokenizer = train_util.load_tokenizer(args)
45
+ tokenizers = [tokenizer]
46
+
47
+ # データセットを準備する
48
+ if args.dataset_class is None:
49
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
50
+ if args.dataset_config is not None:
51
+ print(f"Load dataset config from {args.dataset_config}")
52
+ user_config = config_util.load_user_config(args.dataset_config)
53
+ ignored = ["train_data_dir", "in_json"]
54
+ if any(getattr(args, attr) is not None for attr in ignored):
55
+ print(
56
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
57
+ ", ".join(ignored)
58
+ )
59
+ )
60
+ else:
61
+ if use_dreambooth_method:
62
+ print("Using DreamBooth method.")
63
+ user_config = {
64
+ "datasets": [
65
+ {
66
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
67
+ args.train_data_dir, args.reg_data_dir
68
+ )
69
+ }
70
+ ]
71
+ }
72
+ else:
73
+ print("Training with captions.")
74
+ user_config = {
75
+ "datasets": [
76
+ {
77
+ "subsets": [
78
+ {
79
+ "image_dir": args.train_data_dir,
80
+ "metadata_file": args.in_json,
81
+ }
82
+ ]
83
+ }
84
+ ]
85
+ }
86
+
87
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
88
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
89
+ else:
90
+ train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
91
+
92
+ current_epoch = Value("i", 0)
93
+ current_step = Value("i", 0)
94
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
95
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
96
+
97
+ # acceleratorを準備する
98
+ print("prepare accelerator")
99
+ accelerator = train_util.prepare_accelerator(args)
100
+
101
+ # mixed precisionに対応した型を用意しておき適宜castする
102
+ weight_dtype, _ = train_util.prepare_dtype(args)
103
+
104
+ # モデルを読み込む
105
+ print("load model")
106
+ if args.sdxl:
107
+ (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
108
+ text_encoders = [text_encoder1, text_encoder2]
109
+ else:
110
+ text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
111
+ text_encoders = [text_encoder1]
112
+
113
+ for text_encoder in text_encoders:
114
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
115
+ text_encoder.requires_grad_(False)
116
+ text_encoder.eval()
117
+
118
+ # dataloaderを準備する
119
+ train_dataset_group.set_caching_mode("text")
120
+
121
+ # DataLoaderのプロセス数:0はメインプロセスになる
122
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
123
+
124
+ train_dataloader = torch.utils.data.DataLoader(
125
+ train_dataset_group,
126
+ batch_size=1,
127
+ shuffle=True,
128
+ collate_fn=collator,
129
+ num_workers=n_workers,
130
+ persistent_workers=args.persistent_data_loader_workers,
131
+ )
132
+
133
+ # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
134
+ train_dataloader = accelerator.prepare(train_dataloader)
135
+
136
+ # データ取得のためのループ
137
+ for batch in tqdm(train_dataloader):
138
+ absolute_paths = batch["absolute_paths"]
139
+ input_ids1_list = batch["input_ids1_list"]
140
+ input_ids2_list = batch["input_ids2_list"]
141
+
142
+ image_infos = []
143
+ for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list):
144
+ image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
145
+ image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
146
+ image_info
147
+
148
+ if args.skip_existing:
149
+ if os.path.exists(image_info.text_encoder_outputs_npz):
150
+ print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
151
+ continue
152
+
153
+ image_info.input_ids1 = input_ids1
154
+ image_info.input_ids2 = input_ids2
155
+ image_infos.append(image_info)
156
+
157
+ if len(image_infos) > 0:
158
+ b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
159
+ b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos])
160
+ train_util.cache_batch_text_encoder_outputs(
161
+ image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype
162
+ )
163
+
164
+ accelerator.wait_for_everyone()
165
+ accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
166
+
167
+
168
+ def setup_parser() -> argparse.ArgumentParser:
169
+ parser = argparse.ArgumentParser()
170
+
171
+ train_util.add_sd_models_arguments(parser)
172
+ train_util.add_training_arguments(parser, True)
173
+ train_util.add_dataset_arguments(parser, True, True, True)
174
+ config_util.add_config_arguments(parser)
175
+ sdxl_train_util.add_sdxl_training_arguments(parser)
176
+ parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
177
+ parser.add_argument(
178
+ "--skip_existing",
179
+ action="store_true",
180
+ help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
181
+ )
182
+ return parser
183
+
184
+
185
+ if __name__ == "__main__":
186
+ parser = setup_parser()
187
+
188
+ args = parser.parse_args()
189
+ args = train_util.read_config_from_file(args, parser)
190
+
191
+ cache_to_disk(args)
external/llite/tools/canny.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+
4
+
5
+ def canny(args):
6
+ img = cv2.imread(args.input)
7
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
8
+
9
+ canny_img = cv2.Canny(img, args.thres1, args.thres2)
10
+ # canny_img = 255 - canny_img
11
+
12
+ cv2.imwrite(args.output, canny_img)
13
+ print("done!")
14
+
15
+
16
+ def setup_parser() -> argparse.ArgumentParser:
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--input", type=str, default=None, help="input path")
19
+ parser.add_argument("--output", type=str, default=None, help="output path")
20
+ parser.add_argument("--thres1", type=int, default=32, help="thres1")
21
+ parser.add_argument("--thres2", type=int, default=224, help="thres2")
22
+
23
+ return parser
24
+
25
+
26
+ if __name__ == '__main__':
27
+ parser = setup_parser()
28
+
29
+ args = parser.parse_args()
30
+ canny(args)
external/llite/tools/convert_diffusers20_original_sd.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # convert Diffusers v1.x/v2.0 model to original Stable Diffusion
2
+
3
+ import argparse
4
+ import os
5
+ import torch
6
+ from diffusers import StableDiffusionPipeline
7
+
8
+ import library.model_util as model_util
9
+
10
+
11
+ def convert(args):
12
+ # 引数を確認する
13
+ load_dtype = torch.float16 if args.fp16 else None
14
+
15
+ save_dtype = None
16
+ if args.fp16 or args.save_precision_as == "fp16":
17
+ save_dtype = torch.float16
18
+ elif args.bf16 or args.save_precision_as == "bf16":
19
+ save_dtype = torch.bfloat16
20
+ elif args.float or args.save_precision_as == "float":
21
+ save_dtype = torch.float
22
+
23
+ is_load_ckpt = os.path.isfile(args.model_to_load)
24
+ is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
25
+
26
+ assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
27
+ # assert (
28
+ # is_save_ckpt or args.reference_model is not None
29
+ # ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
30
+
31
+ # モデルを読み込む
32
+ msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
33
+ print(f"loading {msg}: {args.model_to_load}")
34
+
35
+ if is_load_ckpt:
36
+ v2_model = args.v2
37
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
38
+ v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
39
+ )
40
+ else:
41
+ pipe = StableDiffusionPipeline.from_pretrained(
42
+ args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
43
+ )
44
+ text_encoder = pipe.text_encoder
45
+ vae = pipe.vae
46
+ unet = pipe.unet
47
+
48
+ if args.v1 == args.v2:
49
+ # 自動判定する
50
+ v2_model = unet.config.cross_attention_dim == 1024
51
+ print("checking model version: model is " + ("v2" if v2_model else "v1"))
52
+ else:
53
+ v2_model = not args.v1
54
+
55
+ # 変換して保存する
56
+ msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
57
+ print(f"converting and saving as {msg}: {args.model_to_save}")
58
+
59
+ if is_save_ckpt:
60
+ original_model = args.model_to_load if is_load_ckpt else None
61
+ key_count = model_util.save_stable_diffusion_checkpoint(
62
+ v2_model,
63
+ args.model_to_save,
64
+ text_encoder,
65
+ unet,
66
+ original_model,
67
+ args.epoch,
68
+ args.global_step,
69
+ None if args.metadata is None else eval(args.metadata),
70
+ save_dtype=save_dtype,
71
+ vae=vae,
72
+ )
73
+ print(f"model saved. total converted state_dict keys: {key_count}")
74
+ else:
75
+ print(
76
+ f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
77
+ )
78
+ model_util.save_diffusers_checkpoint(
79
+ v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
80
+ )
81
+ print("model saved.")
82
+
83
+
84
+ def setup_parser() -> argparse.ArgumentParser:
85
+ parser = argparse.ArgumentParser()
86
+ parser.add_argument(
87
+ "--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む"
88
+ )
89
+ parser.add_argument(
90
+ "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
91
+ )
92
+ parser.add_argument(
93
+ "--unet_use_linear_projection",
94
+ action="store_true",
95
+ help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)",
96
+ )
97
+ parser.add_argument(
98
+ "--fp16",
99
+ action="store_true",
100
+ help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)",
101
+ )
102
+ parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)")
103
+ parser.add_argument(
104
+ "--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)"
105
+ )
106
+ parser.add_argument(
107
+ "--save_precision_as",
108
+ type=str,
109
+ default="no",
110
+ choices=["fp16", "bf16", "float"],
111
+ help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください",
112
+ )
113
+ parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値")
114
+ parser.add_argument(
115
+ "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
116
+ )
117
+ parser.add_argument(
118
+ "--metadata",
119
+ type=str,
120
+ default=None,
121
+ help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
122
+ )
123
+ parser.add_argument(
124
+ "--variant",
125
+ type=str,
126
+ default=None,
127
+ help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
128
+ )
129
+ parser.add_argument(
130
+ "--reference_model",
131
+ type=str,
132
+ default=None,
133
+ help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`",
134
+ )
135
+ parser.add_argument(
136
+ "--use_safetensors",
137
+ action="store_true",
138
+ help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)",
139
+ )
140
+
141
+ parser.add_argument(
142
+ "model_to_load",
143
+ type=str,
144
+ default=None,
145
+ help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ",
146
+ )
147
+ parser.add_argument(
148
+ "model_to_save",
149
+ type=str,
150
+ default=None,
151
+ help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存",
152
+ )
153
+ return parser
154
+
155
+
156
+ if __name__ == "__main__":
157
+ parser = setup_parser()
158
+
159
+ args = parser.parse_args()
160
+ convert(args)
external/llite/tools/detect_face_rotate.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
2
+ # (c) 2022 Kohya S. @kohya_ss
3
+
4
+ # 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
5
+
6
+ # v2: extract max face if multiple faces are found
7
+ # v3: add crop_ratio option
8
+ # v4: add multiple faces extraction and min/max size
9
+
10
+ import argparse
11
+ import math
12
+ import cv2
13
+ import glob
14
+ import os
15
+ from anime_face_detector import create_detector
16
+ from tqdm import tqdm
17
+ import numpy as np
18
+
19
+ KP_REYE = 11
20
+ KP_LEYE = 19
21
+
22
+ SCORE_THRES = 0.90
23
+
24
+
25
+ def detect_faces(detector, image, min_size):
26
+ preds = detector(image) # bgr
27
+ # print(len(preds))
28
+
29
+ faces = []
30
+ for pred in preds:
31
+ bb = pred['bbox']
32
+ score = bb[-1]
33
+ if score < SCORE_THRES:
34
+ continue
35
+
36
+ left, top, right, bottom = bb[:4]
37
+ cx = int((left + right) / 2)
38
+ cy = int((top + bottom) / 2)
39
+ fw = int(right - left)
40
+ fh = int(bottom - top)
41
+
42
+ lex, ley = pred['keypoints'][KP_LEYE, 0:2]
43
+ rex, rey = pred['keypoints'][KP_REYE, 0:2]
44
+ angle = math.atan2(ley - rey, lex - rex)
45
+ angle = angle / math.pi * 180
46
+
47
+ faces.append((cx, cy, fw, fh, angle))
48
+
49
+ faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順
50
+ return faces
51
+
52
+
53
+ def rotate_image(image, angle, cx, cy):
54
+ h, w = image.shape[0:2]
55
+ rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
56
+
57
+ # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
58
+ # nh = max(h, int(w * math.sin(angle)))
59
+ # nw = max(w, int(h * math.sin(angle)))
60
+ # if nh > h or nw > w:
61
+ # pad_y = nh - h
62
+ # pad_t = pad_y // 2
63
+ # pad_x = nw - w
64
+ # pad_l = pad_x // 2
65
+ # m = np.array([[0, 0, pad_l],
66
+ # [0, 0, pad_t]])
67
+ # rot_mat = rot_mat + m
68
+ # h, w = nh, nw
69
+ # cx += pad_l
70
+ # cy += pad_t
71
+
72
+ result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
73
+ return result, cx, cy
74
+
75
+
76
+ def process(args):
77
+ assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
78
+ assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
79
+
80
+ # アニメ顔検出モデルを読み込む
81
+ print("loading face detector.")
82
+ detector = create_detector('yolov3')
83
+
84
+ # cropの引数を解析する
85
+ if args.crop_size is None:
86
+ crop_width = crop_height = None
87
+ else:
88
+ tokens = args.crop_size.split(',')
89
+ assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
90
+ crop_width, crop_height = [int(t) for t in tokens]
91
+
92
+ if args.crop_ratio is None:
93
+ crop_h_ratio = crop_v_ratio = None
94
+ else:
95
+ tokens = args.crop_ratio.split(',')
96
+ assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
97
+ crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
98
+
99
+ # 画像を処理する
100
+ print("processing.")
101
+ output_extension = ".png"
102
+
103
+ os.makedirs(args.dst_dir, exist_ok=True)
104
+ paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
105
+ glob.glob(os.path.join(args.src_dir, "*.webp"))
106
+ for path in tqdm(paths):
107
+ basename = os.path.splitext(os.path.basename(path))[0]
108
+
109
+ # image = cv2.imread(path) # 日本語ファイル名でエラーになる
110
+ image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
111
+ if len(image.shape) == 2:
112
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
113
+ if image.shape[2] == 4:
114
+ print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
115
+ image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
116
+
117
+ h, w = image.shape[:2]
118
+
119
+ faces = detect_faces(detector, image, args.multiple_faces)
120
+ for i, face in enumerate(faces):
121
+ cx, cy, fw, fh, angle = face
122
+ face_size = max(fw, fh)
123
+ if args.min_size is not None and face_size < args.min_size:
124
+ continue
125
+ if args.max_size is not None and face_size >= args.max_size:
126
+ continue
127
+ face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
128
+
129
+ # オプション指定があれば回転する
130
+ face_img = image
131
+ if args.rotate:
132
+ face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
133
+
134
+ # オプション指定があれば顔を中心に切り出す
135
+ if crop_width is not None or crop_h_ratio is not None:
136
+ cur_crop_width, cur_crop_height = crop_width, crop_height
137
+ if crop_h_ratio is not None:
138
+ cur_crop_width = int(face_size * crop_h_ratio + .5)
139
+ cur_crop_height = int(face_size * crop_v_ratio + .5)
140
+
141
+ # リサイズを必要なら行う
142
+ scale = 1.0
143
+ if args.resize_face_size is not None:
144
+ # 顔サイズを基準にリサイズする
145
+ scale = args.resize_face_size / face_size
146
+ if scale < cur_crop_width / w:
147
+ print(
148
+ f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
149
+ scale = cur_crop_width / w
150
+ if scale < cur_crop_height / h:
151
+ print(
152
+ f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
153
+ scale = cur_crop_height / h
154
+ elif crop_h_ratio is not None:
155
+ # 倍率指定の時にはリサイズしない
156
+ pass
157
+ else:
158
+ # 切り出しサイズ指定あり
159
+ if w < cur_crop_width:
160
+ print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
161
+ scale = cur_crop_width / w
162
+ if h < cur_crop_height:
163
+ print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
164
+ scale = cur_crop_height / h
165
+ if args.resize_fit:
166
+ scale = max(cur_crop_width / w, cur_crop_height / h)
167
+
168
+ if scale != 1.0:
169
+ w = int(w * scale + .5)
170
+ h = int(h * scale + .5)
171
+ face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
172
+ cx = int(cx * scale + .5)
173
+ cy = int(cy * scale + .5)
174
+ fw = int(fw * scale + .5)
175
+ fh = int(fh * scale + .5)
176
+
177
+ cur_crop_width = min(cur_crop_width, face_img.shape[1])
178
+ cur_crop_height = min(cur_crop_height, face_img.shape[0])
179
+
180
+ x = cx - cur_crop_width // 2
181
+ cx = cur_crop_width // 2
182
+ if x < 0:
183
+ cx = cx + x
184
+ x = 0
185
+ elif x + cur_crop_width > w:
186
+ cx = cx + (x + cur_crop_width - w)
187
+ x = w - cur_crop_width
188
+ face_img = face_img[:, x:x+cur_crop_width]
189
+
190
+ y = cy - cur_crop_height // 2
191
+ cy = cur_crop_height // 2
192
+ if y < 0:
193
+ cy = cy + y
194
+ y = 0
195
+ elif y + cur_crop_height > h:
196
+ cy = cy + (y + cur_crop_height - h)
197
+ y = h - cur_crop_height
198
+ face_img = face_img[y:y + cur_crop_height]
199
+
200
+ # # debug
201
+ # print(path, cx, cy, angle)
202
+ # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
203
+ # cv2.imshow("image", crp)
204
+ # if cv2.waitKey() == 27:
205
+ # break
206
+ # cv2.destroyAllWindows()
207
+
208
+ # debug
209
+ if args.debug:
210
+ cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
211
+
212
+ _, buf = cv2.imencode(output_extension, face_img)
213
+ with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
214
+ buf.tofile(f)
215
+
216
+
217
+ def setup_parser() -> argparse.ArgumentParser:
218
+ parser = argparse.ArgumentParser()
219
+ parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
220
+ parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
221
+ parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
222
+ parser.add_argument("--resize_fit", action="store_true",
223
+ help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
224
+ parser.add_argument("--resize_face_size", type=int, default=None,
225
+ help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
226
+ parser.add_argument("--crop_size", type=str, default=None,
227
+ help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
228
+ parser.add_argument("--crop_ratio", type=str, default=None,
229
+ help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
230
+ parser.add_argument("--min_size", type=int, default=None,
231
+ help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
232
+ parser.add_argument("--max_size", type=int, default=None,
233
+ help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
234
+ parser.add_argument("--multiple_faces", action="store_true",
235
+ help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
236
+ parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
237
+
238
+ return parser
239
+
240
+
241
+ if __name__ == '__main__':
242
+ parser = setup_parser()
243
+
244
+ args = parser.parse_args()
245
+
246
+ process(args)
external/llite/tools/latent_upscaler.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 外部から簡単にupscalerを呼ぶためのスクリプト
2
+ # 単体で動くようにモデル定義も含めている
3
+
4
+ import argparse
5
+ import glob
6
+ import os
7
+ import cv2
8
+ from diffusers import AutoencoderKL
9
+
10
+ from typing import Dict, List
11
+ import numpy as np
12
+
13
+ import torch
14
+ from torch import nn
15
+ from tqdm import tqdm
16
+ from PIL import Image
17
+
18
+
19
+ class ResidualBlock(nn.Module):
20
+ def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
21
+ super(ResidualBlock, self).__init__()
22
+
23
+ if out_channels is None:
24
+ out_channels = in_channels
25
+
26
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
27
+ self.bn1 = nn.BatchNorm2d(out_channels)
28
+ self.relu1 = nn.ReLU(inplace=True)
29
+
30
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
31
+ self.bn2 = nn.BatchNorm2d(out_channels)
32
+
33
+ self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
34
+
35
+ # initialize weights
36
+ self._initialize_weights()
37
+
38
+ def _initialize_weights(self):
39
+ for m in self.modules():
40
+ if isinstance(m, nn.Conv2d):
41
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
42
+ if m.bias is not None:
43
+ nn.init.constant_(m.bias, 0)
44
+ elif isinstance(m, nn.BatchNorm2d):
45
+ nn.init.constant_(m.weight, 1)
46
+ nn.init.constant_(m.bias, 0)
47
+ elif isinstance(m, nn.Linear):
48
+ nn.init.normal_(m.weight, 0, 0.01)
49
+ nn.init.constant_(m.bias, 0)
50
+
51
+ def forward(self, x):
52
+ residual = x
53
+
54
+ out = self.conv1(x)
55
+ out = self.bn1(out)
56
+ out = self.relu1(out)
57
+
58
+ out = self.conv2(out)
59
+ out = self.bn2(out)
60
+
61
+ out += residual
62
+
63
+ out = self.relu2(out)
64
+
65
+ return out
66
+
67
+
68
+ class Upscaler(nn.Module):
69
+ def __init__(self):
70
+ super(Upscaler, self).__init__()
71
+
72
+ # define layers
73
+ # latent has 4 channels
74
+
75
+ self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
76
+ self.bn1 = nn.BatchNorm2d(128)
77
+ self.relu1 = nn.ReLU(inplace=True)
78
+
79
+ # resblocks
80
+ # 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
81
+ self.resblock1 = ResidualBlock(128)
82
+ self.resblock2 = ResidualBlock(128)
83
+ self.resblock3 = ResidualBlock(128)
84
+ self.resblock4 = ResidualBlock(128)
85
+ self.resblock5 = ResidualBlock(128)
86
+ self.resblock6 = ResidualBlock(128)
87
+ self.resblock7 = ResidualBlock(128)
88
+ self.resblock8 = ResidualBlock(128)
89
+ self.resblock9 = ResidualBlock(128)
90
+ self.resblock10 = ResidualBlock(128)
91
+ self.resblock11 = ResidualBlock(128)
92
+ self.resblock12 = ResidualBlock(128)
93
+ self.resblock13 = ResidualBlock(128)
94
+ self.resblock14 = ResidualBlock(128)
95
+ self.resblock15 = ResidualBlock(128)
96
+ self.resblock16 = ResidualBlock(128)
97
+ self.resblock17 = ResidualBlock(128)
98
+ self.resblock18 = ResidualBlock(128)
99
+ self.resblock19 = ResidualBlock(128)
100
+ self.resblock20 = ResidualBlock(128)
101
+
102
+ # last convs
103
+ self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
104
+ self.bn2 = nn.BatchNorm2d(64)
105
+ self.relu2 = nn.ReLU(inplace=True)
106
+
107
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
108
+ self.bn3 = nn.BatchNorm2d(64)
109
+ self.relu3 = nn.ReLU(inplace=True)
110
+
111
+ # final conv: output 4 channels
112
+ self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
113
+
114
+ # initialize weights
115
+ self._initialize_weights()
116
+
117
+ def _initialize_weights(self):
118
+ for m in self.modules():
119
+ if isinstance(m, nn.Conv2d):
120
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
121
+ if m.bias is not None:
122
+ nn.init.constant_(m.bias, 0)
123
+ elif isinstance(m, nn.BatchNorm2d):
124
+ nn.init.constant_(m.weight, 1)
125
+ nn.init.constant_(m.bias, 0)
126
+ elif isinstance(m, nn.Linear):
127
+ nn.init.normal_(m.weight, 0, 0.01)
128
+ nn.init.constant_(m.bias, 0)
129
+
130
+ # initialize final conv weights to 0: 流行りのzero conv
131
+ nn.init.constant_(self.conv_final.weight, 0)
132
+
133
+ def forward(self, x):
134
+ inp = x
135
+
136
+ x = self.conv1(x)
137
+ x = self.bn1(x)
138
+ x = self.relu1(x)
139
+
140
+ # いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
141
+ residual = x
142
+ x = self.resblock1(x)
143
+ x = self.resblock2(x)
144
+ x = self.resblock3(x)
145
+ x = self.resblock4(x)
146
+ x = x + residual
147
+ residual = x
148
+ x = self.resblock5(x)
149
+ x = self.resblock6(x)
150
+ x = self.resblock7(x)
151
+ x = self.resblock8(x)
152
+ x = x + residual
153
+ residual = x
154
+ x = self.resblock9(x)
155
+ x = self.resblock10(x)
156
+ x = self.resblock11(x)
157
+ x = self.resblock12(x)
158
+ x = x + residual
159
+ residual = x
160
+ x = self.resblock13(x)
161
+ x = self.resblock14(x)
162
+ x = self.resblock15(x)
163
+ x = self.resblock16(x)
164
+ x = x + residual
165
+ residual = x
166
+ x = self.resblock17(x)
167
+ x = self.resblock18(x)
168
+ x = self.resblock19(x)
169
+ x = self.resblock20(x)
170
+ x = x + residual
171
+
172
+ x = self.conv2(x)
173
+ x = self.bn2(x)
174
+ x = self.relu2(x)
175
+ x = self.conv3(x)
176
+ x = self.bn3(x)
177
+
178
+ # ここにreluを入れないほうがいい気がする
179
+
180
+ x = self.conv_final(x)
181
+
182
+ # network estimates the difference between the input and the output
183
+ x = x + inp
184
+
185
+ return x
186
+
187
+ def support_latents(self) -> bool:
188
+ return False
189
+
190
+ def upscale(
191
+ self,
192
+ vae: AutoencoderKL,
193
+ lowreso_images: List[Image.Image],
194
+ lowreso_latents: torch.Tensor,
195
+ dtype: torch.dtype,
196
+ width: int,
197
+ height: int,
198
+ batch_size: int = 1,
199
+ vae_batch_size: int = 1,
200
+ ):
201
+ # assertion
202
+ assert lowreso_images is not None, "Upscaler requires lowreso image"
203
+
204
+ # make upsampled image with lanczos4
205
+ upsampled_images = []
206
+ for lowreso_image in lowreso_images:
207
+ upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
208
+ upsampled_images.append(upsampled_image)
209
+
210
+ # convert to tensor: this tensor is too large to be converted to cuda
211
+ upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
212
+ upsampled_images = torch.stack(upsampled_images, dim=0)
213
+ upsampled_images = upsampled_images.to(dtype)
214
+
215
+ # normalize to [-1, 1]
216
+ upsampled_images = upsampled_images / 127.5 - 1.0
217
+
218
+ # convert upsample images to latents with batch size
219
+ # print("Encoding upsampled (LANCZOS4) images...")
220
+ upsampled_latents = []
221
+ for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
222
+ batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
223
+ with torch.no_grad():
224
+ batch = vae.encode(batch).latent_dist.sample()
225
+ upsampled_latents.append(batch)
226
+
227
+ upsampled_latents = torch.cat(upsampled_latents, dim=0)
228
+
229
+ # upscale (refine) latents with this model with batch size
230
+ print("Upscaling latents...")
231
+ upscaled_latents = []
232
+ for i in range(0, upsampled_latents.shape[0], batch_size):
233
+ with torch.no_grad():
234
+ upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
235
+ upscaled_latents = torch.cat(upscaled_latents, dim=0)
236
+
237
+ return upscaled_latents * 0.18215
238
+
239
+
240
+ # external interface: returns a model
241
+ def create_upscaler(**kwargs):
242
+ weights = kwargs["weights"]
243
+ model = Upscaler()
244
+
245
+ print(f"Loading weights from {weights}...")
246
+ if os.path.splitext(weights)[1] == ".safetensors":
247
+ from safetensors.torch import load_file
248
+
249
+ sd = load_file(weights)
250
+ else:
251
+ sd = torch.load(weights, map_location=torch.device("cpu"))
252
+ model.load_state_dict(sd)
253
+ return model
254
+
255
+
256
+ # another interface: upscale images with a model for given images from command line
257
+ def upscale_images(args: argparse.Namespace):
258
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
259
+ us_dtype = torch.float16 # TODO: support fp32/bf16
260
+ os.makedirs(args.output_dir, exist_ok=True)
261
+
262
+ # load VAE with Diffusers
263
+ assert args.vae_path is not None, "VAE path is required"
264
+ print(f"Loading VAE from {args.vae_path}...")
265
+ vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
266
+ vae.to(DEVICE, dtype=us_dtype)
267
+
268
+ # prepare model
269
+ print("Preparing model...")
270
+ upscaler: Upscaler = create_upscaler(weights=args.weights)
271
+ # print("Loading weights from", args.weights)
272
+ # upscaler.load_state_dict(torch.load(args.weights))
273
+ upscaler.eval()
274
+ upscaler.to(DEVICE, dtype=us_dtype)
275
+
276
+ # load images
277
+ image_paths = glob.glob(args.image_pattern)
278
+ images = []
279
+ for image_path in image_paths:
280
+ image = Image.open(image_path)
281
+ image = image.convert("RGB")
282
+
283
+ # make divisible by 8
284
+ width = image.width
285
+ height = image.height
286
+ if width % 8 != 0:
287
+ width = width - (width % 8)
288
+ if height % 8 != 0:
289
+ height = height - (height % 8)
290
+ if width != image.width or height != image.height:
291
+ image = image.crop((0, 0, width, height))
292
+
293
+ images.append(image)
294
+
295
+ # debug output
296
+ if args.debug:
297
+ for image, image_path in zip(images, image_paths):
298
+ image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
299
+
300
+ basename = os.path.basename(image_path)
301
+ basename_wo_ext, ext = os.path.splitext(basename)
302
+ dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
303
+ image_debug.save(dest_file_name)
304
+
305
+ # upscale
306
+ print("Upscaling...")
307
+ upscaled_latents = upscaler.upscale(
308
+ vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
309
+ )
310
+ upscaled_latents /= 0.18215
311
+
312
+ # decode with batch
313
+ print("Decoding...")
314
+ upscaled_images = []
315
+ for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
316
+ with torch.no_grad():
317
+ batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
318
+ batch = batch.to("cpu")
319
+ upscaled_images.append(batch)
320
+ upscaled_images = torch.cat(upscaled_images, dim=0)
321
+
322
+ # tensor to numpy
323
+ upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
324
+ upscaled_images = (upscaled_images + 1.0) * 127.5
325
+ upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
326
+
327
+ upscaled_images = upscaled_images[..., ::-1]
328
+
329
+ # save images
330
+ for i, image in enumerate(upscaled_images):
331
+ basename = os.path.basename(image_paths[i])
332
+ basename_wo_ext, ext = os.path.splitext(basename)
333
+ dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
334
+ cv2.imwrite(dest_file_name, image)
335
+
336
+
337
+ if __name__ == "__main__":
338
+ parser = argparse.ArgumentParser()
339
+ parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
340
+ parser.add_argument("--weights", type=str, default=None, help="Weights path")
341
+ parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
342
+ parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
343
+ parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
344
+ parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
345
+ parser.add_argument("--debug", action="store_true", help="Debug mode")
346
+
347
+ args = parser.parse_args()
348
+ upscale_images(args)
external/llite/tools/merge_models.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from safetensors import safe_open
6
+ from safetensors.torch import load_file, save_file
7
+ from tqdm import tqdm
8
+
9
+
10
+ def is_unet_key(key):
11
+ # VAE or TextEncoder, the last one is for SDXL
12
+ return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key)
13
+
14
+
15
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
16
+ ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
17
+ ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
18
+ ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
19
+ ]
20
+
21
+
22
+ # support for models with different text encoder keys
23
+ def replace_text_encoder_key(key):
24
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
25
+ if key.startswith(rep_from):
26
+ return True, rep_to + key[len(rep_from) :]
27
+ return False, key
28
+
29
+
30
+ def merge(args):
31
+ if args.precision == "fp16":
32
+ dtype = torch.float16
33
+ elif args.precision == "bf16":
34
+ dtype = torch.bfloat16
35
+ else:
36
+ dtype = torch.float
37
+
38
+ if args.saving_precision == "fp16":
39
+ save_dtype = torch.float16
40
+ elif args.saving_precision == "bf16":
41
+ save_dtype = torch.bfloat16
42
+ else:
43
+ save_dtype = torch.float
44
+
45
+ # check if all models are safetensors
46
+ for model in args.models:
47
+ if not model.endswith("safetensors"):
48
+ print(f"Model {model} is not a safetensors model")
49
+ exit()
50
+ if not os.path.isfile(model):
51
+ print(f"Model {model} does not exist")
52
+ exit()
53
+
54
+ assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
55
+
56
+ # load and merge
57
+ ratio = 1.0 / len(args.models) # default
58
+ supplementary_key_ratios = {} # [key] = ratio, for keys not in all models, add later
59
+
60
+ merged_sd = None
61
+ first_model_keys = set() # check missing keys in other models
62
+ for i, model in enumerate(args.models):
63
+ if args.ratios is not None:
64
+ ratio = args.ratios[i]
65
+
66
+ if merged_sd is None:
67
+ # load first model
68
+ print(f"Loading model {model}, ratio = {ratio}...")
69
+ merged_sd = {}
70
+ with safe_open(model, framework="pt", device=args.device) as f:
71
+ for key in tqdm(f.keys()):
72
+ value = f.get_tensor(key)
73
+ _, key = replace_text_encoder_key(key)
74
+
75
+ first_model_keys.add(key)
76
+
77
+ if not is_unet_key(key) and args.unet_only:
78
+ supplementary_key_ratios[key] = 1.0 # use first model's value for VAE or TextEncoder
79
+ continue
80
+
81
+ value = ratio * value.to(dtype) # first model's value * ratio
82
+ merged_sd[key] = value
83
+
84
+ print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
85
+ continue
86
+
87
+ # load other models
88
+ print(f"Loading model {model}, ratio = {ratio}...")
89
+
90
+ with safe_open(model, framework="pt", device=args.device) as f:
91
+ model_keys = f.keys()
92
+ for key in tqdm(model_keys):
93
+ _, new_key = replace_text_encoder_key(key)
94
+ if new_key not in merged_sd:
95
+ if args.show_skipped and new_key not in first_model_keys:
96
+ print(f"Skip: {new_key}")
97
+ continue
98
+
99
+ value = f.get_tensor(key)
100
+ merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype)
101
+
102
+ # enumerate keys not in this model
103
+ model_keys = set(model_keys)
104
+ for key in merged_sd.keys():
105
+ if key in model_keys:
106
+ continue
107
+ print(f"Key {key} not in model {model}, use first model's value")
108
+ if key in supplementary_key_ratios:
109
+ supplementary_key_ratios[key] += ratio
110
+ else:
111
+ supplementary_key_ratios[key] = ratio
112
+
113
+ # add supplementary keys' value (including VAE and TextEncoder)
114
+ if len(supplementary_key_ratios) > 0:
115
+ print("add first model's value")
116
+ with safe_open(args.models[0], framework="pt", device=args.device) as f:
117
+ for key in tqdm(f.keys()):
118
+ _, new_key = replace_text_encoder_key(key)
119
+ if new_key not in supplementary_key_ratios:
120
+ continue
121
+
122
+ if is_unet_key(new_key): # not VAE or TextEncoder
123
+ print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
124
+
125
+ value = f.get_tensor(key) # original key
126
+
127
+ if new_key not in merged_sd:
128
+ merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype)
129
+ else:
130
+ merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype)
131
+
132
+ # save
133
+ output_file = args.output
134
+ if not output_file.endswith(".safetensors"):
135
+ output_file = output_file + ".safetensors"
136
+
137
+ print(f"Saving to {output_file}...")
138
+
139
+ # convert to save_dtype
140
+ for k in merged_sd.keys():
141
+ merged_sd[k] = merged_sd[k].to(save_dtype)
142
+
143
+ save_file(merged_sd, output_file)
144
+
145
+ print("Done!")
146
+
147
+
148
+ if __name__ == "__main__":
149
+ parser = argparse.ArgumentParser(description="Merge models")
150
+ parser.add_argument("--models", nargs="+", type=str, help="Models to merge")
151
+ parser.add_argument("--output", type=str, help="Output model")
152
+ parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0")
153
+ parser.add_argument("--unet_only", action="store_true", help="Only merge unet")
154
+ parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu")
155
+ parser.add_argument(
156
+ "--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float"
157
+ )
158
+ parser.add_argument(
159
+ "--saving_precision",
160
+ type=str,
161
+ default="float",
162
+ choices=["float", "fp16", "bf16"],
163
+ help="Saving precision, default is float",
164
+ )
165
+ parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)")
166
+
167
+ args = parser.parse_args()
168
+ merge(args)
external/llite/tools/original_control_net.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, NamedTuple, Any
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ from safetensors.torch import load_file
6
+
7
+ from library.original_unet import UNet2DConditionModel, SampleOutput
8
+
9
+ import library.model_util as model_util
10
+
11
+
12
+ class ControlNetInfo(NamedTuple):
13
+ unet: Any
14
+ net: Any
15
+ prep: Any
16
+ weight: float
17
+ ratio: float
18
+
19
+
20
+ class ControlNet(torch.nn.Module):
21
+ def __init__(self) -> None:
22
+ super().__init__()
23
+
24
+ # make control model
25
+ self.control_model = torch.nn.Module()
26
+
27
+ dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
28
+ zero_convs = torch.nn.ModuleList()
29
+ for i, dim in enumerate(dims):
30
+ sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
31
+ zero_convs.append(sub_list)
32
+ self.control_model.add_module("zero_convs", zero_convs)
33
+
34
+ middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
35
+ self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
36
+
37
+ dims = [16, 16, 32, 32, 96, 96, 256, 320]
38
+ strides = [1, 1, 2, 1, 2, 1, 2, 1]
39
+ prev_dim = 3
40
+ input_hint_block = torch.nn.Sequential()
41
+ for i, (dim, stride) in enumerate(zip(dims, strides)):
42
+ input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
43
+ if i < len(dims) - 1:
44
+ input_hint_block.append(torch.nn.SiLU())
45
+ prev_dim = dim
46
+ self.control_model.add_module("input_hint_block", input_hint_block)
47
+
48
+
49
+ def load_control_net(v2, unet, model):
50
+ device = unet.device
51
+
52
+ # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
53
+ # state dictを読み込む
54
+ print(f"ControlNet: loading control SD model : {model}")
55
+
56
+ if model_util.is_safetensors(model):
57
+ ctrl_sd_sd = load_file(model)
58
+ else:
59
+ ctrl_sd_sd = torch.load(model, map_location="cpu")
60
+ ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
61
+
62
+ # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
63
+ is_difference = "difference" in ctrl_sd_sd
64
+ print("ControlNet: loading difference:", is_difference)
65
+
66
+ # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
67
+ # またTransfer Controlの元weightとなる
68
+ ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
69
+
70
+ # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
71
+ for key in list(ctrl_unet_sd_sd.keys()):
72
+ ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
73
+
74
+ zero_conv_sd = {}
75
+ for key in list(ctrl_sd_sd.keys()):
76
+ if key.startswith("control_"):
77
+ unet_key = "model.diffusion_" + key[len("control_") :]
78
+ if unet_key not in ctrl_unet_sd_sd: # zero conv
79
+ zero_conv_sd[key] = ctrl_sd_sd[key]
80
+ continue
81
+ if is_difference: # Transfer Control
82
+ ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
83
+ else:
84
+ ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
85
+
86
+ unet_config = model_util.create_unet_diffusers_config(v2)
87
+ ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
88
+
89
+ # ControlNetのU-Netを作成する
90
+ ctrl_unet = UNet2DConditionModel(**unet_config)
91
+ info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
92
+ print("ControlNet: loading Control U-Net:", info)
93
+
94
+ # U-Net以外のControlNetを作成する
95
+ # TODO support middle only
96
+ ctrl_net = ControlNet()
97
+ info = ctrl_net.load_state_dict(zero_conv_sd)
98
+ print("ControlNet: loading ControlNet:", info)
99
+
100
+ ctrl_unet.to(unet.device, dtype=unet.dtype)
101
+ ctrl_net.to(unet.device, dtype=unet.dtype)
102
+ return ctrl_unet, ctrl_net
103
+
104
+
105
+ def load_preprocess(prep_type: str):
106
+ if prep_type is None or prep_type.lower() == "none":
107
+ return None
108
+
109
+ if prep_type.startswith("canny"):
110
+ args = prep_type.split("_")
111
+ th1 = int(args[1]) if len(args) >= 2 else 63
112
+ th2 = int(args[2]) if len(args) >= 3 else 191
113
+
114
+ def canny(img):
115
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
116
+ return cv2.Canny(img, th1, th2)
117
+
118
+ return canny
119
+
120
+ print("Unsupported prep type:", prep_type)
121
+ return None
122
+
123
+
124
+ def preprocess_ctrl_net_hint_image(image):
125
+ image = np.array(image).astype(np.float32) / 255.0
126
+ # ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている
127
+ # image = image[:, :, ::-1].copy() # rgb to bgr
128
+ image = image[None].transpose(0, 3, 1, 2) # nchw
129
+ image = torch.from_numpy(image)
130
+ return image # 0 to 1
131
+
132
+
133
+ def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
134
+ guided_hints = []
135
+ for i, cnet_info in enumerate(control_nets):
136
+ # hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
137
+ b_hints = []
138
+ if len(hints) == 1: # すべて同じ画像をhintとして使う
139
+ hint = hints[0]
140
+ if cnet_info.prep is not None:
141
+ hint = cnet_info.prep(hint)
142
+ hint = preprocess_ctrl_net_hint_image(hint)
143
+ b_hints = [hint for _ in range(b_size)]
144
+ else:
145
+ for bi in range(b_size):
146
+ hint = hints[(bi * len(control_nets) + i) % len(hints)]
147
+ if cnet_info.prep is not None:
148
+ hint = cnet_info.prep(hint)
149
+ hint = preprocess_ctrl_net_hint_image(hint)
150
+ b_hints.append(hint)
151
+ b_hints = torch.cat(b_hints, dim=0)
152
+ b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
153
+
154
+ guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
155
+ guided_hints.append(guided_hint)
156
+ return guided_hints
157
+
158
+
159
+ def call_unet_and_control_net(
160
+ step,
161
+ num_latent_input,
162
+ original_unet,
163
+ control_nets: List[ControlNetInfo],
164
+ guided_hints,
165
+ current_ratio,
166
+ sample,
167
+ timestep,
168
+ encoder_hidden_states,
169
+ encoder_hidden_states_for_control_net,
170
+ ):
171
+ # ControlNet
172
+ # 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
173
+ cnet_cnt = len(control_nets)
174
+ cnet_idx = step % cnet_cnt
175
+ cnet_info = control_nets[cnet_idx]
176
+
177
+ # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
178
+ if cnet_info.ratio < current_ratio:
179
+ return original_unet(sample, timestep, encoder_hidden_states)
180
+
181
+ guided_hint = guided_hints[cnet_idx]
182
+ guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
183
+ outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
184
+ outs = [o * cnet_info.weight for o in outs]
185
+
186
+ # U-Net
187
+ return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
188
+
189
+
190
+ """
191
+ # これはmergeのバージョン
192
+ # ControlNet
193
+ cnet_outs_list = []
194
+ for i, cnet_info in enumerate(control_nets):
195
+ # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
196
+ if cnet_info.ratio < current_ratio:
197
+ continue
198
+ guided_hint = guided_hints[i]
199
+ outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
200
+ for i in range(len(outs)):
201
+ outs[i] *= cnet_info.weight
202
+
203
+ cnet_outs_list.append(outs)
204
+
205
+ count = len(cnet_outs_list)
206
+ if count == 0:
207
+ return original_unet(sample, timestep, encoder_hidden_states)
208
+
209
+ # sum of controlnets
210
+ for i in range(1, count):
211
+ cnet_outs_list[0] += cnet_outs_list[i]
212
+
213
+ # U-Net
214
+ return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states)
215
+ """
216
+
217
+
218
+ def unet_forward(
219
+ is_control_net,
220
+ control_net: ControlNet,
221
+ unet: UNet2DConditionModel,
222
+ guided_hint,
223
+ ctrl_outs,
224
+ sample,
225
+ timestep,
226
+ encoder_hidden_states,
227
+ ):
228
+ # copy from UNet2DConditionModel
229
+ default_overall_up_factor = 2**unet.num_upsamplers
230
+
231
+ forward_upsample_size = False
232
+ upsample_size = None
233
+
234
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
235
+ print("Forward upsample size to force interpolation output size.")
236
+ forward_upsample_size = True
237
+
238
+ # 1. time
239
+ timesteps = timestep
240
+ if not torch.is_tensor(timesteps):
241
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
242
+ # This would be a good case for the `match` statement (Python 3.10+)
243
+ is_mps = sample.device.type == "mps"
244
+ if isinstance(timestep, float):
245
+ dtype = torch.float32 if is_mps else torch.float64
246
+ else:
247
+ dtype = torch.int32 if is_mps else torch.int64
248
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
249
+ elif len(timesteps.shape) == 0:
250
+ timesteps = timesteps[None].to(sample.device)
251
+
252
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
253
+ timesteps = timesteps.expand(sample.shape[0])
254
+
255
+ t_emb = unet.time_proj(timesteps)
256
+
257
+ # timesteps does not contain any weights and will always return f32 tensors
258
+ # but time_embedding might actually be running in fp16. so we need to cast here.
259
+ # there might be better ways to encapsulate this.
260
+ t_emb = t_emb.to(dtype=unet.dtype)
261
+ emb = unet.time_embedding(t_emb)
262
+
263
+ outs = [] # output of ControlNet
264
+ zc_idx = 0
265
+
266
+ # 2. pre-process
267
+ sample = unet.conv_in(sample)
268
+ if is_control_net:
269
+ sample += guided_hint
270
+ outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
271
+ zc_idx += 1
272
+
273
+ # 3. down
274
+ down_block_res_samples = (sample,)
275
+ for downsample_block in unet.down_blocks:
276
+ if downsample_block.has_cross_attention:
277
+ sample, res_samples = downsample_block(
278
+ hidden_states=sample,
279
+ temb=emb,
280
+ encoder_hidden_states=encoder_hidden_states,
281
+ )
282
+ else:
283
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
284
+ if is_control_net:
285
+ for rs in res_samples:
286
+ outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
287
+ zc_idx += 1
288
+
289
+ down_block_res_samples += res_samples
290
+
291
+ # 4. mid
292
+ sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
293
+ if is_control_net:
294
+ outs.append(control_net.control_model.middle_block_out[0](sample))
295
+ return outs
296
+
297
+ if not is_control_net:
298
+ sample += ctrl_outs.pop()
299
+
300
+ # 5. up
301
+ for i, upsample_block in enumerate(unet.up_blocks):
302
+ is_final_block = i == len(unet.up_blocks) - 1
303
+
304
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
305
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
306
+
307
+ if not is_control_net and len(ctrl_outs) > 0:
308
+ res_samples = list(res_samples)
309
+ apply_ctrl_outs = ctrl_outs[-len(res_samples) :]
310
+ ctrl_outs = ctrl_outs[: -len(res_samples)]
311
+ for j in range(len(res_samples)):
312
+ res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
313
+ res_samples = tuple(res_samples)
314
+
315
+ # if we have not reached the final block and need to forward the
316
+ # upsample size, we do it here
317
+ if not is_final_block and forward_upsample_size:
318
+ upsample_size = down_block_res_samples[-1].shape[2:]
319
+
320
+ if upsample_block.has_cross_attention:
321
+ sample = upsample_block(
322
+ hidden_states=sample,
323
+ temb=emb,
324
+ res_hidden_states_tuple=res_samples,
325
+ encoder_hidden_states=encoder_hidden_states,
326
+ upsample_size=upsample_size,
327
+ )
328
+ else:
329
+ sample = upsample_block(
330
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
331
+ )
332
+ # 6. post-process
333
+ sample = unet.conv_norm_out(sample)
334
+ sample = unet.conv_act(sample)
335
+ sample = unet.conv_out(sample)
336
+
337
+ return SampleOutput(sample=sample)
external/llite/tools/resize_images_to_resolution.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import cv2
4
+ import argparse
5
+ import shutil
6
+ import math
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+
11
+ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
12
+ # Split the max_resolution string by "," and strip any whitespaces
13
+ max_resolutions = [res.strip() for res in max_resolution.split(',')]
14
+
15
+ # # Calculate max_pixels from max_resolution string
16
+ # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
17
+
18
+ # Create destination folder if it does not exist
19
+ if not os.path.exists(dst_img_folder):
20
+ os.makedirs(dst_img_folder)
21
+
22
+ # Select interpolation method
23
+ if interpolation == 'lanczos4':
24
+ cv2_interpolation = cv2.INTER_LANCZOS4
25
+ elif interpolation == 'cubic':
26
+ cv2_interpolation = cv2.INTER_CUBIC
27
+ else:
28
+ cv2_interpolation = cv2.INTER_AREA
29
+
30
+ # Iterate through all files in src_img_folder
31
+ img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
32
+ for filename in os.listdir(src_img_folder):
33
+ # Check if the image is png, jpg or webp etc...
34
+ if not filename.endswith(img_exts):
35
+ # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
36
+ shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
37
+ continue
38
+
39
+ # Load image
40
+ # img = cv2.imread(os.path.join(src_img_folder, filename))
41
+ image = Image.open(os.path.join(src_img_folder, filename))
42
+ if not image.mode == "RGB":
43
+ image = image.convert("RGB")
44
+ img = np.array(image, np.uint8)
45
+
46
+ base, _ = os.path.splitext(filename)
47
+ for max_resolution in max_resolutions:
48
+ # Calculate max_pixels from max_resolution string
49
+ max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
50
+
51
+ # Calculate current number of pixels
52
+ current_pixels = img.shape[0] * img.shape[1]
53
+
54
+ # Check if the image needs resizing
55
+ if current_pixels > max_pixels:
56
+ # Calculate scaling factor
57
+ scale_factor = max_pixels / current_pixels
58
+
59
+ # Calculate new dimensions
60
+ new_height = int(img.shape[0] * math.sqrt(scale_factor))
61
+ new_width = int(img.shape[1] * math.sqrt(scale_factor))
62
+
63
+ # Resize image
64
+ img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
65
+ else:
66
+ new_height, new_width = img.shape[0:2]
67
+
68
+ # Calculate the new height and width that are divisible by divisible_by (with/without resizing)
69
+ new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
70
+ new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
71
+
72
+ # Center crop the image to the calculated dimensions
73
+ y = int((img.shape[0] - new_height) / 2)
74
+ x = int((img.shape[1] - new_width) / 2)
75
+ img = img[y:y + new_height, x:x + new_width]
76
+
77
+ # Split filename into base and extension
78
+ new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
79
+
80
+ # Save resized image in dst_img_folder
81
+ # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
82
+ image = Image.fromarray(img)
83
+ image.save(os.path.join(dst_img_folder, new_filename), quality=100)
84
+
85
+ proc = "Resized" if current_pixels > max_pixels else "Saved"
86
+ print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
87
+
88
+ # If other files with same basename, copy them with resolution suffix
89
+ if copy_associated_files:
90
+ asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
91
+ for asoc_file in asoc_files:
92
+ ext = os.path.splitext(asoc_file)[1]
93
+ if ext in img_exts:
94
+ continue
95
+ for max_resolution in max_resolutions:
96
+ new_asoc_file = base + '+' + max_resolution + ext
97
+ print(f"Copy {asoc_file} as {new_asoc_file}")
98
+ shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
99
+
100
+
101
+ def setup_parser() -> argparse.ArgumentParser:
102
+ parser = argparse.ArgumentParser(
103
+ description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
104
+ parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
105
+ parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
106
+ parser.add_argument('--max_resolution', type=str,
107
+ help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
108
+ parser.add_argument('--divisible_by', type=int,
109
+ help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
110
+ parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
111
+ default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
112
+ parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
113
+ parser.add_argument('--copy_associated_files', action='store_true',
114
+ help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
115
+
116
+ return parser
117
+
118
+
119
+ def main():
120
+ parser = setup_parser()
121
+
122
+ args = parser.parse_args()
123
+ resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
124
+ args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
125
+
126
+
127
+ if __name__ == '__main__':
128
+ main()
external/llite/tools/show_metadata.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ from safetensors import safe_open
4
+
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument("--model", type=str, required=True)
7
+ args = parser.parse_args()
8
+
9
+ with safe_open(args.model, framework="pt") as f:
10
+ metadata = f.metadata()
11
+
12
+ if metadata is None:
13
+ print("No metadata found")
14
+ else:
15
+ # metadata is json dict, but not pretty printed
16
+ # sort by key and pretty print
17
+ print(json.dumps(metadata, indent=4, sort_keys=True))
18
+
19
+
inference.py CHANGED
@@ -468,24 +468,46 @@ def img2img(task: Task):
468
 
469
  width, height = get_intermediate_dimension(task)
470
 
471
- lora_patcher = lora_style.get_patcher(
472
- [img2img_pipe.pipe, high_res.pipe], task.get_style()
473
- )
474
- lora_patcher.patch()
475
-
476
  torch.manual_seed(task.get_seed())
477
 
478
- kwargs = {
479
- "prompt": prompt,
480
- "imageUrl": task.get_imageUrl(),
481
- "negative_prompt": [task.get_negative_prompt()] * num_return_sequences,
482
- "num_inference_steps": task.get_steps(),
483
- "width": width,
484
- "height": height,
485
- **task.i2i_kwargs(),
486
- **lora_patcher.kwargs(),
487
- }
488
- images, has_nsfw = img2img_pipe.process(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
  if task.get_high_res_fix():
491
  kwargs = {
 
468
 
469
  width, height = get_intermediate_dimension(task)
470
 
 
 
 
 
 
471
  torch.manual_seed(task.get_seed())
472
 
473
+ if get_is_sdxl():
474
+ # we run lineart for img2img
475
+ controlnet.load_model("linearart")
476
+
477
+ lora_patcher = lora_style.get_patcher(
478
+ [controlnet.pipe2, high_res.pipe], task.get_style()
479
+ )
480
+ lora_patcher.patch()
481
+
482
+ kwargs = {
483
+ "imageUrl": task.get_imageUrl(),
484
+ "seed": task.get_seed(),
485
+ "num_inference_steps": task.get_steps(),
486
+ "width": width,
487
+ "height": height,
488
+ "prompt": prompt,
489
+ "negative_prompt": [task.get_negative_prompt()] * num_return_sequences,
490
+ **task.cnl_kwargs(),
491
+ "adapter_conditioning_scale": 0.3,
492
+ }
493
+ images, has_nsfw = controlnet.process(**kwargs)
494
+ else:
495
+ lora_patcher = lora_style.get_patcher(
496
+ [img2img_pipe.pipe, high_res.pipe], task.get_style()
497
+ )
498
+ lora_patcher.patch()
499
+
500
+ kwargs = {
501
+ "prompt": prompt,
502
+ "imageUrl": task.get_imageUrl(),
503
+ "negative_prompt": [task.get_negative_prompt()] * num_return_sequences,
504
+ "num_inference_steps": task.get_steps(),
505
+ "width": width,
506
+ "height": height,
507
+ **task.i2i_kwargs(),
508
+ **lora_patcher.kwargs(),
509
+ }
510
+ images, has_nsfw = img2img_pipe.process(**kwargs)
511
 
512
  if task.get_high_res_fix():
513
  kwargs = {