|
import argparse |
|
import inspect |
|
import os |
|
from pathlib import Path |
|
import toml |
|
from kohya_ss.library import train_util, config_util |
|
|
|
import gradio as gr |
|
|
|
from scripts.shared import ROOT_DIR |
|
from scripts.utilities import gradio_to_args |
|
|
|
PRESET_DIR = os.path.join(ROOT_DIR, "presets") |
|
PRESET_PATH = os.path.join(ROOT_DIR, "presets.json") |
|
|
|
|
|
def get_arg_templates(fn): |
|
parser = argparse.ArgumentParser() |
|
args = [parser] |
|
sig = inspect.signature(fn) |
|
args.extend([True] * (len(sig.parameters) - 1)) |
|
fn(*args) |
|
keys = [ |
|
x.replace("--", "") for x in parser.__dict__["_option_string_actions"].keys() |
|
] |
|
keys = [x for x in keys if x not in ["help", "-h"]] |
|
return keys, fn.__name__.replace("add_", "") |
|
|
|
|
|
arguments_functions = [ |
|
train_util.add_dataset_arguments, |
|
train_util.add_optimizer_arguments, |
|
train_util.add_sd_models_arguments, |
|
train_util.add_sd_saving_arguments, |
|
train_util.add_training_arguments, |
|
config_util.add_config_arguments, |
|
] |
|
|
|
arg_templates = [get_arg_templates(x) for x in arguments_functions] |
|
|
|
|
|
def load_presets(): |
|
obj = {} |
|
os.makedirs(PRESET_DIR, exist_ok=True) |
|
preset_names = os.listdir(PRESET_DIR) |
|
for preset_name in preset_names: |
|
preset_path = os.path.join(PRESET_DIR, preset_name) |
|
obj[preset_name] = {} |
|
for key in os.listdir(preset_path): |
|
key = key.replace(".toml", "") |
|
obj[preset_name][key] = load_preset(preset_name, key) |
|
return obj |
|
|
|
|
|
def load_preset(key, name): |
|
filepath = os.path.join(PRESET_DIR, key, name + ".toml") |
|
if not os.path.exists(filepath): |
|
return {} |
|
with open(filepath, mode="r") as f: |
|
obj = toml.load(f) |
|
|
|
flatten = {} |
|
for k, v in obj.items(): |
|
if not isinstance(v, dict): |
|
flatten[k] = v |
|
else: |
|
for k2, v2 in v.items(): |
|
flatten[k2] = v2 |
|
return flatten |
|
|
|
|
|
def save_preset(key, name, value): |
|
obj = {} |
|
for k, v in value.items(): |
|
if isinstance(v, Path): |
|
v = str(v) |
|
for (template, category) in arg_templates: |
|
if k in template: |
|
if category not in obj: |
|
obj[category] = {} |
|
obj[category][k] = v |
|
break |
|
else: |
|
obj[k] = v |
|
|
|
filepath = os.path.join(PRESET_DIR, key, name + ".toml") |
|
os.makedirs(os.path.dirname(filepath), exist_ok=True) |
|
with open(filepath, mode="w") as f: |
|
toml.dump(obj, f) |
|
|
|
|
|
def delete_preset(key, name): |
|
filepath = os.path.join(PRESET_DIR, key, name + ".toml") |
|
if os.path.exists(filepath): |
|
os.remove(filepath) |
|
|
|
|
|
def create_ui(key, tmpls, opts): |
|
get_templates = lambda: tmpls() if callable(tmpls) else tmpls |
|
get_options = lambda: opts() if callable(opts) else opts |
|
|
|
presets = load_presets() |
|
|
|
if key not in presets: |
|
presets[key] = {} |
|
|
|
with gr.Box(): |
|
with gr.Row(): |
|
with gr.Column() as c: |
|
load_preset_button = gr.Button("Load preset", variant="primary") |
|
delete_preset_button = gr.Button("Delete preset") |
|
with gr.Column() as c: |
|
load_preset_name = gr.Dropdown( |
|
list(presets[key].keys()), show_label=False |
|
).style(container=False) |
|
reload_presets_button = gr.Button("🔄️") |
|
with gr.Column() as c: |
|
c.scale = 0.5 |
|
save_preset_name = gr.Textbox( |
|
"", placeholder="Preset name", lines=1, show_label=False |
|
).style(container=False) |
|
save_preset_button = gr.Button("Save preset", variant="primary") |
|
|
|
def update_dropdown(): |
|
presets = load_presets() |
|
if key not in presets: |
|
presets[key] = {} |
|
return gr.Dropdown.update(choices=list(presets[key].keys())) |
|
|
|
def _save_preset(args): |
|
name = args[save_preset_name] |
|
if not name: |
|
return update_dropdown() |
|
args = gradio_to_args(get_templates(), get_options(), args) |
|
save_preset(key, name, args) |
|
return update_dropdown() |
|
|
|
def _load_preset(args): |
|
name = args[load_preset_name] |
|
if not name: |
|
return update_dropdown() |
|
args = gradio_to_args(get_templates(), get_options(), args) |
|
preset = load_preset(key, name) |
|
result = [] |
|
for k, _ in args.items(): |
|
if k == load_preset_name: |
|
continue |
|
if k not in preset: |
|
result.append(None) |
|
continue |
|
v = preset[k] |
|
if type(v) == list: |
|
v = " ".join(v) |
|
result.append(v) |
|
return result[0] if len(result) == 1 else result |
|
|
|
def _delete_preset(name): |
|
if not name: |
|
return update_dropdown() |
|
delete_preset(key, name) |
|
return update_dropdown() |
|
|
|
def init(): |
|
save_preset_button.click( |
|
_save_preset, |
|
set([save_preset_name, *get_options().values()]), |
|
[load_preset_name], |
|
) |
|
load_preset_button.click( |
|
_load_preset, |
|
set([load_preset_name, *get_options().values()]), |
|
[*get_options().values()], |
|
) |
|
delete_preset_button.click(_delete_preset, load_preset_name, [load_preset_name]) |
|
reload_presets_button.click( |
|
update_dropdown, inputs=[], outputs=[load_preset_name] |
|
) |
|
|
|
return init |
|
|