File size: 6,720 Bytes
0163a2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
### https://gist.github.com/kohya-ss/3f774da220df102548093a7abc8538ed

from pathlib import Path
from omegaconf import DictConfig, OmegaConf
from modules import scripts, script_callbacks
import gradio as gr
import torch

CONFIG_PATH = Path(__file__).parent.resolve() / '../config.yaml'


class Scaler(torch.nn.Module):
    def __init__(self, scale, block, scaler):
        super().__init__()
        self.scale = scale
        self.block = block
        self.scaler = scaler
        
    def forward(self, x, *args):
        x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode=self.scaler)
        return self.block(x, *args)
    
    
class KohyaHiresFix(scripts.Script):
    def __init__(self):
        super().__init__()
        try:
            self.config: DictConfig = OmegaConf.load(CONFIG_PATH)
        except Exception:
            self.config = DictConfig({})
        self.disable = False
        self.step_limit = 0
        self.infotext_fields = []

    def title(self):
        return "Kohya Hires.fix"

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def ui(self, is_img2img):
        with gr.Accordion(label='Kohya Hires.fix', open=False):
            with gr.Row():
                enable = gr.Checkbox(label='Enable extension', value=False)
            with gr.Row():
                s1 = gr.Slider(minimum=0, maximum=0.5, step=0.01, label="Stop at", value=self.config.get('s1', 0.15))
                d1 = gr.Slider(minimum=1, maximum=10, step=1, label="Depth", value=self.config.get('d1', 3))
            with gr.Row():
                s2 = gr.Slider(minimum=0, maximum=0.5, step=0.01, label="Stop at", value=self.config.get('s2', 0.3))
                d2 = gr.Slider(minimum=1, maximum=10, step=1, label="Depth", value=self.config.get('d2', 4))
            with gr.Row():
                scaler = gr.Dropdown(['bicubic', 'bilinear', 'nearest', 'nearest-exact'], label='Layer scaler', 
                                     value=self.config.get('scaler', 'bicubic'))
                downscale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, label="Downsampling scale", 
                                      value=self.config.get('downscale', 0.5))
                upscale = gr.Slider(minimum=1.0, maximum=4.0, step=0.1, label="Upsampling scale", 
                                    value=self.config.get('upscale', 2.0))
            with gr.Row():
                smooth_scaling = gr.Checkbox(label="Smooth scaling", value=self.config.get('smooth_scaling', True))
                early_out = gr.Checkbox(label="Early upsampling", value=self.config.get('early_out', False))
                only_one_pass = gr.Checkbox(label='Disable for additional passes', 
                                            value=self.config.get('only_one_pass', True))
        
        ui = [enable, only_one_pass, d1, d2, s1, s2, scaler, downscale, upscale, smooth_scaling, early_out]
        for elem in ui:
            setattr(elem, "do_not_save_to_config", True)

        parameters = {
            'DSHF_s1': s1,
            'DSHF_d1': d1,
            'DSHF_s2': s2,
            'DSHF_d2': d2,
            'DSHF_scaler': scaler,
            'DSHF_down': downscale,
            'DSHF_up': upscale,
            'DSHF_smooth': smooth_scaling,
            'DSHF_early': early_out,
            'DSHF_one': only_one_pass,
        }
        # using "DSHF_s1" as key to check if extension is enabled
        self.infotext_fields.append((enable, lambda d: d.get('DSHF_s1', False)))
        for k, element in parameters.items():
            self.infotext_fields.append((element, k))

        return ui

    def process(self, p, enable, only_one_pass, d1, d2, s1, s2, scaler, downscale, upscale, smooth_scaling, early_out):
        self.config = DictConfig({name: var for name, var in locals().items() if name not in ['self', 'p']})
        if not enable or self.disable:
            script_callbacks.remove_current_script_callbacks()
            return
        model = p.sd_model.model.diffusion_model
        if s1 > s2: self.config.s2 = s1
        self.p1 = (s1, d1 - 1)
        self.p2 = (s2, d2 - 1)
        self.step_limit = 0
        
        def denoiser_callback(params: script_callbacks.CFGDenoiserParams):
            if params.sampling_step < self.step_limit: return
            for s, d in [self.p1, self.p2]:
                out_d = d if self.config.early_out else -(d + 1)
                if params.sampling_step < params.total_sampling_steps * s:
                    if not isinstance(model.input_blocks[d], Scaler):
                        model.input_blocks[d] = Scaler(self.config.downscale, model.input_blocks[d], self.config.scaler)
                        model.output_blocks[out_d] = Scaler(self.config.upscale, model.output_blocks[out_d], self.config.scaler)
                    elif self.config.smooth_scaling:
                        scale_ratio = params.sampling_step / (params.total_sampling_steps * s)
                        downscale = min((1 - self.config.downscale) * scale_ratio + self.config.downscale, 1.0)
                        model.input_blocks[d].scale = downscale
                        model.output_blocks[out_d].scale = self.config.upscale * (self.config.downscale / downscale)
                    return
                elif isinstance(model.input_blocks[d], Scaler) and (self.p1[1] != self.p2[1] or s == self.p2[0]):
                    model.input_blocks[d] = model.input_blocks[d].block
                    model.output_blocks[out_d] = model.output_blocks[out_d].block
            self.step_limit = params.sampling_step if self.config.only_one_pass else 0

        script_callbacks.on_cfg_denoiser(denoiser_callback)

        parameters = {
            'DSHF_s1': s1,
            'DSHF_d1': d1,
            'DSHF_s2': s2,
            'DSHF_d2': d2,
            'DSHF_scaler': scaler,
            'DSHF_down': downscale,
            'DSHF_up': upscale,
            'DSHF_smooth': smooth_scaling,
            'DSHF_early': early_out,
            'DSHF_one': only_one_pass,
        }
        for k, v in parameters.items():
            p.extra_generation_params[k] = v

    def postprocess(self, p, processed, *args):
        for i, b in enumerate(p.sd_model.model.diffusion_model.input_blocks):
            if isinstance(b, Scaler):
                p.sd_model.model.diffusion_model.input_blocks[i] = b.block
        for i, b in enumerate(p.sd_model.model.diffusion_model.output_blocks):
            if isinstance(b, Scaler):
                p.sd_model.model.diffusion_model.output_blocks[i] = b.block
        OmegaConf.save(self.config, CONFIG_PATH)
        
    def process_batch(self, p, *args, **kwargs):
        self.step_limit = 0