File size: 5,467 Bytes
0163a2c |
|
import argparse
import inspect
import os
from pathlib import Path
import toml
from kohya_ss.library import train_util, config_util
import gradio as gr
from scripts.shared import ROOT_DIR
from scripts.utilities import gradio_to_args
PRESET_DIR = os.path.join(ROOT_DIR, "presets")
PRESET_PATH = os.path.join(ROOT_DIR, "presets.json")
def get_arg_templates(fn):
parser = argparse.ArgumentParser()
args = [parser]
sig = inspect.signature(fn)
args.extend([True] * (len(sig.parameters) - 1))
fn(*args)
keys = [
x.replace("--", "") for x in parser.__dict__["_option_string_actions"].keys()
]
keys = [x for x in keys if x not in ["help", "-h"]]
return keys, fn.__name__.replace("add_", "")
arguments_functions = [
train_util.add_dataset_arguments,
train_util.add_optimizer_arguments,
train_util.add_sd_models_arguments,
train_util.add_sd_saving_arguments,
train_util.add_training_arguments,
config_util.add_config_arguments,
]
arg_templates = [get_arg_templates(x) for x in arguments_functions]
def load_presets():
obj = {}
os.makedirs(PRESET_DIR, exist_ok=True)
preset_names = os.listdir(PRESET_DIR)
for preset_name in preset_names:
preset_path = os.path.join(PRESET_DIR, preset_name)
obj[preset_name] = {}
for key in os.listdir(preset_path):
key = key.replace(".toml", "")
obj[preset_name][key] = load_preset(preset_name, key)
return obj
def load_preset(key, name):
filepath = os.path.join(PRESET_DIR, key, name + ".toml")
if not os.path.exists(filepath):
return {}
with open(filepath, mode="r") as f:
obj = toml.load(f)
flatten = {}
for k, v in obj.items():
if not isinstance(v, dict):
flatten[k] = v
else:
for k2, v2 in v.items():
flatten[k2] = v2
return flatten
def save_preset(key, name, value):
obj = {}
for k, v in value.items():
if isinstance(v, Path):
v = str(v)
for (template, category) in arg_templates:
if k in template:
if category not in obj:
obj[category] = {}
obj[category][k] = v
break
else:
obj[k] = v
filepath = os.path.join(PRESET_DIR, key, name + ".toml")
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, mode="w") as f:
toml.dump(obj, f)
def delete_preset(key, name):
filepath = os.path.join(PRESET_DIR, key, name + ".toml")
if os.path.exists(filepath):
os.remove(filepath)
def create_ui(key, tmpls, opts):
get_templates = lambda: tmpls() if callable(tmpls) else tmpls
get_options = lambda: opts() if callable(opts) else opts
presets = load_presets()
if key not in presets:
presets[key] = {}
with gr.Box():
with gr.Row():
with gr.Column() as c:
load_preset_button = gr.Button("Load preset", variant="primary")
delete_preset_button = gr.Button("Delete preset")
with gr.Column() as c:
load_preset_name = gr.Dropdown(
list(presets[key].keys()), show_label=False
).style(container=False)
reload_presets_button = gr.Button("🔄️")
with gr.Column() as c:
c.scale = 0.5
save_preset_name = gr.Textbox(
"", placeholder="Preset name", lines=1, show_label=False
).style(container=False)
save_preset_button = gr.Button("Save preset", variant="primary")
def update_dropdown():
presets = load_presets()
if key not in presets:
presets[key] = {}
return gr.Dropdown.update(choices=list(presets[key].keys()))
def _save_preset(args):
name = args[save_preset_name]
if not name:
return update_dropdown()
args = gradio_to_args(get_templates(), get_options(), args)
save_preset(key, name, args)
return update_dropdown()
def _load_preset(args):
name = args[load_preset_name]
if not name:
return update_dropdown()
args = gradio_to_args(get_templates(), get_options(), args)
preset = load_preset(key, name)
result = []
for k, _ in args.items():
if k == load_preset_name:
continue
if k not in preset:
result.append(None)
continue
v = preset[k]
if type(v) == list:
v = " ".join(v)
result.append(v)
return result[0] if len(result) == 1 else result
def _delete_preset(name):
if not name:
return update_dropdown()
delete_preset(key, name)
return update_dropdown()
def init():
save_preset_button.click(
_save_preset,
set([save_preset_name, *get_options().values()]),
[load_preset_name],
)
load_preset_button.click(
_load_preset,
set([load_preset_name, *get_options().values()]),
[*get_options().values()],
)
delete_preset_button.click(_delete_preset, load_preset_name, [load_preset_name])
reload_presets_button.click(
update_dropdown, inputs=[], outputs=[load_preset_name]
)
return init
|