|
|
|
|
|
from pathlib import Path |
|
from omegaconf import DictConfig, OmegaConf |
|
from modules import scripts, script_callbacks |
|
import gradio as gr |
|
import torch |
|
|
|
CONFIG_PATH = Path(__file__).parent.resolve() / '../config.yaml' |
|
|
|
|
|
class Scaler(torch.nn.Module): |
|
def __init__(self, scale, block, scaler): |
|
super().__init__() |
|
self.scale = scale |
|
self.block = block |
|
self.scaler = scaler |
|
|
|
def forward(self, x, *args): |
|
x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode=self.scaler) |
|
return self.block(x, *args) |
|
|
|
|
|
class KohyaHiresFix(scripts.Script): |
|
def __init__(self): |
|
super().__init__() |
|
try: |
|
self.config: DictConfig = OmegaConf.load(CONFIG_PATH) |
|
except Exception: |
|
self.config = DictConfig({}) |
|
self.disable = False |
|
self.step_limit = 0 |
|
self.infotext_fields = [] |
|
|
|
def title(self): |
|
return "Kohya Hires.fix" |
|
|
|
def show(self, is_img2img): |
|
return scripts.AlwaysVisible |
|
|
|
def ui(self, is_img2img): |
|
with gr.Accordion(label='Kohya Hires.fix', open=False): |
|
with gr.Row(): |
|
enable = gr.Checkbox(label='Enable extension', value=False) |
|
with gr.Row(): |
|
s1 = gr.Slider(minimum=0, maximum=0.5, step=0.01, label="Stop at", value=self.config.get('s1', 0.15)) |
|
d1 = gr.Slider(minimum=1, maximum=10, step=1, label="Depth", value=self.config.get('d1', 3)) |
|
with gr.Row(): |
|
s2 = gr.Slider(minimum=0, maximum=0.5, step=0.01, label="Stop at", value=self.config.get('s2', 0.3)) |
|
d2 = gr.Slider(minimum=1, maximum=10, step=1, label="Depth", value=self.config.get('d2', 4)) |
|
with gr.Row(): |
|
scaler = gr.Dropdown(['bicubic', 'bilinear', 'nearest', 'nearest-exact'], label='Layer scaler', |
|
value=self.config.get('scaler', 'bicubic')) |
|
downscale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, label="Downsampling scale", |
|
value=self.config.get('downscale', 0.5)) |
|
upscale = gr.Slider(minimum=1.0, maximum=4.0, step=0.1, label="Upsampling scale", |
|
value=self.config.get('upscale', 2.0)) |
|
with gr.Row(): |
|
smooth_scaling = gr.Checkbox(label="Smooth scaling", value=self.config.get('smooth_scaling', True)) |
|
early_out = gr.Checkbox(label="Early upsampling", value=self.config.get('early_out', False)) |
|
only_one_pass = gr.Checkbox(label='Disable for additional passes', |
|
value=self.config.get('only_one_pass', True)) |
|
|
|
ui = [enable, only_one_pass, d1, d2, s1, s2, scaler, downscale, upscale, smooth_scaling, early_out] |
|
for elem in ui: |
|
setattr(elem, "do_not_save_to_config", True) |
|
|
|
parameters = { |
|
'DSHF_s1': s1, |
|
'DSHF_d1': d1, |
|
'DSHF_s2': s2, |
|
'DSHF_d2': d2, |
|
'DSHF_scaler': scaler, |
|
'DSHF_down': downscale, |
|
'DSHF_up': upscale, |
|
'DSHF_smooth': smooth_scaling, |
|
'DSHF_early': early_out, |
|
'DSHF_one': only_one_pass, |
|
} |
|
|
|
self.infotext_fields.append((enable, lambda d: d.get('DSHF_s1', False))) |
|
for k, element in parameters.items(): |
|
self.infotext_fields.append((element, k)) |
|
|
|
return ui |
|
|
|
def process(self, p, enable, only_one_pass, d1, d2, s1, s2, scaler, downscale, upscale, smooth_scaling, early_out): |
|
self.config = DictConfig({name: var for name, var in locals().items() if name not in ['self', 'p']}) |
|
if not enable or self.disable: |
|
script_callbacks.remove_current_script_callbacks() |
|
return |
|
model = p.sd_model.model.diffusion_model |
|
if s1 > s2: self.config.s2 = s1 |
|
self.p1 = (s1, d1 - 1) |
|
self.p2 = (s2, d2 - 1) |
|
self.step_limit = 0 |
|
|
|
def denoiser_callback(params: script_callbacks.CFGDenoiserParams): |
|
if params.sampling_step < self.step_limit: return |
|
for s, d in [self.p1, self.p2]: |
|
out_d = d if self.config.early_out else -(d + 1) |
|
if params.sampling_step < params.total_sampling_steps * s: |
|
if not isinstance(model.input_blocks[d], Scaler): |
|
model.input_blocks[d] = Scaler(self.config.downscale, model.input_blocks[d], self.config.scaler) |
|
model.output_blocks[out_d] = Scaler(self.config.upscale, model.output_blocks[out_d], self.config.scaler) |
|
elif self.config.smooth_scaling: |
|
scale_ratio = params.sampling_step / (params.total_sampling_steps * s) |
|
downscale = min((1 - self.config.downscale) * scale_ratio + self.config.downscale, 1.0) |
|
model.input_blocks[d].scale = downscale |
|
model.output_blocks[out_d].scale = self.config.upscale * (self.config.downscale / downscale) |
|
return |
|
elif isinstance(model.input_blocks[d], Scaler) and (self.p1[1] != self.p2[1] or s == self.p2[0]): |
|
model.input_blocks[d] = model.input_blocks[d].block |
|
model.output_blocks[out_d] = model.output_blocks[out_d].block |
|
self.step_limit = params.sampling_step if self.config.only_one_pass else 0 |
|
|
|
script_callbacks.on_cfg_denoiser(denoiser_callback) |
|
|
|
parameters = { |
|
'DSHF_s1': s1, |
|
'DSHF_d1': d1, |
|
'DSHF_s2': s2, |
|
'DSHF_d2': d2, |
|
'DSHF_scaler': scaler, |
|
'DSHF_down': downscale, |
|
'DSHF_up': upscale, |
|
'DSHF_smooth': smooth_scaling, |
|
'DSHF_early': early_out, |
|
'DSHF_one': only_one_pass, |
|
} |
|
for k, v in parameters.items(): |
|
p.extra_generation_params[k] = v |
|
|
|
def postprocess(self, p, processed, *args): |
|
for i, b in enumerate(p.sd_model.model.diffusion_model.input_blocks): |
|
if isinstance(b, Scaler): |
|
p.sd_model.model.diffusion_model.input_blocks[i] = b.block |
|
for i, b in enumerate(p.sd_model.model.diffusion_model.output_blocks): |
|
if isinstance(b, Scaler): |
|
p.sd_model.model.diffusion_model.output_blocks[i] = b.block |
|
OmegaConf.save(self.config, CONFIG_PATH) |
|
|
|
def process_batch(self, p, *args, **kwargs): |
|
self.step_limit = 0 |
|
|