import gradio as gr from multiprocessing import cpu_count from pathlib import Path from src.ui_shared import ( model_ids, scheduler_names, default_scheduler, controlnet_ids, assets_directory, ) from src.ui_functions import generate, run_training default_img_size = 512 with open(f"{assets_directory}/header.MD") as fp: header = fp.read() with open(f"{assets_directory}/footer.MD") as fp: footer = fp.read() theme = gr.themes.Soft( primary_hue="blue", neutral_hue="slate", ) with gr.Blocks(theme=theme) as demo: header_component = gr.Markdown(header) with gr.Row().style(equal_height=True): with gr.Column(scale=70): prompt = gr.Textbox( label="Prompt", placeholder="Press to generate", lines=2 ) neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="", lines=2) with gr.Row(): controlnet_prompt = gr.Textbox( label="Controlnet Prompt", placeholder="If empty, defaults to base `Prompt`", lines=2, ) controlnet_negative_prompt = gr.Textbox( label="Controlnet Negative Prompt", placeholder="If empty, defaults to base `Negative Prompt`", lines=2, ) with gr.Column(scale=30): model_name = gr.Dropdown( label="Model", choices=model_ids, value=model_ids[0], allow_custom_value=True ) controlnet_name = gr.Dropdown( label="Controlnet", choices=controlnet_ids, value=controlnet_ids[0], allow_custom_value=True ) scheduler_name = gr.Dropdown( label="Scheduler", choices=scheduler_names, value=default_scheduler, allow_custom_value=True ) with gr.Row(): generate_button = gr.Button(value="Generate", variant="primary") dark_mode_btn = gr.Button("Dark Mode", variant="secondary") with gr.Row(): with gr.Column(): with gr.Tab("Inference") as tab: guidance_image = gr.Image( label="Guidance Image", source="upload", tool="editor", type="pil", ).style(height=256) with gr.Row(): controlnet_cond_scale = gr.Slider( label="Controlnet Weight", value=0.5, minimum=0.0, maximum=1.0, step=0.1, ) with gr.Row(): batch_size = gr.Slider( label="Batch Size", value=1, minimum=1, maximum=8, step=1 ) seed = gr.Slider(-1, 2147483647, label="Seed", value=-1, step=1) with gr.Row(): guidance = gr.Slider( label="Guidance scale", value=7.5, minimum=0, maximum=20 ) steps = gr.Slider( label="Steps", value=20, minimum=1, maximum=100, step=1 ) with gr.Row(): width = gr.Slider( label="Width", value=default_img_size, minimum=64, maximum=1024, step=32, ) height = gr.Slider( label="Height", value=default_img_size, minimum=64, maximum=1024, step=32, ) with gr.Tab("Train Anime ControlNet") as tab: with gr.Row(): train_batch_size = gr.Slider( label="Training Batch Size", minimum=1, maximum=8, step=1, value=1, ) gradient_accumulation_steps = gr.Slider( label="Gradient Accumulation steps", minimum=1, maximum=6, step=1, value=4, ) with gr.Row(): num_train_epochs = gr.Number( label="Total training epochs", value=2 ) train_learning_rate = gr.Number(label="Learning Rate", value=5.0e-6) with gr.Row(): checkpointing_steps = gr.Number( label="Steps between saving checkpoints", value=4000 ) image_logging_steps = gr.Number( label="Steps between logging example images (pass 0 to disable)", value=0, ) with gr.Row(): train_data_dir = gr.Textbox( label=f"Path to training image folder", value="lint/anybooru", ) valid_data_dir = gr.Textbox( label=f"Path to validation image folder", value="", ) with gr.Row(): controlnet_weights_path = gr.Textbox( label=f"Repo for initializing Controlnet Weights", value="lint/anime_control/anime_merge", ) output_dir = gr.Textbox( label=f"Output directory for trained weights", value="./models" ) with gr.Row(): train_whole_controlnet = gr.Checkbox( label="Train whole controlnet", value=True ) save_whole_pipeline = gr.Checkbox( label="Save whole pipeline", value=True ) training_button = gr.Button( value="Train Style ControlNet", variant="primary" ) training_status = gr.Text(label="Training Status") with gr.Column(): gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(height=default_img_size, grid=2) generation_details = gr.Markdown() # pipe_kwargs = gr.Textbox(label="Pipe kwargs", value="{\n\t\n}", visible=False) # if torch.cuda.is_available(): # giga = 2**30 # vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1) # demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False) footer_component = gr.Markdown(footer) inputs = [ model_name, guidance_image, controlnet_name, scheduler_name, prompt, guidance, steps, batch_size, width, height, seed, neg_prompt, controlnet_prompt, controlnet_negative_prompt, controlnet_cond_scale, # pipe_kwargs, ] outputs = [gallery, generation_details] prompt.submit(generate, inputs=inputs, outputs=outputs) generate_button.click(generate, inputs=inputs, outputs=outputs) training_inputs = [ model_name, controlnet_weights_path, train_data_dir, valid_data_dir, train_batch_size, train_whole_controlnet, gradient_accumulation_steps, num_train_epochs, train_learning_rate, output_dir, checkpointing_steps, image_logging_steps, save_whole_pipeline, ] training_button.click( run_training, inputs=training_inputs, outputs=[training_status], ) # from gradio.themes.builder toggle_dark_mode_args = dict( fn=None, inputs=None, outputs=None, _js="""() => { if (document.querySelectorAll('.dark').length) { document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark')); } else { document.querySelector('body').classList.add('dark'); } }""", ) demo.load(**toggle_dark_mode_args) dark_mode_btn.click(**toggle_dark_mode_args) if __name__ == "__main__": demo.queue(concurrency_count=cpu_count()).launch(favicon_path=favicon_path)