Spaces:
Runtime error
Runtime error
import gradio as gr | |
from multiprocessing import cpu_count | |
from pathlib import Path | |
from src.ui_shared import ( | |
model_ids, | |
scheduler_names, | |
default_scheduler, | |
controlnet_ids, | |
assets_directory, | |
) | |
from src.ui_functions import generate, run_training | |
default_img_size = 512 | |
with open(f"{assets_directory}/header.MD") as fp: | |
header = fp.read() | |
with open(f"{assets_directory}/footer.MD") as fp: | |
footer = fp.read() | |
theme = gr.themes.Soft( | |
primary_hue="blue", | |
neutral_hue="slate", | |
) | |
with gr.Blocks(theme=theme) as demo: | |
header_component = gr.Markdown(header) | |
with gr.Row().style(equal_height=True): | |
with gr.Column(scale=70): | |
prompt = gr.Textbox( | |
label="Prompt", placeholder="Press <Shift+Enter> to generate", lines=2 | |
) | |
neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="", lines=2) | |
with gr.Row(): | |
controlnet_prompt = gr.Textbox( | |
label="Controlnet Prompt", | |
placeholder="If empty, defaults to base `Prompt`", | |
lines=2, | |
) | |
controlnet_negative_prompt = gr.Textbox( | |
label="Controlnet Negative Prompt", | |
placeholder="If empty, defaults to base `Negative Prompt`", | |
lines=2, | |
) | |
with gr.Column(scale=30): | |
model_name = gr.Dropdown( | |
label="Model", choices=model_ids, value=model_ids[0], allow_custom_value=True | |
) | |
controlnet_name = gr.Dropdown( | |
label="Controlnet", choices=controlnet_ids, value=controlnet_ids[0], allow_custom_value=True | |
) | |
scheduler_name = gr.Dropdown( | |
label="Scheduler", choices=scheduler_names, value=default_scheduler, allow_custom_value=True | |
) | |
with gr.Row(): | |
generate_button = gr.Button(value="Generate", variant="primary") | |
dark_mode_btn = gr.Button("Dark Mode", variant="secondary") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tab("Inference") as tab: | |
guidance_image = gr.Image( | |
label="Guidance Image", | |
source="upload", | |
tool="editor", | |
type="pil", | |
).style(height=256) | |
with gr.Row(): | |
controlnet_cond_scale = gr.Slider( | |
label="Controlnet Weight", | |
value=1.0, | |
minimum=0.0, | |
maximum=2.0, | |
step=0.1, | |
) | |
with gr.Row(): | |
batch_size = gr.Slider( | |
label="Batch Size", value=1, minimum=1, maximum=4, step=1 | |
) | |
seed = gr.Slider(-1, 2147483647, label="Seed", value=-1, step=1) | |
with gr.Row(): | |
guidance = gr.Slider( | |
label="Guidance scale", value=7.5, minimum=0, maximum=20 | |
) | |
steps = gr.Slider( | |
label="Steps", value=20, minimum=1, maximum=100, step=1 | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
value=default_img_size, | |
minimum=64, | |
maximum=1024, | |
step=32, | |
) | |
height = gr.Slider( | |
label="Height", | |
value=default_img_size, | |
minimum=64, | |
maximum=1024, | |
step=32, | |
) | |
with gr.Tab("Train Anime ControlNet") as tab: | |
with gr.Row(): | |
train_batch_size = gr.Slider( | |
label="Training Batch Size", | |
minimum=1, | |
maximum=8, | |
step=1, | |
value=1, | |
) | |
gradient_accumulation_steps = gr.Slider( | |
label="Gradient Accumulation steps", | |
minimum=1, | |
maximum=6, | |
step=1, | |
value=4, | |
) | |
with gr.Row(): | |
num_train_epochs = gr.Number( | |
label="Total training epochs", value=2 | |
) | |
train_learning_rate = gr.Number(label="Learning Rate", value=5.0e-6) | |
with gr.Row(): | |
checkpointing_steps = gr.Number( | |
label="Steps between saving checkpoints", value=4000 | |
) | |
image_logging_steps = gr.Number( | |
label="Steps between logging example images (pass 0 to disable)", | |
value=0, | |
) | |
with gr.Row(): | |
train_data_dir = gr.Textbox( | |
label=f"Path to training image folder", | |
value="lint/anybooru", | |
) | |
valid_data_dir = gr.Textbox( | |
label=f"Path to validation image folder", | |
value="", | |
) | |
with gr.Row(): | |
controlnet_weights_path = gr.Textbox( | |
label=f"Repo for initializing Controlnet Weights", | |
value="lint/anime_control/anime_merge", | |
) | |
output_dir = gr.Textbox( | |
label=f"Output directory for trained weights", value="./models" | |
) | |
with gr.Row(): | |
train_whole_controlnet = gr.Checkbox( | |
label="Train whole controlnet", value=True | |
) | |
save_whole_pipeline = gr.Checkbox( | |
label="Save whole pipeline", value=True | |
) | |
training_button = gr.Button( | |
value="Train Anime ControlNet", variant="primary" | |
) | |
training_status = gr.Text(label="Training Status") | |
with gr.Column(): | |
gallery = gr.Gallery( | |
label="Generated images", show_label=False, elem_id="gallery" | |
).style(height=default_img_size, grid=2) | |
generation_details = gr.Markdown() | |
# pipe_kwargs = gr.Textbox(label="Pipe kwargs", value="{\n\t\n}", visible=False) | |
# if torch.cuda.is_available(): | |
# giga = 2**30 | |
# vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1) | |
# demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False) | |
footer_component = gr.Markdown(footer) | |
inputs = [ | |
model_name, | |
guidance_image, | |
controlnet_name, | |
scheduler_name, | |
prompt, | |
guidance, | |
steps, | |
batch_size, | |
width, | |
height, | |
seed, | |
neg_prompt, | |
controlnet_prompt, | |
controlnet_negative_prompt, | |
controlnet_cond_scale, | |
# pipe_kwargs, | |
] | |
outputs = [gallery, generation_details] | |
prompt.submit(generate, inputs=inputs, outputs=outputs) | |
generate_button.click(generate, inputs=inputs, outputs=outputs) | |
training_inputs = [ | |
model_name, | |
controlnet_weights_path, | |
train_data_dir, | |
valid_data_dir, | |
train_batch_size, | |
train_whole_controlnet, | |
gradient_accumulation_steps, | |
num_train_epochs, | |
train_learning_rate, | |
output_dir, | |
checkpointing_steps, | |
image_logging_steps, | |
save_whole_pipeline, | |
] | |
training_button.click( | |
run_training, | |
inputs=training_inputs, | |
outputs=[training_status], | |
) | |
# from gradio.themes.builder | |
toggle_dark_mode_args = dict( | |
fn=None, | |
inputs=None, | |
outputs=None, | |
_js="""() => { | |
if (document.querySelectorAll('.dark').length) { | |
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark')); | |
} else { | |
document.querySelector('body').classList.add('dark'); | |
} | |
}""", | |
) | |
demo.load(**toggle_dark_mode_args) | |
dark_mode_btn.click(**toggle_dark_mode_args) | |
if __name__ == "__main__": | |
demo.queue(concurrency_count=cpu_count()).launch(favicon_path=favicon_path) | |