|
import json |
|
import gradio as gr |
|
from modules import scripts, script_callbacks, processing, shared |
|
from lib_free_u import global_state, unet, xyz_grid |
|
|
|
|
|
txt2img_steps_component = None |
|
img2img_steps_component = None |
|
txt2img_steps_callbacks = [] |
|
img2img_steps_callbacks = [] |
|
|
|
|
|
class FreeUScript(scripts.Script): |
|
def title(self): |
|
return "FreeU" |
|
|
|
def show(self, is_img2img): |
|
return scripts.AlwaysVisible |
|
|
|
def ui(self, is_img2img): |
|
global_state.reload_presets() |
|
default_stage_infos = next(iter(global_state.all_presets.values())).stage_infos |
|
|
|
with gr.Accordion(open=False, label=self.title()): |
|
with gr.Row(): |
|
with gr.Row(): |
|
enabled = gr.Checkbox( |
|
label="Enable", |
|
value=False, |
|
) |
|
|
|
version = gr.Dropdown( |
|
show_label=False, |
|
elem_id=self.elem_id("version"), |
|
choices=list(global_state.all_versions.keys()), |
|
value=next(iter(reversed(global_state.all_versions.keys()))), |
|
) |
|
|
|
preset_name = gr.Dropdown( |
|
show_label=False, |
|
choices=list(global_state.all_presets.keys()), |
|
value=next(iter(global_state.all_presets.keys())), |
|
type="value", |
|
elem_id=self.elem_id("preset_name"), |
|
allow_custom_value=True, |
|
tooltip="Apply button loads settings\nWrite custom name to enable save\nDelete automatically will save to file", |
|
size="sm", |
|
) |
|
|
|
is_custom_preset = preset_name.value not in global_state.default_presets |
|
preset_exists = preset_name.value in global_state.all_presets |
|
|
|
apply_preset = gr.Button( |
|
value="โ
", |
|
size="lg", |
|
elem_classes="tool", |
|
interactive=preset_exists, |
|
) |
|
save_preset = gr.Button( |
|
value="๐พ", |
|
size="lg", |
|
elem_classes="tool", |
|
interactive=is_custom_preset, |
|
) |
|
refresh_presets = gr.Button( |
|
value="๐", |
|
size="lg", |
|
elem_classes="tool" |
|
) |
|
delete_preset = gr.Button( |
|
value="๐๏ธ", |
|
size="lg", |
|
elem_classes="tool", |
|
interactive=is_custom_preset and preset_exists, |
|
) |
|
|
|
with gr.Row(): |
|
start_ratio = gr.Slider( |
|
label="Start At Step", |
|
elem_id=self.elem_id("start_at_step"), |
|
minimum=0, |
|
maximum=1, |
|
value=0, |
|
) |
|
|
|
stop_ratio = gr.Slider( |
|
label="Stop At Step", |
|
elem_id=self.elem_id("stop_at_step"), |
|
minimum=0, |
|
maximum=1, |
|
value=1, |
|
) |
|
|
|
transition_smoothness = gr.Slider( |
|
label="Transition Smoothness", |
|
elem_id=self.elem_id("transition_smoothness"), |
|
minimum=0, |
|
maximum=1, |
|
value=0, |
|
) |
|
|
|
flat_stage_infos = [] |
|
|
|
for index in range(global_state.STAGES_COUNT): |
|
stage_n = index + 1 |
|
default_stage_info = default_stage_infos[index] |
|
|
|
with gr.Accordion(open=index < 2, label=f"Stage {stage_n}"): |
|
with gr.Row(): |
|
backbone_scale = gr.Slider( |
|
label=f"Backbone {stage_n} Scale", |
|
elem_id=self.elem_id(f"backbone_scale_{stage_n}"), |
|
minimum=-1, |
|
maximum=3, |
|
value=default_stage_info.backbone_factor, |
|
) |
|
|
|
backbone_offset = gr.Slider( |
|
label=f"Backbone {stage_n} Offset", |
|
elem_id=self.elem_id(f"backbone_offset_{stage_n}"), |
|
minimum=0, |
|
maximum=1, |
|
value=default_stage_info.backbone_offset, |
|
) |
|
|
|
backbone_width = gr.Slider( |
|
label=f"Backbone {stage_n} Width", |
|
elem_id=self.elem_id(f"backbone_width_{stage_n}"), |
|
minimum=0, |
|
maximum=1, |
|
value=default_stage_info.backbone_width, |
|
) |
|
|
|
with gr.Row(): |
|
skip_scale = gr.Slider( |
|
label=f"Skip {stage_n} Scale", |
|
elem_id=self.elem_id(f"skip_scale_{stage_n}"), |
|
minimum=-1, |
|
maximum=3, |
|
value=default_stage_info.skip_factor, |
|
) |
|
|
|
skip_high_end_scale = gr.Slider( |
|
label=f"Skip {stage_n} High End Scale", |
|
elem_id=self.elem_id(f"skip_high_end_scale_{stage_n}"), |
|
minimum=-1, |
|
maximum=3, |
|
value=default_stage_info.skip_high_end_factor, |
|
) |
|
|
|
skip_cutoff = gr.Slider( |
|
label=f"Skip {stage_n} Cutoff", |
|
elem_id=self.elem_id(f"skip_cutoff_{stage_n}"), |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=default_stage_info.skip_cutoff, |
|
) |
|
|
|
flat_stage_infos.extend([ |
|
backbone_scale, |
|
skip_scale, |
|
backbone_offset, |
|
backbone_width, |
|
skip_cutoff, |
|
skip_high_end_scale, |
|
]) |
|
|
|
def on_preset_name_change(preset_name): |
|
is_custom_preset = preset_name not in global_state.default_presets |
|
preset_exists = preset_name in global_state.all_presets |
|
return ( |
|
gr.Button.update(interactive=preset_exists), |
|
gr.Button.update(interactive=is_custom_preset), |
|
gr.Button.update(interactive=is_custom_preset and preset_exists), |
|
) |
|
|
|
preset_name.change( |
|
fn=on_preset_name_change, |
|
inputs=[preset_name], |
|
outputs=[apply_preset, save_preset, delete_preset], |
|
) |
|
|
|
def on_apply_click(user_settings_name): |
|
preset = global_state.all_presets[user_settings_name] |
|
return ( |
|
gr.Slider.update(value=preset.start_ratio), |
|
gr.Slider.update(value=preset.stop_ratio), |
|
gr.Slider.update(value=preset.transition_smoothness), |
|
*[ |
|
gr.update(value=v) |
|
for stage_info in preset.stage_infos |
|
for v in stage_info.to_dict(include_default=True).values() |
|
], |
|
) |
|
|
|
apply_preset.click( |
|
fn=on_apply_click, |
|
inputs=[preset_name], |
|
outputs=[start_ratio, stop_ratio, transition_smoothness, *flat_stage_infos], |
|
) |
|
|
|
def on_save_click(preset_name, start_ratio, stop_ratio, transition_smoothness, *flat_stage_infos): |
|
global_state.all_presets[preset_name] = global_state.State( |
|
stage_infos=flat_stage_infos, |
|
start_ratio=start_ratio, |
|
stop_ratio=stop_ratio, |
|
transition_smoothness=transition_smoothness, |
|
) |
|
global_state.save_presets() |
|
|
|
return ( |
|
gr.Dropdown.update(choices=list(global_state.all_presets.keys())), |
|
gr.Button.update(interactive=True), |
|
gr.Button.update(interactive=True), |
|
) |
|
|
|
save_preset.click( |
|
fn=on_save_click, |
|
inputs=[preset_name, start_ratio, stop_ratio, transition_smoothness, *flat_stage_infos], |
|
outputs=[preset_name, apply_preset, delete_preset], |
|
) |
|
|
|
def on_refresh_click(preset_name): |
|
global_state.reload_presets() |
|
is_custom_preset = preset_name not in global_state.default_presets |
|
preset_exists = preset_name in global_state.all_presets |
|
|
|
return ( |
|
gr.Dropdown.update(value=preset_name, choices=list(global_state.all_presets.keys())), |
|
gr.Button.update(interactive=preset_exists), |
|
gr.Button.update(interactive=is_custom_preset), |
|
gr.Button.update(interactive=is_custom_preset and preset_exists), |
|
) |
|
|
|
refresh_presets.click( |
|
fn=on_refresh_click, |
|
inputs=[preset_name], |
|
outputs=[preset_name, apply_preset, save_preset, delete_preset], |
|
) |
|
|
|
def on_delete_click(preset_name): |
|
preset_name_index = list(global_state.all_presets.keys()).index(preset_name) |
|
del global_state.all_presets[preset_name] |
|
global_state.save_presets() |
|
|
|
preset_name_index = min(len(global_state.all_presets) - 1, preset_name_index) |
|
preset_names = list(global_state.all_presets.keys()) |
|
preset_name = preset_names[preset_name_index] |
|
|
|
is_custom_preset = preset_name not in global_state.default_presets |
|
preset_exists = preset_name in global_state.all_presets |
|
return ( |
|
gr.Dropdown.update(value=preset_name, choices=preset_names), |
|
gr.Button.update(interactive=preset_exists), |
|
gr.Button.update(interactive=is_custom_preset), |
|
gr.Button.update(interactive=is_custom_preset and preset_exists), |
|
) |
|
|
|
delete_preset.click( |
|
fn=on_delete_click, |
|
inputs=[preset_name], |
|
outputs=[preset_name, apply_preset, save_preset, delete_preset], |
|
) |
|
|
|
schedule_infotext = gr.HTML(visible=False, interactive=False) |
|
stages_infotext = gr.HTML(visible=False, interactive=False) |
|
version_infotext = gr.HTML(visible=False, interactive=False) |
|
|
|
def register_schedule_infotext_change(steps_component): |
|
schedule_infotext.change( |
|
fn=self.on_schedule_infotext_update, |
|
inputs=[schedule_infotext, steps_component], |
|
outputs=[schedule_infotext, start_ratio, stop_ratio, transition_smoothness], |
|
) |
|
|
|
steps_component, steps_callbacks = ( |
|
(img2img_steps_component, img2img_steps_callbacks) |
|
if is_img2img else |
|
(txt2img_steps_component, txt2img_steps_callbacks) |
|
) |
|
|
|
if steps_component is None: |
|
steps_callbacks.append(register_schedule_infotext_change) |
|
else: |
|
register_schedule_infotext_change(steps_component) |
|
|
|
stages_infotext.change( |
|
fn=self.on_stages_infotext_update, |
|
inputs=[stages_infotext], |
|
outputs=[stages_infotext, enabled, *flat_stage_infos], |
|
) |
|
|
|
version_infotext.change( |
|
fn=self.on_version_infotext_update, |
|
inputs=[version_infotext], |
|
outputs=[version_infotext, version] |
|
) |
|
|
|
self.infotext_fields = [ |
|
(schedule_infotext, "FreeU Schedule"), |
|
(stages_infotext, "FreeU Stages"), |
|
(version_infotext, "FreeU Version"), |
|
] |
|
self.paste_field_names = [f for _, f in self.infotext_fields] |
|
|
|
return enabled, start_ratio, stop_ratio, transition_smoothness, version, *flat_stage_infos |
|
|
|
def on_schedule_infotext_update(self, infotext, steps): |
|
if not infotext: |
|
return (gr.skip(),) * 4 |
|
|
|
start_ratio, stop_ratio, transition_smoothness, *_ = infotext.split(", ") |
|
|
|
return ( |
|
gr.update(value=""), |
|
gr.update(value=unet.to_denoising_step(xyz_grid.int_or_float(start_ratio), steps) / steps), |
|
gr.update(value=unet.to_denoising_step(xyz_grid.int_or_float(stop_ratio), steps) / steps), |
|
gr.update(value=float(transition_smoothness)), |
|
) |
|
|
|
def on_stages_infotext_update(self, infotext): |
|
if not infotext: |
|
return (gr.skip(),) * (2 + global_state.STAGES_COUNT * global_state.STAGE_INFO_ARGS_LEN) |
|
|
|
stage_infos = json.loads(infotext) |
|
stage_infos = [ |
|
global_state.StageInfo(**stage_info) |
|
for stage_info in stage_infos |
|
] |
|
stage_infos.extend([ |
|
global_state.StageInfo() |
|
for _ in range(global_state.STAGES_COUNT - len(stage_infos)) |
|
]) |
|
|
|
return ( |
|
gr.update(value=""), |
|
gr.update(value=shared.opts.data.get("freeu_png_info_auto_enable", True)), |
|
*( |
|
gr.update(value=v) |
|
for stage_info in stage_infos |
|
for v in stage_info.to_dict(include_default=True).values() |
|
) |
|
) |
|
|
|
def on_version_infotext_update(self, infotext): |
|
if not infotext: |
|
return (gr.skip(),) * 2 |
|
|
|
return ( |
|
gr.update(value=""), |
|
gr.update(value=global_state.reversed_all_versions.get(infotext, infotext)), |
|
) |
|
|
|
def process( |
|
self, |
|
p: processing.StableDiffusionProcessing, |
|
*args |
|
): |
|
if isinstance(args[0], dict): |
|
global_state.instance = global_state.State(**args[0]) |
|
elif isinstance(args[0], bool): |
|
stage_infos_begin = global_state.STATE_ARGS_LEN - 1 |
|
global_state.instance = global_state.State( |
|
args[0], |
|
*[float(n) for n in args[1:stage_infos_begin-1]], |
|
args[stage_infos_begin-1], |
|
args[stage_infos_begin:], |
|
) |
|
else: |
|
raise TypeError(f"Unrecognized args sequence starting with type {type(args[0])}") |
|
|
|
global_state.apply_xyz() |
|
global_state.xyz_attrs.clear() |
|
if not global_state.instance.enable: |
|
return |
|
|
|
last_d = False |
|
p.extra_generation_params["FreeU Stages"] = json.dumps(list(reversed([ |
|
stage_info.to_dict() |
|
for stage_info in reversed(global_state.instance.stage_infos) |
|
|
|
if last_d or stage_info.to_dict() and (last_d := True) |
|
]))) |
|
p.extra_generation_params["FreeU Schedule"] = ", ".join([ |
|
str(global_state.instance.start_ratio), |
|
str(global_state.instance.stop_ratio), |
|
str(global_state.instance.transition_smoothness), |
|
]) |
|
p.extra_generation_params["FreeU Version"] = global_state.instance.version |
|
|
|
def process_batch(self, p, *args, **kwargs): |
|
global_state.current_sampling_step = 0 |
|
|
|
|
|
def increment_sampling_step(*_args, **_kwargs): |
|
global_state.current_sampling_step += 1 |
|
|
|
|
|
try: |
|
script_callbacks.on_cfg_after_cfg(increment_sampling_step) |
|
except AttributeError: |
|
|
|
|
|
|
|
script_callbacks.on_cfg_denoised(increment_sampling_step) |
|
|
|
|
|
def on_after_component(component, **kwargs): |
|
global txt2img_steps_component, img2img_steps_component |
|
|
|
if kwargs.get("elem_id", None) == "img2img_steps": |
|
img2img_steps_component = component |
|
for callback in img2img_steps_callbacks: |
|
callback(component) |
|
|
|
if kwargs.get("elem_id", None) == "txt2img_steps": |
|
txt2img_steps_component = component |
|
for callback in txt2img_steps_callbacks: |
|
callback(component) |
|
|
|
|
|
script_callbacks.on_after_component(on_after_component) |
|
|
|
|
|
def on_ui_settings(): |
|
section = ("freeu", "FreeU") |
|
shared.opts.add_option( |
|
"freeu_png_info_auto_enable", |
|
shared.OptionInfo( |
|
default=True, |
|
label="Auto enable when loading the PNG Info of a generation that used FreeU", |
|
section=section, |
|
) |
|
) |
|
|
|
|
|
script_callbacks.on_ui_settings(on_ui_settings) |
|
|
|
|
|
unet.patch() |
|
xyz_grid.patch() |
|
|