ehristoforu's picture
Upload folder using huggingface_hub
0163a2c verified
raw
history blame
6.72 kB
### https://gist.github.com/kohya-ss/3f774da220df102548093a7abc8538ed
from pathlib import Path
from omegaconf import DictConfig, OmegaConf
from modules import scripts, script_callbacks
import gradio as gr
import torch
CONFIG_PATH = Path(__file__).parent.resolve() / '../config.yaml'
class Scaler(torch.nn.Module):
def __init__(self, scale, block, scaler):
super().__init__()
self.scale = scale
self.block = block
self.scaler = scaler
def forward(self, x, *args):
x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode=self.scaler)
return self.block(x, *args)
class KohyaHiresFix(scripts.Script):
def __init__(self):
super().__init__()
try:
self.config: DictConfig = OmegaConf.load(CONFIG_PATH)
except Exception:
self.config = DictConfig({})
self.disable = False
self.step_limit = 0
self.infotext_fields = []
def title(self):
return "Kohya Hires.fix"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Accordion(label='Kohya Hires.fix', open=False):
with gr.Row():
enable = gr.Checkbox(label='Enable extension', value=False)
with gr.Row():
s1 = gr.Slider(minimum=0, maximum=0.5, step=0.01, label="Stop at", value=self.config.get('s1', 0.15))
d1 = gr.Slider(minimum=1, maximum=10, step=1, label="Depth", value=self.config.get('d1', 3))
with gr.Row():
s2 = gr.Slider(minimum=0, maximum=0.5, step=0.01, label="Stop at", value=self.config.get('s2', 0.3))
d2 = gr.Slider(minimum=1, maximum=10, step=1, label="Depth", value=self.config.get('d2', 4))
with gr.Row():
scaler = gr.Dropdown(['bicubic', 'bilinear', 'nearest', 'nearest-exact'], label='Layer scaler',
value=self.config.get('scaler', 'bicubic'))
downscale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, label="Downsampling scale",
value=self.config.get('downscale', 0.5))
upscale = gr.Slider(minimum=1.0, maximum=4.0, step=0.1, label="Upsampling scale",
value=self.config.get('upscale', 2.0))
with gr.Row():
smooth_scaling = gr.Checkbox(label="Smooth scaling", value=self.config.get('smooth_scaling', True))
early_out = gr.Checkbox(label="Early upsampling", value=self.config.get('early_out', False))
only_one_pass = gr.Checkbox(label='Disable for additional passes',
value=self.config.get('only_one_pass', True))
ui = [enable, only_one_pass, d1, d2, s1, s2, scaler, downscale, upscale, smooth_scaling, early_out]
for elem in ui:
setattr(elem, "do_not_save_to_config", True)
parameters = {
'DSHF_s1': s1,
'DSHF_d1': d1,
'DSHF_s2': s2,
'DSHF_d2': d2,
'DSHF_scaler': scaler,
'DSHF_down': downscale,
'DSHF_up': upscale,
'DSHF_smooth': smooth_scaling,
'DSHF_early': early_out,
'DSHF_one': only_one_pass,
}
# using "DSHF_s1" as key to check if extension is enabled
self.infotext_fields.append((enable, lambda d: d.get('DSHF_s1', False)))
for k, element in parameters.items():
self.infotext_fields.append((element, k))
return ui
def process(self, p, enable, only_one_pass, d1, d2, s1, s2, scaler, downscale, upscale, smooth_scaling, early_out):
self.config = DictConfig({name: var for name, var in locals().items() if name not in ['self', 'p']})
if not enable or self.disable:
script_callbacks.remove_current_script_callbacks()
return
model = p.sd_model.model.diffusion_model
if s1 > s2: self.config.s2 = s1
self.p1 = (s1, d1 - 1)
self.p2 = (s2, d2 - 1)
self.step_limit = 0
def denoiser_callback(params: script_callbacks.CFGDenoiserParams):
if params.sampling_step < self.step_limit: return
for s, d in [self.p1, self.p2]:
out_d = d if self.config.early_out else -(d + 1)
if params.sampling_step < params.total_sampling_steps * s:
if not isinstance(model.input_blocks[d], Scaler):
model.input_blocks[d] = Scaler(self.config.downscale, model.input_blocks[d], self.config.scaler)
model.output_blocks[out_d] = Scaler(self.config.upscale, model.output_blocks[out_d], self.config.scaler)
elif self.config.smooth_scaling:
scale_ratio = params.sampling_step / (params.total_sampling_steps * s)
downscale = min((1 - self.config.downscale) * scale_ratio + self.config.downscale, 1.0)
model.input_blocks[d].scale = downscale
model.output_blocks[out_d].scale = self.config.upscale * (self.config.downscale / downscale)
return
elif isinstance(model.input_blocks[d], Scaler) and (self.p1[1] != self.p2[1] or s == self.p2[0]):
model.input_blocks[d] = model.input_blocks[d].block
model.output_blocks[out_d] = model.output_blocks[out_d].block
self.step_limit = params.sampling_step if self.config.only_one_pass else 0
script_callbacks.on_cfg_denoiser(denoiser_callback)
parameters = {
'DSHF_s1': s1,
'DSHF_d1': d1,
'DSHF_s2': s2,
'DSHF_d2': d2,
'DSHF_scaler': scaler,
'DSHF_down': downscale,
'DSHF_up': upscale,
'DSHF_smooth': smooth_scaling,
'DSHF_early': early_out,
'DSHF_one': only_one_pass,
}
for k, v in parameters.items():
p.extra_generation_params[k] = v
def postprocess(self, p, processed, *args):
for i, b in enumerate(p.sd_model.model.diffusion_model.input_blocks):
if isinstance(b, Scaler):
p.sd_model.model.diffusion_model.input_blocks[i] = b.block
for i, b in enumerate(p.sd_model.model.diffusion_model.output_blocks):
if isinstance(b, Scaler):
p.sd_model.model.diffusion_model.output_blocks[i] = b.block
OmegaConf.save(self.config, CONFIG_PATH)
def process_batch(self, p, *args, **kwargs):
self.step_limit = 0