ehristoforu's picture
Upload folder using huggingface_hub
0163a2c verified
raw
history blame
3.47 kB
import argparse
import gradio as gr
from kohya_ss.library import train_util, config_util
from scripts import presets, ui, ui_overrides
from scripts.runner import initialize_runner
from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio
def title():
return "Train dreambooth"
def create_ui():
sd_models_arguments = argparse.ArgumentParser()
dataset_arguments = argparse.ArgumentParser()
training_arguments = argparse.ArgumentParser()
sd_saving_arguments = argparse.ArgumentParser()
optimizer_arguments = argparse.ArgumentParser()
config_arguments = argparse.ArgumentParser()
train_util.add_sd_models_arguments(sd_models_arguments)
train_util.add_dataset_arguments(dataset_arguments, True, False, True)
train_util.add_training_arguments(training_arguments, True)
train_util.add_sd_saving_arguments(sd_saving_arguments)
train_util.add_optimizer_arguments(optimizer_arguments)
config_util.add_config_arguments(config_arguments)
sd_models_options = {}
dataset_options = {}
training_options = {}
sd_saving_options = {}
optimizer_options = {}
config_options = {}
dreambooth_options = {}
templates, script_file = load_args_template("train_db.py")
get_options = lambda: {
**sd_models_options,
**dataset_options,
**training_options,
**sd_saving_options,
**optimizer_options,
**config_options,
**dreambooth_options,
}
get_templates = lambda: {
**sd_models_arguments.__dict__["_option_string_actions"],
**dataset_arguments.__dict__["_option_string_actions"],
**training_arguments.__dict__["_option_string_actions"],
**sd_saving_arguments.__dict__["_option_string_actions"],
**optimizer_arguments.__dict__["_option_string_actions"],
**config_arguments.__dict__["_option_string_actions"],
**templates,
}
with gr.Column():
init_runner = initialize_runner(script_file, get_templates, get_options)
with gr.Box():
with gr.Row():
init_ui = presets.create_ui("train_db", get_templates, get_options)
with gr.Row():
with gr.Group():
with gr.Box():
ui.title("Dreambooth options")
options_to_gradio(templates, dreambooth_options)
with gr.Box():
ui.title("Model options")
args_to_gradio(sd_models_arguments, sd_models_options)
with gr.Box():
ui.title("Dataset options")
args_to_gradio(dataset_arguments, dataset_options)
with gr.Box():
ui.title("Dataset Config options")
args_to_gradio(config_arguments, config_options)
with gr.Box():
ui.title("Training options")
args_to_gradio(training_arguments, training_options)
with gr.Group():
with gr.Box():
ui.title("Save options")
args_to_gradio(sd_saving_arguments, sd_saving_options)
with gr.Box():
ui.title("Optimizer options")
args_to_gradio(
optimizer_arguments,
optimizer_options,
ui_overrides.OPTIMIZER_OPTIONS,
)
init_runner()
init_ui()