|
|
|
"""Unofficial demo app for https://github.com/cloneofsimo/lora. |
|
|
|
The code in this repo is partly adapted from the following repository: |
|
https://huggingface.co/spaces/multimodalart/dreambooth-training/tree/a00184917aa273c6d8adab08d5deb9b39b997938 |
|
The license of the original code is MIT, which is specified in the README.md. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import os |
|
import pathlib |
|
|
|
import gradio as gr |
|
import torch |
|
|
|
from inference import InferencePipeline |
|
from trainer import Trainer |
|
from uploader import upload |
|
|
|
TITLE = '# LoRA + StableDiffusion Training UI' |
|
DESCRIPTION = 'This is an unofficial demo for [https://github.com/cloneofsimo/lora](https://github.com/cloneofsimo/lora).' |
|
|
|
ORIGINAL_SPACE_ID = 'hysts/LoRA-SD-training' |
|
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID) |
|
SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU. |
|
|
|
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center> |
|
''' |
|
if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID: |
|
SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>' |
|
|
|
else: |
|
SETTINGS = 'Settings' |
|
CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU. |
|
<center> |
|
You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces. |
|
"T4 small" is sufficient to run this demo. |
|
</center> |
|
''' |
|
|
|
|
|
def show_warning(warning_text: str) -> gr.Blocks: |
|
with gr.Blocks() as demo: |
|
with gr.Box(): |
|
gr.Markdown(warning_text) |
|
return demo |
|
|
|
|
|
def update_output_files() -> dict: |
|
paths = sorted(pathlib.Path('results').glob('*.pt')) |
|
paths = [path.as_posix() for path in paths] |
|
return gr.update(value=paths or None) |
|
|
|
|
|
def create_training_demo(trainer: Trainer, |
|
pipe: InferencePipeline) -> gr.Blocks: |
|
with gr.Blocks() as demo: |
|
base_model = gr.Dropdown( |
|
choices=['stabilityai/stable-diffusion-2-1-base'], |
|
value='stabilityai/stable-diffusion-2-1-base', |
|
label='Base Model', |
|
visible=False) |
|
resolution = gr.Dropdown(choices=['512'], |
|
value='512', |
|
label='Resolution', |
|
visible=False) |
|
|
|
with gr.Row(): |
|
with gr.Box(): |
|
gr.Markdown('Training Data') |
|
concept_images = gr.Files(label='Images for your concept') |
|
concept_prompt = gr.Textbox(label='Concept Prompt', |
|
max_lines=1) |
|
gr.Markdown(''' |
|
- Upload images of the style you are planning on training on. |
|
- For a concept prompt, use a unique, made up word to avoid collisions. |
|
''') |
|
with gr.Box(): |
|
gr.Markdown('Training Parameters') |
|
num_training_steps = gr.Number( |
|
label='Number of Training Steps', value=1000, precision=0) |
|
learning_rate = gr.Number(label='Learning Rate', value=0.0001) |
|
train_text_encoder = gr.Checkbox(label='Train Text Encoder', |
|
value=True) |
|
learning_rate_text = gr.Number( |
|
label='Learning Rate for Text Encoder', value=0.00005) |
|
gradient_accumulation = gr.Number( |
|
label='Number of Gradient Accumulation', |
|
value=1, |
|
precision=0) |
|
fp16 = gr.Checkbox(label='FP16', value=True) |
|
use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True) |
|
gr.Markdown(''' |
|
- It will take about 8 minutes to train for 1000 steps with a T4 GPU. |
|
- You may want to try a small number of steps first, like 1, to see if everything works fine in your environment. |
|
- Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab. |
|
''') |
|
|
|
run_button = gr.Button('Start Training') |
|
with gr.Box(): |
|
with gr.Row(): |
|
check_status_button = gr.Button('Check Training Status') |
|
with gr.Column(): |
|
with gr.Box(): |
|
gr.Markdown('Message') |
|
training_status = gr.Markdown() |
|
output_files = gr.Files(label='Trained Weight Files') |
|
|
|
run_button.click(fn=pipe.clear) |
|
run_button.click(fn=trainer.run, |
|
inputs=[ |
|
base_model, |
|
resolution, |
|
concept_images, |
|
concept_prompt, |
|
num_training_steps, |
|
learning_rate, |
|
train_text_encoder, |
|
learning_rate_text, |
|
gradient_accumulation, |
|
fp16, |
|
use_8bit_adam, |
|
], |
|
outputs=[ |
|
training_status, |
|
output_files, |
|
], |
|
queue=False) |
|
check_status_button.click(fn=trainer.check_if_running, |
|
inputs=None, |
|
outputs=training_status, |
|
queue=False) |
|
check_status_button.click(fn=update_output_files, |
|
inputs=None, |
|
outputs=output_files, |
|
queue=False) |
|
return demo |
|
|
|
|
|
def find_weight_files() -> list[str]: |
|
curr_dir = pathlib.Path(__file__).parent |
|
paths = sorted(curr_dir.rglob('*.pt')) |
|
paths = [path for path in paths if not path.stem.endswith('.text_encoder')] |
|
return [path.relative_to(curr_dir).as_posix() for path in paths] |
|
|
|
|
|
def reload_lora_weight_list() -> dict: |
|
return gr.update(choices=find_weight_files()) |
|
|
|
|
|
def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks: |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
base_model = gr.Dropdown( |
|
choices=['stabilityai/stable-diffusion-2-1-base'], |
|
value='stabilityai/stable-diffusion-2-1-base', |
|
label='Base Model', |
|
visible=False) |
|
reload_button = gr.Button('Reload Weight List') |
|
lora_weight_name = gr.Dropdown(choices=find_weight_files(), |
|
value='lora/lora_disney.pt', |
|
label='LoRA Weight File') |
|
prompt = gr.Textbox( |
|
label='Prompt', |
|
max_lines=1, |
|
placeholder='Example: "style of sks, baby lion"') |
|
alpha = gr.Slider(label='Alpha', |
|
minimum=0, |
|
maximum=2, |
|
step=0.05, |
|
value=1) |
|
alpha_for_text = gr.Slider(label='Alpha for Text Encoder', |
|
minimum=0, |
|
maximum=2, |
|
step=0.05, |
|
value=1) |
|
seed = gr.Slider(label='Seed', |
|
minimum=0, |
|
maximum=100000, |
|
step=1, |
|
value=1) |
|
with gr.Accordion('Other Parameters', open=False): |
|
num_steps = gr.Slider(label='Number of Steps', |
|
minimum=0, |
|
maximum=100, |
|
step=1, |
|
value=50) |
|
guidance_scale = gr.Slider(label='CFG Scale', |
|
minimum=0, |
|
maximum=50, |
|
step=0.1, |
|
value=7) |
|
|
|
run_button = gr.Button('Generate') |
|
|
|
gr.Markdown(''' |
|
- Models with names starting with "lora/" are the pretrained models provided in the [original repo](https://github.com/cloneofsimo/lora), and the ones with names starting with "results/" are your trained models. |
|
- After training, you can press "Reload Weight List" button to load your trained model names. |
|
- The pretrained models for "disney", "illust" and "pop" are trained with the concept prompt "style of sks". |
|
- The pretrained model for "kiriko" is trained with the concept prompt "game character bnha". For this model, the text encoder is also trained. |
|
''') |
|
with gr.Column(): |
|
result = gr.Image(label='Result') |
|
|
|
reload_button.click(fn=reload_lora_weight_list, |
|
inputs=None, |
|
outputs=lora_weight_name) |
|
prompt.submit(fn=pipe.run, |
|
inputs=[ |
|
base_model, |
|
lora_weight_name, |
|
prompt, |
|
alpha, |
|
alpha_for_text, |
|
seed, |
|
num_steps, |
|
guidance_scale, |
|
], |
|
outputs=result, |
|
queue=False) |
|
run_button.click(fn=pipe.run, |
|
inputs=[ |
|
base_model, |
|
lora_weight_name, |
|
prompt, |
|
alpha, |
|
alpha_for_text, |
|
seed, |
|
num_steps, |
|
guidance_scale, |
|
], |
|
outputs=result, |
|
queue=False) |
|
return demo |
|
|
|
|
|
def create_upload_demo() -> gr.Blocks: |
|
with gr.Blocks() as demo: |
|
model_name = gr.Textbox(label='Model Name') |
|
hf_token = gr.Textbox( |
|
label='Hugging Face Token (with write permission)') |
|
upload_button = gr.Button('Upload') |
|
with gr.Box(): |
|
gr.Markdown('Message') |
|
result = gr.Markdown() |
|
gr.Markdown(''' |
|
- You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}). |
|
- You can find your Hugging Face token [here](https://huggingface.co/settings/tokens). |
|
''') |
|
|
|
upload_button.click(fn=upload, |
|
inputs=[model_name, hf_token], |
|
outputs=result) |
|
|
|
return demo |
|
|
|
|
|
pipe = InferencePipeline() |
|
trainer = Trainer() |
|
|
|
with gr.Blocks(css='style.css') as demo: |
|
if os.getenv('IS_SHARED_UI'): |
|
show_warning(SHARED_UI_WARNING) |
|
if not torch.cuda.is_available(): |
|
show_warning(CUDA_NOT_AVAILABLE_WARNING) |
|
|
|
gr.Markdown(TITLE) |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem('Train'): |
|
create_training_demo(trainer, pipe) |
|
with gr.TabItem('Test'): |
|
create_inference_demo(pipe) |
|
with gr.TabItem('Upload'): |
|
create_upload_demo() |
|
|
|
demo.queue(default_enabled=False).launch(share=False) |
|
|