File size: 5,467 Bytes
0163a2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import argparse
import inspect
import os
from pathlib import Path
import toml
from kohya_ss.library import train_util, config_util
import gradio as gr
from scripts.shared import ROOT_DIR
from scripts.utilities import gradio_to_args
PRESET_DIR = os.path.join(ROOT_DIR, "presets")
PRESET_PATH = os.path.join(ROOT_DIR, "presets.json")
def get_arg_templates(fn):
parser = argparse.ArgumentParser()
args = [parser]
sig = inspect.signature(fn)
args.extend([True] * (len(sig.parameters) - 1))
fn(*args)
keys = [
x.replace("--", "") for x in parser.__dict__["_option_string_actions"].keys()
]
keys = [x for x in keys if x not in ["help", "-h"]]
return keys, fn.__name__.replace("add_", "")
arguments_functions = [
train_util.add_dataset_arguments,
train_util.add_optimizer_arguments,
train_util.add_sd_models_arguments,
train_util.add_sd_saving_arguments,
train_util.add_training_arguments,
config_util.add_config_arguments,
]
arg_templates = [get_arg_templates(x) for x in arguments_functions]
def load_presets():
obj = {}
os.makedirs(PRESET_DIR, exist_ok=True)
preset_names = os.listdir(PRESET_DIR)
for preset_name in preset_names:
preset_path = os.path.join(PRESET_DIR, preset_name)
obj[preset_name] = {}
for key in os.listdir(preset_path):
key = key.replace(".toml", "")
obj[preset_name][key] = load_preset(preset_name, key)
return obj
def load_preset(key, name):
filepath = os.path.join(PRESET_DIR, key, name + ".toml")
if not os.path.exists(filepath):
return {}
with open(filepath, mode="r") as f:
obj = toml.load(f)
flatten = {}
for k, v in obj.items():
if not isinstance(v, dict):
flatten[k] = v
else:
for k2, v2 in v.items():
flatten[k2] = v2
return flatten
def save_preset(key, name, value):
obj = {}
for k, v in value.items():
if isinstance(v, Path):
v = str(v)
for (template, category) in arg_templates:
if k in template:
if category not in obj:
obj[category] = {}
obj[category][k] = v
break
else:
obj[k] = v
filepath = os.path.join(PRESET_DIR, key, name + ".toml")
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, mode="w") as f:
toml.dump(obj, f)
def delete_preset(key, name):
filepath = os.path.join(PRESET_DIR, key, name + ".toml")
if os.path.exists(filepath):
os.remove(filepath)
def create_ui(key, tmpls, opts):
get_templates = lambda: tmpls() if callable(tmpls) else tmpls
get_options = lambda: opts() if callable(opts) else opts
presets = load_presets()
if key not in presets:
presets[key] = {}
with gr.Box():
with gr.Row():
with gr.Column() as c:
load_preset_button = gr.Button("Load preset", variant="primary")
delete_preset_button = gr.Button("Delete preset")
with gr.Column() as c:
load_preset_name = gr.Dropdown(
list(presets[key].keys()), show_label=False
).style(container=False)
reload_presets_button = gr.Button("🔄️")
with gr.Column() as c:
c.scale = 0.5
save_preset_name = gr.Textbox(
"", placeholder="Preset name", lines=1, show_label=False
).style(container=False)
save_preset_button = gr.Button("Save preset", variant="primary")
def update_dropdown():
presets = load_presets()
if key not in presets:
presets[key] = {}
return gr.Dropdown.update(choices=list(presets[key].keys()))
def _save_preset(args):
name = args[save_preset_name]
if not name:
return update_dropdown()
args = gradio_to_args(get_templates(), get_options(), args)
save_preset(key, name, args)
return update_dropdown()
def _load_preset(args):
name = args[load_preset_name]
if not name:
return update_dropdown()
args = gradio_to_args(get_templates(), get_options(), args)
preset = load_preset(key, name)
result = []
for k, _ in args.items():
if k == load_preset_name:
continue
if k not in preset:
result.append(None)
continue
v = preset[k]
if type(v) == list:
v = " ".join(v)
result.append(v)
return result[0] if len(result) == 1 else result
def _delete_preset(name):
if not name:
return update_dropdown()
delete_preset(key, name)
return update_dropdown()
def init():
save_preset_button.click(
_save_preset,
set([save_preset_name, *get_options().values()]),
[load_preset_name],
)
load_preset_button.click(
_load_preset,
set([load_preset_name, *get_options().values()]),
[*get_options().values()],
)
delete_preset_button.click(_delete_preset, load_preset_name, [load_preset_name])
reload_presets_button.click(
update_dropdown, inputs=[], outputs=[load_preset_name]
)
return init
|