1lint
fix and revise app
0d0a1c2
raw
history blame
8.8 kB
import gradio as gr
from multiprocessing import cpu_count
from pathlib import Path
from src.ui_shared import (
model_ids,
scheduler_names,
default_scheduler,
controlnet_ids,
assets_directory,
)
from src.ui_functions import generate, run_training
default_img_size = 512
with open(f"{assets_directory}/header.MD") as fp:
header = fp.read()
with open(f"{assets_directory}/footer.MD") as fp:
footer = fp.read()
theme = gr.themes.Soft(
primary_hue="blue",
neutral_hue="slate",
)
with gr.Blocks(theme=theme) as demo:
header_component = gr.Markdown(header)
with gr.Row().style(equal_height=True):
with gr.Column(scale=70):
prompt = gr.Textbox(
label="Prompt", placeholder="Press <Shift+Enter> to generate", lines=2
)
neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="", lines=2)
with gr.Row():
controlnet_prompt = gr.Textbox(
label="Controlnet Prompt",
placeholder="If empty, defaults to base `Prompt`",
lines=2,
)
controlnet_negative_prompt = gr.Textbox(
label="Controlnet Negative Prompt",
placeholder="If empty, defaults to base `Negative Prompt`",
lines=2,
)
with gr.Column(scale=30):
model_name = gr.Dropdown(
label="Model", choices=model_ids, value=model_ids[0], allow_custom_value=True
)
controlnet_name = gr.Dropdown(
label="Controlnet", choices=controlnet_ids, value=controlnet_ids[0], allow_custom_value=True
)
scheduler_name = gr.Dropdown(
label="Scheduler", choices=scheduler_names, value=default_scheduler, allow_custom_value=True
)
with gr.Row():
generate_button = gr.Button(value="Generate", variant="primary")
dark_mode_btn = gr.Button("Dark Mode", variant="secondary")
with gr.Row():
with gr.Column():
with gr.Tab("Inference") as tab:
guidance_image = gr.Image(
label="Guidance Image",
source="upload",
tool="editor",
type="pil",
).style(height=256)
with gr.Row():
controlnet_cond_scale = gr.Slider(
label="Controlnet Weight",
value=0.5,
minimum=0.0,
maximum=1.0,
step=0.1,
)
with gr.Row():
batch_size = gr.Slider(
label="Batch Size", value=1, minimum=1, maximum=8, step=1
)
seed = gr.Slider(-1, 2147483647, label="Seed", value=-1, step=1)
with gr.Row():
guidance = gr.Slider(
label="Guidance scale", value=7.5, minimum=0, maximum=20
)
steps = gr.Slider(
label="Steps", value=20, minimum=1, maximum=100, step=1
)
with gr.Row():
width = gr.Slider(
label="Width",
value=default_img_size,
minimum=64,
maximum=1024,
step=32,
)
height = gr.Slider(
label="Height",
value=default_img_size,
minimum=64,
maximum=1024,
step=32,
)
with gr.Tab("Train Anime ControlNet") as tab:
with gr.Row():
train_batch_size = gr.Slider(
label="Training Batch Size",
minimum=1,
maximum=8,
step=1,
value=1,
)
gradient_accumulation_steps = gr.Slider(
label="Gradient Accumulation steps",
minimum=1,
maximum=6,
step=1,
value=4,
)
with gr.Row():
num_train_epochs = gr.Number(
label="Total training epochs", value=2
)
train_learning_rate = gr.Number(label="Learning Rate", value=5.0e-6)
with gr.Row():
checkpointing_steps = gr.Number(
label="Steps between saving checkpoints", value=4000
)
image_logging_steps = gr.Number(
label="Steps between logging example images (pass 0 to disable)",
value=0,
)
with gr.Row():
train_data_dir = gr.Textbox(
label=f"Path to training image folder",
value="lint/anybooru",
)
valid_data_dir = gr.Textbox(
label=f"Path to validation image folder",
value="",
)
with gr.Row():
controlnet_weights_path = gr.Textbox(
label=f"Repo for initializing Controlnet Weights",
value="lint/anime_control/anime_merge",
)
output_dir = gr.Textbox(
label=f"Output directory for trained weights", value="./models"
)
with gr.Row():
train_whole_controlnet = gr.Checkbox(
label="Train whole controlnet", value=True
)
save_whole_pipeline = gr.Checkbox(
label="Save whole pipeline", value=True
)
training_button = gr.Button(
value="Train Style ControlNet", variant="primary"
)
training_status = gr.Text(label="Training Status")
with gr.Column():
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(height=default_img_size, grid=2)
generation_details = gr.Markdown()
# pipe_kwargs = gr.Textbox(label="Pipe kwargs", value="{\n\t\n}", visible=False)
# if torch.cuda.is_available():
# giga = 2**30
# vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1)
# demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False)
footer_component = gr.Markdown(footer)
inputs = [
model_name,
guidance_image,
controlnet_name,
scheduler_name,
prompt,
guidance,
steps,
batch_size,
width,
height,
seed,
neg_prompt,
controlnet_prompt,
controlnet_negative_prompt,
controlnet_cond_scale,
# pipe_kwargs,
]
outputs = [gallery, generation_details]
prompt.submit(generate, inputs=inputs, outputs=outputs)
generate_button.click(generate, inputs=inputs, outputs=outputs)
training_inputs = [
model_name,
controlnet_weights_path,
train_data_dir,
valid_data_dir,
train_batch_size,
train_whole_controlnet,
gradient_accumulation_steps,
num_train_epochs,
train_learning_rate,
output_dir,
checkpointing_steps,
image_logging_steps,
save_whole_pipeline,
]
training_button.click(
run_training,
inputs=training_inputs,
outputs=[training_status],
)
# from gradio.themes.builder
toggle_dark_mode_args = dict(
fn=None,
inputs=None,
outputs=None,
_js="""() => {
if (document.querySelectorAll('.dark').length) {
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
} else {
document.querySelector('body').classList.add('dark');
}
}""",
)
demo.load(**toggle_dark_mode_args)
dark_mode_btn.click(**toggle_dark_mode_args)
if __name__ == "__main__":
demo.queue(concurrency_count=cpu_count()).launch(favicon_path=favicon_path)