import gradio as gr from modules import script_callbacks from modules import sd_models, sd_vae from modules.ui import create_refresh_button from scripts import convert def gr_show(visible=True): return {"visible": visible, "__type__": "update"} def add_tab(): with gr.Blocks(analytics_enabled=False) as ui: with gr.Row(equal_height=True): with gr.Column(variant='panel'): gr.HTML(value="
Converted checkpoints will be saved in your checkpoint directory.
") with gr.Tabs(): with gr.TabItem(label='Single process'): with gr.Row(): model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="model_converter_model_name", label="Model") create_refresh_button(model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_Z") custom_name = gr.Textbox(label="Custom Name (Optional)") with gr.TabItem(label='Input file path'): with gr.Row(): model_path = gr.Textbox(label="model path") with gr.TabItem(label='Batch from directory'): with gr.Row(): input_directory = gr.Textbox(label="Input Directory") with gr.Row(): precision = gr.Radio(choices=["fp32", "fp16", "bf16"], value="fp16", label="Precision") m_type = gr.Radio(choices=["disabled", "no-ema", "ema-only"], value="disabled", label="Pruning Methods") with gr.Row(): checkpoint_formats = gr.CheckboxGroup(choices=["ckpt", "safetensors"], value=["safetensors"], label="Checkpoint Format") show_extra_options = gr.Checkbox(label="Show extra options", value=False) with gr.Row(): bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE") create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "model_converter_refresh_bake_in_vae") with gr.Row(): force_position_id = gr.Checkbox(label="Force CLIP position_id to int64 before convert", value=True) fix_clip = gr.Checkbox(label="Fix clip", value=False) delete_known_junk_data = gr.Checkbox(label="Delete known junk data", value=False) with gr.Row(visible=False) as extra_options: specific_part_conv = ["copy", "convert", "delete"] unet_conv = gr.Dropdown(specific_part_conv, value="convert", label="unet") text_encoder_conv = gr.Dropdown(specific_part_conv, value="convert", label="text encoder") vae_conv = gr.Dropdown(specific_part_conv, value="convert", label="vae") others_conv = gr.Dropdown(specific_part_conv, value="convert", label="others") model_converter_convert = gr.Button(elem_id="model_converter_convert", label="Convert", variant='primary') with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="model_converter_result", show_label=False) show_extra_options.change( fn=lambda x: gr_show(x), inputs=[show_extra_options], outputs=[extra_options], ) model_converter_convert.click( fn=convert.convert_warp, inputs=[ model_name, model_path, input_directory, checkpoint_formats, precision, m_type, custom_name, bake_in_vae, unet_conv, text_encoder_conv, vae_conv, others_conv, fix_clip, force_position_id, delete_known_junk_data ], outputs=[submit_result] ) return [(ui, "Model Converter", "model_converter")] script_callbacks.on_ui_tabs(add_tab)