Upload tilevae.py
Browse files- tilevae.py +753 -0
tilevae.py
ADDED
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
# ------------------------------------------------------------------------
|
3 |
+
#
|
4 |
+
# Tiled VAE
|
5 |
+
#
|
6 |
+
# Introducing a revolutionary new optimization designed to make
|
7 |
+
# the VAE work with giant images on limited VRAM!
|
8 |
+
# Say goodbye to the frustration of OOM and hello to seamless output!
|
9 |
+
#
|
10 |
+
# ------------------------------------------------------------------------
|
11 |
+
#
|
12 |
+
# This script is a wild hack that splits the image into tiles,
|
13 |
+
# encodes each tile separately, and merges the result back together.
|
14 |
+
#
|
15 |
+
# Advantages:
|
16 |
+
# - The VAE can now work with giant images on limited VRAM
|
17 |
+
# (~10 GB for 8K images!)
|
18 |
+
# - The merged output is completely seamless without any post-processing.
|
19 |
+
#
|
20 |
+
# Drawbacks:
|
21 |
+
# - NaNs always appear in for 8k images when you use fp16 (half) VAE
|
22 |
+
# You must use --no-half-vae to disable half VAE for that giant image.
|
23 |
+
# - The gradient calculation is not compatible with this hack. It
|
24 |
+
# will break any backward() or torch.autograd.grad() that passes VAE.
|
25 |
+
# (But you can still use the VAE to generate training data.)
|
26 |
+
#
|
27 |
+
# How it works:
|
28 |
+
# 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder.
|
29 |
+
# 2. When Fast Mode is disabled:
|
30 |
+
# 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile.
|
31 |
+
# 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile.
|
32 |
+
# 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues.
|
33 |
+
# 4. A zigzag execution order is used to reduce unnecessary data transfer.
|
34 |
+
# 3. When Fast Mode is enabled:
|
35 |
+
# 1. The original input is downsampled and passed to a separate task queue.
|
36 |
+
# 2. Its group norm parameters are recorded and used by all tiles' task queues.
|
37 |
+
# 3. Each tile is separately processed without any RAM-VRAM data transfer.
|
38 |
+
# 4. After all tiles are processed, tiles are written to a result buffer and returned.
|
39 |
+
# Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode.
|
40 |
+
#
|
41 |
+
# Enjoy!
|
42 |
+
#
|
43 |
+
# @Author: LI YI @ Nanyang Technological University - Singapore
|
44 |
+
# @Date: 2023-03-02
|
45 |
+
# @License: CC BY-NC-SA 4.0
|
46 |
+
#
|
47 |
+
# Please give me a star if you like this project!
|
48 |
+
#
|
49 |
+
# -------------------------------------------------------------------------
|
50 |
+
'''
|
51 |
+
|
52 |
+
import gc
|
53 |
+
import math
|
54 |
+
from time import time
|
55 |
+
from tqdm import tqdm
|
56 |
+
|
57 |
+
import torch
|
58 |
+
import torch.version
|
59 |
+
import torch.nn.functional as F
|
60 |
+
import gradio as gr
|
61 |
+
|
62 |
+
import modules.scripts as scripts
|
63 |
+
import modules.devices as devices
|
64 |
+
from modules.shared import state
|
65 |
+
from modules.ui import gr_show
|
66 |
+
from modules.processing import opt_f
|
67 |
+
from modules.sd_vae_approx import cheap_approximation
|
68 |
+
from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock
|
69 |
+
|
70 |
+
from tile_utils.attn import get_attn_func
|
71 |
+
from tile_utils.typing import Processing
|
72 |
+
|
73 |
+
|
74 |
+
def get_rcmd_enc_tsize():
|
75 |
+
if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
|
76 |
+
total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
|
77 |
+
if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072
|
78 |
+
elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048
|
79 |
+
elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536
|
80 |
+
else: ENCODER_TILE_SIZE = 960
|
81 |
+
else: ENCODER_TILE_SIZE = 512
|
82 |
+
return ENCODER_TILE_SIZE
|
83 |
+
|
84 |
+
|
85 |
+
def get_rcmd_dec_tsize():
|
86 |
+
if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
|
87 |
+
total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
|
88 |
+
if total_memory > 30*1000: DECODER_TILE_SIZE = 256
|
89 |
+
elif total_memory > 16*1000: DECODER_TILE_SIZE = 192
|
90 |
+
elif total_memory > 12*1000: DECODER_TILE_SIZE = 128
|
91 |
+
elif total_memory > 8*1000: DECODER_TILE_SIZE = 96
|
92 |
+
else: DECODER_TILE_SIZE = 64
|
93 |
+
else: DECODER_TILE_SIZE = 64
|
94 |
+
return DECODER_TILE_SIZE
|
95 |
+
|
96 |
+
|
97 |
+
def inplace_nonlinearity(x):
|
98 |
+
# Test: fix for Nans
|
99 |
+
return F.silu(x, inplace=True)
|
100 |
+
|
101 |
+
|
102 |
+
def attn2task(task_queue, net):
|
103 |
+
attn_forward = get_attn_func()
|
104 |
+
task_queue.append(('store_res', lambda x: x))
|
105 |
+
task_queue.append(('pre_norm', net.norm))
|
106 |
+
task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
|
107 |
+
task_queue.append(['add_res', None])
|
108 |
+
|
109 |
+
|
110 |
+
def resblock2task(queue, block):
|
111 |
+
"""
|
112 |
+
Turn a ResNetBlock into a sequence of tasks and append to the task queue
|
113 |
+
|
114 |
+
@param queue: the target task queue
|
115 |
+
@param block: ResNetBlock
|
116 |
+
|
117 |
+
"""
|
118 |
+
if block.in_channels != block.out_channels:
|
119 |
+
if block.use_conv_shortcut:
|
120 |
+
queue.append(('store_res', block.conv_shortcut))
|
121 |
+
else:
|
122 |
+
queue.append(('store_res', block.nin_shortcut))
|
123 |
+
else:
|
124 |
+
queue.append(('store_res', lambda x: x))
|
125 |
+
queue.append(('pre_norm', block.norm1))
|
126 |
+
queue.append(('silu', inplace_nonlinearity))
|
127 |
+
queue.append(('conv1', block.conv1))
|
128 |
+
queue.append(('pre_norm', block.norm2))
|
129 |
+
queue.append(('silu', inplace_nonlinearity))
|
130 |
+
queue.append(('conv2', block.conv2))
|
131 |
+
queue.append(['add_res', None])
|
132 |
+
|
133 |
+
|
134 |
+
def build_sampling(task_queue, net, is_decoder):
|
135 |
+
"""
|
136 |
+
Build the sampling part of a task queue
|
137 |
+
@param task_queue: the target task queue
|
138 |
+
@param net: the network
|
139 |
+
@param is_decoder: currently building decoder or encoder
|
140 |
+
"""
|
141 |
+
if is_decoder:
|
142 |
+
resblock2task(task_queue, net.mid.block_1)
|
143 |
+
attn2task(task_queue, net.mid.attn_1)
|
144 |
+
resblock2task(task_queue, net.mid.block_2)
|
145 |
+
resolution_iter = reversed(range(net.num_resolutions))
|
146 |
+
block_ids = net.num_res_blocks + 1
|
147 |
+
condition = 0
|
148 |
+
module = net.up
|
149 |
+
func_name = 'upsample'
|
150 |
+
else:
|
151 |
+
resolution_iter = range(net.num_resolutions)
|
152 |
+
block_ids = net.num_res_blocks
|
153 |
+
condition = net.num_resolutions - 1
|
154 |
+
module = net.down
|
155 |
+
func_name = 'downsample'
|
156 |
+
|
157 |
+
for i_level in resolution_iter:
|
158 |
+
for i_block in range(block_ids):
|
159 |
+
resblock2task(task_queue, module[i_level].block[i_block])
|
160 |
+
if i_level != condition:
|
161 |
+
task_queue.append((func_name, getattr(module[i_level], func_name)))
|
162 |
+
|
163 |
+
if not is_decoder:
|
164 |
+
resblock2task(task_queue, net.mid.block_1)
|
165 |
+
attn2task(task_queue, net.mid.attn_1)
|
166 |
+
resblock2task(task_queue, net.mid.block_2)
|
167 |
+
|
168 |
+
|
169 |
+
def build_task_queue(net, is_decoder):
|
170 |
+
"""
|
171 |
+
Build a single task queue for the encoder or decoder
|
172 |
+
@param net: the VAE decoder or encoder network
|
173 |
+
@param is_decoder: currently building decoder or encoder
|
174 |
+
@return: the task queue
|
175 |
+
"""
|
176 |
+
task_queue = []
|
177 |
+
task_queue.append(('conv_in', net.conv_in))
|
178 |
+
|
179 |
+
# construct the sampling part of the task queue
|
180 |
+
# because encoder and decoder share the same architecture, we extract the sampling part
|
181 |
+
build_sampling(task_queue, net, is_decoder)
|
182 |
+
|
183 |
+
if not is_decoder or not net.give_pre_end:
|
184 |
+
task_queue.append(('pre_norm', net.norm_out))
|
185 |
+
task_queue.append(('silu', inplace_nonlinearity))
|
186 |
+
task_queue.append(('conv_out', net.conv_out))
|
187 |
+
if is_decoder and net.tanh_out:
|
188 |
+
task_queue.append(('tanh', torch.tanh))
|
189 |
+
|
190 |
+
return task_queue
|
191 |
+
|
192 |
+
|
193 |
+
def clone_task_queue(task_queue):
|
194 |
+
"""
|
195 |
+
Clone a task queue
|
196 |
+
@param task_queue: the task queue to be cloned
|
197 |
+
@return: the cloned task queue
|
198 |
+
"""
|
199 |
+
return [[item for item in task] for task in task_queue]
|
200 |
+
|
201 |
+
|
202 |
+
def get_var_mean(input, num_groups, eps=1e-6):
|
203 |
+
"""
|
204 |
+
Get mean and var for group norm
|
205 |
+
"""
|
206 |
+
b, c = input.size(0), input.size(1)
|
207 |
+
channel_in_group = int(c/num_groups)
|
208 |
+
input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
209 |
+
var, mean = torch.var_mean(input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
|
210 |
+
return var, mean
|
211 |
+
|
212 |
+
|
213 |
+
def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
|
214 |
+
"""
|
215 |
+
Custom group norm with fixed mean and var
|
216 |
+
|
217 |
+
@param input: input tensor
|
218 |
+
@param num_groups: number of groups. by default, num_groups = 32
|
219 |
+
@param mean: mean, must be pre-calculated by get_var_mean
|
220 |
+
@param var: var, must be pre-calculated by get_var_mean
|
221 |
+
@param weight: weight, should be fetched from the original group norm
|
222 |
+
@param bias: bias, should be fetched from the original group norm
|
223 |
+
@param eps: epsilon, by default, eps = 1e-6 to match the original group norm
|
224 |
+
|
225 |
+
@return: normalized tensor
|
226 |
+
"""
|
227 |
+
b, c = input.size(0), input.size(1)
|
228 |
+
channel_in_group = int(c/num_groups)
|
229 |
+
input_reshaped = input.contiguous().view(
|
230 |
+
1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
231 |
+
|
232 |
+
out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps)
|
233 |
+
out = out.view(b, c, *input.size()[2:])
|
234 |
+
|
235 |
+
# post affine transform
|
236 |
+
if weight is not None:
|
237 |
+
out *= weight.view(1, -1, 1, 1)
|
238 |
+
if bias is not None:
|
239 |
+
out += bias.view(1, -1, 1, 1)
|
240 |
+
return out
|
241 |
+
|
242 |
+
|
243 |
+
def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
|
244 |
+
"""
|
245 |
+
Crop the valid region from the tile
|
246 |
+
@param x: input tile
|
247 |
+
@param input_bbox: original input bounding box
|
248 |
+
@param target_bbox: output bounding box
|
249 |
+
@param scale: scale factor
|
250 |
+
@return: cropped tile
|
251 |
+
"""
|
252 |
+
padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
|
253 |
+
margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
|
254 |
+
return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
|
255 |
+
|
256 |
+
|
257 |
+
# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
|
258 |
+
|
259 |
+
def perfcount(fn):
|
260 |
+
def wrapper(*args, **kwargs):
|
261 |
+
ts = time()
|
262 |
+
|
263 |
+
if torch.cuda.is_available():
|
264 |
+
torch.cuda.reset_peak_memory_stats(devices.device)
|
265 |
+
devices.torch_gc()
|
266 |
+
gc.collect()
|
267 |
+
|
268 |
+
ret = fn(*args, **kwargs)
|
269 |
+
|
270 |
+
devices.torch_gc()
|
271 |
+
gc.collect()
|
272 |
+
if torch.cuda.is_available():
|
273 |
+
vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
|
274 |
+
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
|
275 |
+
else:
|
276 |
+
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
|
277 |
+
|
278 |
+
return ret
|
279 |
+
return wrapper
|
280 |
+
|
281 |
+
# ↑↑↑ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↑↑↑
|
282 |
+
|
283 |
+
|
284 |
+
class GroupNormParam:
|
285 |
+
|
286 |
+
def __init__(self):
|
287 |
+
self.var_list = []
|
288 |
+
self.mean_list = []
|
289 |
+
self.pixel_list = []
|
290 |
+
self.weight = None
|
291 |
+
self.bias = None
|
292 |
+
|
293 |
+
def add_tile(self, tile, layer):
|
294 |
+
var, mean = get_var_mean(tile, 32)
|
295 |
+
# For giant images, the variance can be larger than max float16
|
296 |
+
# In this case we create a copy to float32
|
297 |
+
if var.dtype == torch.float16 and var.isinf().any():
|
298 |
+
fp32_tile = tile.float()
|
299 |
+
var, mean = get_var_mean(fp32_tile, 32)
|
300 |
+
# ============= DEBUG: test for infinite =============
|
301 |
+
# if torch.isinf(var).any():
|
302 |
+
# print('var: ', var)
|
303 |
+
# ====================================================
|
304 |
+
self.var_list.append(var)
|
305 |
+
self.mean_list.append(mean)
|
306 |
+
self.pixel_list.append(
|
307 |
+
tile.shape[2]*tile.shape[3])
|
308 |
+
if hasattr(layer, 'weight'):
|
309 |
+
self.weight = layer.weight
|
310 |
+
self.bias = layer.bias
|
311 |
+
else:
|
312 |
+
self.weight = None
|
313 |
+
self.bias = None
|
314 |
+
|
315 |
+
def summary(self):
|
316 |
+
"""
|
317 |
+
summarize the mean and var and return a function
|
318 |
+
that apply group norm on each tile
|
319 |
+
"""
|
320 |
+
if len(self.var_list) == 0: return None
|
321 |
+
|
322 |
+
var = torch.vstack(self.var_list)
|
323 |
+
mean = torch.vstack(self.mean_list)
|
324 |
+
max_value = max(self.pixel_list)
|
325 |
+
pixels = torch.tensor(self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
|
326 |
+
sum_pixels = torch.sum(pixels)
|
327 |
+
pixels = pixels.unsqueeze(1) / sum_pixels
|
328 |
+
var = torch.sum(var * pixels, dim=0)
|
329 |
+
mean = torch.sum(mean * pixels, dim=0)
|
330 |
+
return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
|
331 |
+
|
332 |
+
@staticmethod
|
333 |
+
def from_tile(tile, norm):
|
334 |
+
"""
|
335 |
+
create a function from a single tile without summary
|
336 |
+
"""
|
337 |
+
var, mean = get_var_mean(tile, 32)
|
338 |
+
if var.dtype == torch.float16 and var.isinf().any():
|
339 |
+
fp32_tile = tile.float()
|
340 |
+
var, mean = get_var_mean(fp32_tile, 32)
|
341 |
+
# if it is a macbook, we need to convert back to float16
|
342 |
+
if var.device.type == 'mps':
|
343 |
+
# clamp to avoid overflow
|
344 |
+
var = torch.clamp(var, 0, 60000)
|
345 |
+
var = var.half()
|
346 |
+
mean = mean.half()
|
347 |
+
if hasattr(norm, 'weight'):
|
348 |
+
weight = norm.weight
|
349 |
+
bias = norm.bias
|
350 |
+
else:
|
351 |
+
weight = None
|
352 |
+
bias = None
|
353 |
+
|
354 |
+
def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
|
355 |
+
return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
|
356 |
+
return group_norm_func
|
357 |
+
|
358 |
+
|
359 |
+
class VAEHook:
|
360 |
+
|
361 |
+
def __init__(self, net, tile_size, is_decoder:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool, to_gpu:bool=False):
|
362 |
+
self.net = net # encoder | decoder
|
363 |
+
self.tile_size = tile_size
|
364 |
+
self.is_decoder = is_decoder
|
365 |
+
self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder and is_decoder)
|
366 |
+
self.color_fix = color_fix and not is_decoder
|
367 |
+
self.to_gpu = to_gpu
|
368 |
+
self.pad = 11 if is_decoder else 32 # FIXME: magic number
|
369 |
+
|
370 |
+
def __call__(self, x):
|
371 |
+
original_device = next(self.net.parameters()).device
|
372 |
+
try:
|
373 |
+
if self.to_gpu:
|
374 |
+
self.net = self.net.to(devices.get_optimal_device())
|
375 |
+
|
376 |
+
B, C, H, W = x.shape
|
377 |
+
if max(H, W) <= self.pad * 2 + self.tile_size:
|
378 |
+
print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
|
379 |
+
return self.net.original_forward(x)
|
380 |
+
else:
|
381 |
+
return self.vae_tile_forward(x)
|
382 |
+
finally:
|
383 |
+
self.net = self.net.to(original_device)
|
384 |
+
|
385 |
+
def get_best_tile_size(self, lowerbound, upperbound):
|
386 |
+
"""
|
387 |
+
Get the best tile size for GPU memory
|
388 |
+
"""
|
389 |
+
divider = 32
|
390 |
+
while divider >= 2:
|
391 |
+
remainer = lowerbound % divider
|
392 |
+
if remainer == 0:
|
393 |
+
return lowerbound
|
394 |
+
candidate = lowerbound - remainer + divider
|
395 |
+
if candidate <= upperbound:
|
396 |
+
return candidate
|
397 |
+
divider //= 2
|
398 |
+
return lowerbound
|
399 |
+
|
400 |
+
def split_tiles(self, h, w):
|
401 |
+
"""
|
402 |
+
Tool function to split the image into tiles
|
403 |
+
@param h: height of the image
|
404 |
+
@param w: width of the image
|
405 |
+
@return: tile_input_bboxes, tile_output_bboxes
|
406 |
+
"""
|
407 |
+
tile_input_bboxes, tile_output_bboxes = [], []
|
408 |
+
tile_size = self.tile_size
|
409 |
+
pad = self.pad
|
410 |
+
num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
|
411 |
+
num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
|
412 |
+
# If any of the numbers are 0, we let it be 1
|
413 |
+
# This is to deal with long and thin images
|
414 |
+
num_height_tiles = max(num_height_tiles, 1)
|
415 |
+
num_width_tiles = max(num_width_tiles, 1)
|
416 |
+
|
417 |
+
# Suggestions from https://github.com/Kahsolt: auto shrink the tile size
|
418 |
+
real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
|
419 |
+
real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
|
420 |
+
real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
|
421 |
+
real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
|
422 |
+
|
423 |
+
print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
|
424 |
+
f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
|
425 |
+
|
426 |
+
for i in range(num_height_tiles):
|
427 |
+
for j in range(num_width_tiles):
|
428 |
+
# bbox: [x1, x2, y1, y2]
|
429 |
+
# the padding is is unnessary for image borders. So we directly start from (32, 32)
|
430 |
+
input_bbox = [
|
431 |
+
pad + j * real_tile_width,
|
432 |
+
min(pad + (j + 1) * real_tile_width, w),
|
433 |
+
pad + i * real_tile_height,
|
434 |
+
min(pad + (i + 1) * real_tile_height, h),
|
435 |
+
]
|
436 |
+
|
437 |
+
# if the output bbox is close to the image boundary, we extend it to the image boundary
|
438 |
+
output_bbox = [
|
439 |
+
input_bbox[0] if input_bbox[0] > pad else 0,
|
440 |
+
input_bbox[1] if input_bbox[1] < w - pad else w,
|
441 |
+
input_bbox[2] if input_bbox[2] > pad else 0,
|
442 |
+
input_bbox[3] if input_bbox[3] < h - pad else h,
|
443 |
+
]
|
444 |
+
|
445 |
+
# scale to get the final output bbox
|
446 |
+
output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
|
447 |
+
tile_output_bboxes.append(output_bbox)
|
448 |
+
|
449 |
+
# indistinguishable expand the input bbox by pad pixels
|
450 |
+
tile_input_bboxes.append([
|
451 |
+
max(0, input_bbox[0] - pad),
|
452 |
+
min(w, input_bbox[1] + pad),
|
453 |
+
max(0, input_bbox[2] - pad),
|
454 |
+
min(h, input_bbox[3] + pad),
|
455 |
+
])
|
456 |
+
|
457 |
+
return tile_input_bboxes, tile_output_bboxes
|
458 |
+
|
459 |
+
@torch.no_grad()
|
460 |
+
def estimate_group_norm(self, z, task_queue, color_fix):
|
461 |
+
device = z.device
|
462 |
+
tile = z
|
463 |
+
last_id = len(task_queue) - 1
|
464 |
+
while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
|
465 |
+
last_id -= 1
|
466 |
+
if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
|
467 |
+
raise ValueError('No group norm found in the task queue')
|
468 |
+
# estimate until the last group norm
|
469 |
+
for i in range(last_id + 1):
|
470 |
+
task = task_queue[i]
|
471 |
+
if task[0] == 'pre_norm':
|
472 |
+
group_norm_func = GroupNormParam.from_tile(tile, task[1])
|
473 |
+
task_queue[i] = ('apply_norm', group_norm_func)
|
474 |
+
if i == last_id:
|
475 |
+
return True
|
476 |
+
tile = group_norm_func(tile)
|
477 |
+
elif task[0] == 'store_res':
|
478 |
+
task_id = i + 1
|
479 |
+
while task_id < last_id and task_queue[task_id][0] != 'add_res':
|
480 |
+
task_id += 1
|
481 |
+
if task_id >= last_id:
|
482 |
+
continue
|
483 |
+
task_queue[task_id][1] = task[1](tile)
|
484 |
+
elif task[0] == 'add_res':
|
485 |
+
tile += task[1].to(device)
|
486 |
+
task[1] = None
|
487 |
+
elif color_fix and task[0] == 'downsample':
|
488 |
+
for j in range(i, last_id + 1):
|
489 |
+
if task_queue[j][0] == 'store_res':
|
490 |
+
task_queue[j] = ('store_res_cpu', task_queue[j][1])
|
491 |
+
return True
|
492 |
+
else:
|
493 |
+
tile = task[1](tile)
|
494 |
+
try:
|
495 |
+
devices.test_for_nans(tile, "vae")
|
496 |
+
except:
|
497 |
+
print(f'Nan detected in fast mode estimation. Fast mode disabled.')
|
498 |
+
return False
|
499 |
+
|
500 |
+
raise IndexError('Should not reach here')
|
501 |
+
|
502 |
+
@perfcount
|
503 |
+
@torch.no_grad()
|
504 |
+
def vae_tile_forward(self, z):
|
505 |
+
"""
|
506 |
+
Decode a latent vector z into an image in a tiled manner.
|
507 |
+
@param z: latent vector
|
508 |
+
@return: image
|
509 |
+
"""
|
510 |
+
device = next(self.net.parameters()).device
|
511 |
+
net = self.net
|
512 |
+
tile_size = self.tile_size
|
513 |
+
is_decoder = self.is_decoder
|
514 |
+
|
515 |
+
z = z.detach() # detach the input to avoid backprop
|
516 |
+
|
517 |
+
N, height, width = z.shape[0], z.shape[2], z.shape[3]
|
518 |
+
net.last_z_shape = z.shape
|
519 |
+
|
520 |
+
# Split the input into tiles and build a task queue for each tile
|
521 |
+
print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
|
522 |
+
|
523 |
+
in_bboxes, out_bboxes = self.split_tiles(height, width)
|
524 |
+
|
525 |
+
# Prepare tiles by split the input latents
|
526 |
+
tiles = []
|
527 |
+
for input_bbox in in_bboxes:
|
528 |
+
tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
|
529 |
+
tiles.append(tile)
|
530 |
+
|
531 |
+
num_tiles = len(tiles)
|
532 |
+
num_completed = 0
|
533 |
+
|
534 |
+
# Build task queues
|
535 |
+
single_task_queue = build_task_queue(net, is_decoder)
|
536 |
+
if self.fast_mode:
|
537 |
+
# Fast mode: downsample the input image to the tile size,
|
538 |
+
# then estimate the group norm parameters on the downsampled image
|
539 |
+
scale_factor = tile_size / max(height, width)
|
540 |
+
z = z.to(device)
|
541 |
+
downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
|
542 |
+
# use nearest-exact to keep statictics as close as possible
|
543 |
+
print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
|
544 |
+
|
545 |
+
# ======= Special thanks to @Kahsolt for distribution shift issue ======= #
|
546 |
+
# The downsampling will heavily distort its mean and std, so we need to recover it.
|
547 |
+
std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
|
548 |
+
std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
|
549 |
+
downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
|
550 |
+
del std_old, mean_old, std_new, mean_new
|
551 |
+
# occasionally the std_new is too small or too large, which exceeds the range of float16
|
552 |
+
# so we need to clamp it to max z's range.
|
553 |
+
downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
|
554 |
+
estimate_task_queue = clone_task_queue(single_task_queue)
|
555 |
+
if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
|
556 |
+
single_task_queue = estimate_task_queue
|
557 |
+
del downsampled_z
|
558 |
+
|
559 |
+
task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
|
560 |
+
|
561 |
+
# Dummy result
|
562 |
+
result = None
|
563 |
+
result_approx = None
|
564 |
+
try:
|
565 |
+
with devices.autocast():
|
566 |
+
result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
|
567 |
+
except: pass
|
568 |
+
# Free memory of input latent tensor
|
569 |
+
del z
|
570 |
+
|
571 |
+
# Task queue execution
|
572 |
+
pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
|
573 |
+
|
574 |
+
# execute the task back and forth when switch tiles so that we always
|
575 |
+
# keep one tile on the GPU to reduce unnecessary data transfer
|
576 |
+
forward = True
|
577 |
+
interrupted = False
|
578 |
+
#state.interrupted = interrupted
|
579 |
+
while True:
|
580 |
+
if state.interrupted: interrupted = True ; break
|
581 |
+
|
582 |
+
group_norm_param = GroupNormParam()
|
583 |
+
for i in range(num_tiles) if forward else reversed(range(num_tiles)):
|
584 |
+
if state.interrupted: interrupted = True ; break
|
585 |
+
|
586 |
+
tile = tiles[i].to(device)
|
587 |
+
input_bbox = in_bboxes[i]
|
588 |
+
task_queue = task_queues[i]
|
589 |
+
|
590 |
+
interrupted = False
|
591 |
+
while len(task_queue) > 0:
|
592 |
+
if state.interrupted: interrupted = True ; break
|
593 |
+
|
594 |
+
# DEBUG: current task
|
595 |
+
# print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
|
596 |
+
task = task_queue.pop(0)
|
597 |
+
if task[0] == 'pre_norm':
|
598 |
+
group_norm_param.add_tile(tile, task[1])
|
599 |
+
break
|
600 |
+
elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
|
601 |
+
task_id = 0
|
602 |
+
res = task[1](tile)
|
603 |
+
if not self.fast_mode or task[0] == 'store_res_cpu':
|
604 |
+
res = res.cpu()
|
605 |
+
while task_queue[task_id][0] != 'add_res':
|
606 |
+
task_id += 1
|
607 |
+
task_queue[task_id][1] = res
|
608 |
+
elif task[0] == 'add_res':
|
609 |
+
tile += task[1].to(device)
|
610 |
+
task[1] = None
|
611 |
+
else:
|
612 |
+
tile = task[1](tile)
|
613 |
+
pbar.update(1)
|
614 |
+
|
615 |
+
if interrupted: break
|
616 |
+
|
617 |
+
# check for NaNs in the tile.
|
618 |
+
# If there are NaNs, we abort the process to save user's time
|
619 |
+
devices.test_for_nans(tile, "vae")
|
620 |
+
|
621 |
+
if len(task_queue) == 0:
|
622 |
+
tiles[i] = None
|
623 |
+
num_completed += 1
|
624 |
+
if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
|
625 |
+
result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
|
626 |
+
result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
|
627 |
+
del tile
|
628 |
+
elif i == num_tiles - 1 and forward:
|
629 |
+
forward = False
|
630 |
+
tiles[i] = tile
|
631 |
+
elif i == 0 and not forward:
|
632 |
+
forward = True
|
633 |
+
tiles[i] = tile
|
634 |
+
else:
|
635 |
+
tiles[i] = tile.cpu()
|
636 |
+
del tile
|
637 |
+
|
638 |
+
if interrupted: break
|
639 |
+
if num_completed == num_tiles: break
|
640 |
+
|
641 |
+
# insert the group norm task to the head of each task queue
|
642 |
+
group_norm_func = group_norm_param.summary()
|
643 |
+
if group_norm_func is not None:
|
644 |
+
for i in range(num_tiles):
|
645 |
+
task_queue = task_queues[i]
|
646 |
+
task_queue.insert(0, ('apply_norm', group_norm_func))
|
647 |
+
|
648 |
+
# Done!
|
649 |
+
pbar.close()
|
650 |
+
return result if result is not None else result_approx.to(device)
|
651 |
+
|
652 |
+
|
653 |
+
class Script(scripts.Script):
|
654 |
+
|
655 |
+
def __init__(self):
|
656 |
+
self.hooked = False
|
657 |
+
|
658 |
+
def title(self):
|
659 |
+
return "Tiled VAE"
|
660 |
+
|
661 |
+
def show(self, is_img2img):
|
662 |
+
return scripts.AlwaysVisible
|
663 |
+
|
664 |
+
def ui(self, is_img2img):
|
665 |
+
tab = 't2i' if not is_img2img else 'i2i'
|
666 |
+
uid = lambda name: f'MD-{tab}-{name}'
|
667 |
+
|
668 |
+
with gr.Accordion('Tiled VAE', open=False, elem_id=f'MDV-{tab}'):
|
669 |
+
with gr.Row() as tab_enable:
|
670 |
+
enabled = gr.Checkbox(label='Enable Tiled VAE', value=False, elem_id=uid('enable'))
|
671 |
+
vae_to_gpu = gr.Checkbox(label='Move VAE to GPU (if possible)', value=True, elem_id=uid('vae2gpu'))
|
672 |
+
|
673 |
+
gr.HTML('<p style="margin-bottom:0.8em"> Recommended to set tile sizes as large as possible before got CUDA error: out of memory. </p>')
|
674 |
+
with gr.Row() as tab_size:
|
675 |
+
encoder_tile_size = gr.Slider(label='Encoder Tile Size', minimum=256, maximum=4096, step=16, value=get_rcmd_enc_tsize(), elem_id=uid('enc-size'))
|
676 |
+
decoder_tile_size = gr.Slider(label='Decoder Tile Size', minimum=48, maximum=512, step=16, value=get_rcmd_dec_tsize(), elem_id=uid('dec-size'))
|
677 |
+
reset = gr.Button(value='↻ Reset', variant='tool')
|
678 |
+
reset.click(fn=lambda: [get_rcmd_enc_tsize(), get_rcmd_dec_tsize()], outputs=[encoder_tile_size, decoder_tile_size], show_progress=False)
|
679 |
+
|
680 |
+
with gr.Row() as tab_param:
|
681 |
+
fast_encoder = gr.Checkbox(label='Fast Encoder', value=True, elem_id=uid('fastenc'))
|
682 |
+
color_fix = gr.Checkbox(label='Fast Encoder Color Fix', value=False, visible=True, elem_id=uid('fastenc-colorfix'))
|
683 |
+
fast_decoder = gr.Checkbox(label='Fast Decoder', value=True, elem_id=uid('fastdec'))
|
684 |
+
|
685 |
+
fast_encoder.change(fn=gr_show, inputs=fast_encoder, outputs=color_fix, show_progress=False)
|
686 |
+
|
687 |
+
return [
|
688 |
+
enabled,
|
689 |
+
encoder_tile_size, decoder_tile_size,
|
690 |
+
vae_to_gpu, fast_decoder, fast_encoder, color_fix,
|
691 |
+
]
|
692 |
+
|
693 |
+
def process(self, p:Processing,
|
694 |
+
enabled:bool,
|
695 |
+
encoder_tile_size:int, decoder_tile_size:int,
|
696 |
+
vae_to_gpu:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool
|
697 |
+
):
|
698 |
+
enabled = True
|
699 |
+
encoder_tile_size = 1536
|
700 |
+
decoder_tile_size = 96
|
701 |
+
vae_to_gpu = True
|
702 |
+
fast_decoder = True
|
703 |
+
fast_encoder = False
|
704 |
+
color_fix = False
|
705 |
+
# for shorthand
|
706 |
+
vae = p.sd_model.first_stage_model
|
707 |
+
encoder = vae.encoder
|
708 |
+
decoder = vae.decoder
|
709 |
+
|
710 |
+
# undo hijack if disabled (in cases last time crashed)
|
711 |
+
if not enabled:
|
712 |
+
if self.hooked:
|
713 |
+
if isinstance(encoder.forward, VAEHook):
|
714 |
+
encoder.forward.net = None
|
715 |
+
encoder.forward = encoder.original_forward
|
716 |
+
if isinstance(decoder.forward, VAEHook):
|
717 |
+
decoder.forward.net = None
|
718 |
+
decoder.forward = decoder.original_forward
|
719 |
+
self.hooked = False
|
720 |
+
return
|
721 |
+
|
722 |
+
if devices.get_optimal_device_name().startswith('cuda') and vae.device == devices.cpu and not vae_to_gpu:
|
723 |
+
print("[Tiled VAE] warn: VAE is not on GPU, check 'Move VAE to GPU' if possible.")
|
724 |
+
|
725 |
+
# do hijack
|
726 |
+
kwargs = {
|
727 |
+
'fast_decoder': fast_decoder,
|
728 |
+
'fast_encoder': fast_encoder,
|
729 |
+
'color_fix': color_fix,
|
730 |
+
'to_gpu': vae_to_gpu,
|
731 |
+
}
|
732 |
+
|
733 |
+
# save original forward (only once)
|
734 |
+
if not hasattr(encoder, 'original_forward'): setattr(encoder, 'original_forward', encoder.forward)
|
735 |
+
if not hasattr(decoder, 'original_forward'): setattr(decoder, 'original_forward', decoder.forward)
|
736 |
+
|
737 |
+
self.hooked = True
|
738 |
+
|
739 |
+
encoder.forward = VAEHook(encoder, encoder_tile_size, is_decoder=False, **kwargs)
|
740 |
+
decoder.forward = VAEHook(decoder, decoder_tile_size, is_decoder=True, **kwargs)
|
741 |
+
|
742 |
+
def postprocess(self, p:Processing, processed, enabled:bool, *args):
|
743 |
+
if not enabled: return
|
744 |
+
|
745 |
+
vae = p.sd_model.first_stage_model
|
746 |
+
encoder = vae.encoder
|
747 |
+
decoder = vae.decoder
|
748 |
+
if isinstance(encoder.forward, VAEHook):
|
749 |
+
encoder.forward.net = None
|
750 |
+
encoder.forward = encoder.original_forward
|
751 |
+
if isinstance(decoder.forward, VAEHook):
|
752 |
+
decoder.forward.net = None
|
753 |
+
decoder.forward = decoder.original_forward
|