Spaces:
Runtime error
Runtime error
Commit
·
001c876
0
Parent(s):
Duplicate from lint/sdpipe_webui
Browse filesCo-authored-by: lint <[email protected]>
- .gitattributes +34 -0
- .gitignore +4 -0
- README.md +22 -0
- app.py +225 -0
- html/footer.html +15 -0
- html/header.html +25 -0
- html/style.css +38 -0
- model_ids.txt +6 -0
- requirements.txt +12 -0
- test.ipynb +73 -0
- utils/__init__.py +0 -0
- utils/functions.py +273 -0
- utils/inpaint_pipeline.py +288 -0
- utils/shared.py +16 -0
- utils/textual_inversion.py +916 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
WIP/
|
3 |
+
concept_images/
|
4 |
+
output_model/
|
README.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Sdpipe Webui
|
3 |
+
emoji: 🍌
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.16.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: openrail
|
11 |
+
duplicated_from: lint/sdpipe_webui
|
12 |
+
---
|
13 |
+
|
14 |
+
# **Stable Diffusion Pipeline Web UI**
|
15 |
+
|
16 |
+
Stable Diffusion WebUI with first class support for HuggingFace Diffusers Pipelines and Diffusion Schedulers, made in the style of Automatic1111's WebUI and Evel_Space.
|
17 |
+
|
18 |
+
Supports Huggingface `Text-to-Image`, `Image to Image`, and `Inpainting` pipelines, with fast switching between pipeline modes by reusing loaded model weights already in memory.
|
19 |
+
|
20 |
+
Install requirements with `pip install -r requirements.txt`
|
21 |
+
|
22 |
+
Run with `python app.py`
|
app.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from multiprocessing import cpu_count
|
3 |
+
from utils.functions import generate, train_textual_inversion
|
4 |
+
from utils.shared import model_ids, scheduler_names, default_scheduler
|
5 |
+
|
6 |
+
default_img_size = 512
|
7 |
+
|
8 |
+
with open("html/header.html") as fp:
|
9 |
+
header = fp.read()
|
10 |
+
|
11 |
+
with open("html/footer.html") as fp:
|
12 |
+
footer = fp.read()
|
13 |
+
|
14 |
+
with gr.Blocks(css="html/style.css") as demo:
|
15 |
+
|
16 |
+
pipe_state = gr.State(lambda: 1)
|
17 |
+
|
18 |
+
gr.HTML(header)
|
19 |
+
|
20 |
+
with gr.Row():
|
21 |
+
|
22 |
+
with gr.Column(scale=70):
|
23 |
+
|
24 |
+
# with gr.Row():
|
25 |
+
prompt = gr.Textbox(
|
26 |
+
label="Prompt", placeholder="<Shift+Enter> to generate", lines=2
|
27 |
+
)
|
28 |
+
neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="", lines=2)
|
29 |
+
|
30 |
+
with gr.Column(scale=30):
|
31 |
+
model_name = gr.Dropdown(
|
32 |
+
label="Model", choices=model_ids, value=model_ids[0]
|
33 |
+
)
|
34 |
+
scheduler_name = gr.Dropdown(
|
35 |
+
label="Scheduler", choices=scheduler_names, value=default_scheduler
|
36 |
+
)
|
37 |
+
generate_button = gr.Button(value="Generate", elem_id="generate-button")
|
38 |
+
|
39 |
+
with gr.Row():
|
40 |
+
|
41 |
+
with gr.Column():
|
42 |
+
|
43 |
+
with gr.Tab("Text to Image") as tab:
|
44 |
+
tab.select(lambda: 1, [], pipe_state)
|
45 |
+
|
46 |
+
with gr.Tab("Image to image") as tab:
|
47 |
+
tab.select(lambda: 2, [], pipe_state)
|
48 |
+
|
49 |
+
image = gr.Image(
|
50 |
+
label="Image to Image",
|
51 |
+
source="upload",
|
52 |
+
tool="editor",
|
53 |
+
type="pil",
|
54 |
+
elem_id="image_upload",
|
55 |
+
).style(height=default_img_size)
|
56 |
+
strength = gr.Slider(
|
57 |
+
label="Denoising strength",
|
58 |
+
minimum=0,
|
59 |
+
maximum=1,
|
60 |
+
step=0.02,
|
61 |
+
value=0.8,
|
62 |
+
)
|
63 |
+
|
64 |
+
with gr.Tab("Inpainting") as tab:
|
65 |
+
tab.select(lambda: 3, [], pipe_state)
|
66 |
+
|
67 |
+
inpaint_image = gr.Image(
|
68 |
+
label="Inpainting",
|
69 |
+
source="upload",
|
70 |
+
tool="sketch",
|
71 |
+
type="pil",
|
72 |
+
elem_id="image_upload",
|
73 |
+
).style(height=default_img_size)
|
74 |
+
inpaint_strength = gr.Slider(
|
75 |
+
label="Denoising strength",
|
76 |
+
minimum=0,
|
77 |
+
maximum=1,
|
78 |
+
step=0.02,
|
79 |
+
value=0.8,
|
80 |
+
)
|
81 |
+
inpaint_options = [
|
82 |
+
"preserve non-masked portions of image",
|
83 |
+
"output entire inpainted image",
|
84 |
+
]
|
85 |
+
inpaint_radio = gr.Radio(
|
86 |
+
inpaint_options,
|
87 |
+
value=inpaint_options[0],
|
88 |
+
show_label=False,
|
89 |
+
interactive=True,
|
90 |
+
)
|
91 |
+
|
92 |
+
with gr.Tab("Textual Inversion") as tab:
|
93 |
+
tab.select(lambda: 4, [], pipe_state)
|
94 |
+
|
95 |
+
type_of_thing = gr.Dropdown(
|
96 |
+
label="What would you like to train?",
|
97 |
+
choices=["object", "person", "style"],
|
98 |
+
value="object",
|
99 |
+
interactive=True,
|
100 |
+
)
|
101 |
+
|
102 |
+
text_train_bsz = gr.Slider(
|
103 |
+
label="Training Batch Size",
|
104 |
+
minimum=1,
|
105 |
+
maximum=8,
|
106 |
+
step=1,
|
107 |
+
value=1,
|
108 |
+
)
|
109 |
+
|
110 |
+
files = gr.File(
|
111 |
+
label=f"""Upload the images for your concept""",
|
112 |
+
file_count="multiple",
|
113 |
+
interactive=True,
|
114 |
+
visible=True,
|
115 |
+
)
|
116 |
+
|
117 |
+
text_train_steps = gr.Number(label="How many steps", value=1000)
|
118 |
+
|
119 |
+
text_learning_rate = gr.Number(label="Learning Rate", value=5.0e-4)
|
120 |
+
|
121 |
+
concept_word = gr.Textbox(
|
122 |
+
label=f"""concept word - use a unique, made up word to avoid collisions"""
|
123 |
+
)
|
124 |
+
init_word = gr.Textbox(
|
125 |
+
label=f"""initial word - to init the concept embedding"""
|
126 |
+
)
|
127 |
+
|
128 |
+
textual_inversion_button = gr.Button(value="Train Textual Inversion")
|
129 |
+
|
130 |
+
training_status = gr.Text(label="Training Status")
|
131 |
+
|
132 |
+
with gr.Row():
|
133 |
+
batch_size = gr.Slider(
|
134 |
+
label="Batch Size", value=1, minimum=1, maximum=8, step=1
|
135 |
+
)
|
136 |
+
seed = gr.Slider(-1, 2147483647, label="Seed", value=-1, step=1)
|
137 |
+
|
138 |
+
with gr.Row():
|
139 |
+
guidance = gr.Slider(
|
140 |
+
label="Guidance scale", value=7.5, minimum=0, maximum=20
|
141 |
+
)
|
142 |
+
steps = gr.Slider(
|
143 |
+
label="Steps", value=20, minimum=1, maximum=100, step=1
|
144 |
+
)
|
145 |
+
|
146 |
+
with gr.Row():
|
147 |
+
width = gr.Slider(
|
148 |
+
label="Width",
|
149 |
+
value=default_img_size,
|
150 |
+
minimum=64,
|
151 |
+
maximum=1024,
|
152 |
+
step=32,
|
153 |
+
)
|
154 |
+
height = gr.Slider(
|
155 |
+
label="Height",
|
156 |
+
value=default_img_size,
|
157 |
+
minimum=64,
|
158 |
+
maximum=1024,
|
159 |
+
step=32,
|
160 |
+
)
|
161 |
+
|
162 |
+
with gr.Column():
|
163 |
+
gallery = gr.Gallery(
|
164 |
+
label="Generated images", show_label=False, elem_id="gallery"
|
165 |
+
).style(height=default_img_size, grid=2)
|
166 |
+
|
167 |
+
generation_details = gr.Markdown()
|
168 |
+
|
169 |
+
pipe_kwargs = gr.Textbox(label="Pipe kwargs", value="{\n\t\n}")
|
170 |
+
|
171 |
+
# if torch.cuda.is_available():
|
172 |
+
# giga = 2**30
|
173 |
+
# vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1)
|
174 |
+
# demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False)
|
175 |
+
|
176 |
+
gr.HTML(footer)
|
177 |
+
|
178 |
+
inputs = [
|
179 |
+
model_name,
|
180 |
+
scheduler_name,
|
181 |
+
prompt,
|
182 |
+
guidance,
|
183 |
+
steps,
|
184 |
+
batch_size,
|
185 |
+
width,
|
186 |
+
height,
|
187 |
+
seed,
|
188 |
+
image,
|
189 |
+
strength,
|
190 |
+
inpaint_image,
|
191 |
+
inpaint_strength,
|
192 |
+
inpaint_radio,
|
193 |
+
neg_prompt,
|
194 |
+
pipe_state,
|
195 |
+
pipe_kwargs,
|
196 |
+
]
|
197 |
+
outputs = [gallery, generation_details]
|
198 |
+
|
199 |
+
prompt.submit(generate, inputs=inputs, outputs=outputs)
|
200 |
+
generate_button.click(generate, inputs=inputs, outputs=outputs)
|
201 |
+
|
202 |
+
textual_inversion_inputs = [
|
203 |
+
model_name,
|
204 |
+
scheduler_name,
|
205 |
+
type_of_thing,
|
206 |
+
files,
|
207 |
+
concept_word,
|
208 |
+
init_word,
|
209 |
+
text_train_steps,
|
210 |
+
text_train_bsz,
|
211 |
+
text_learning_rate,
|
212 |
+
]
|
213 |
+
|
214 |
+
textual_inversion_button.click(
|
215 |
+
train_textual_inversion,
|
216 |
+
inputs=textual_inversion_inputs,
|
217 |
+
outputs=[training_status],
|
218 |
+
)
|
219 |
+
|
220 |
+
|
221 |
+
# demo = gr.TabbedInterface([demo, dreambooth_tab], ["Main", "Dreambooth"])
|
222 |
+
|
223 |
+
demo.queue(concurrency_count=cpu_count())
|
224 |
+
|
225 |
+
demo.launch()
|
html/footer.html
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
<!-- based on https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/main/app.py -->
|
3 |
+
|
4 |
+
|
5 |
+
<div class="footer">
|
6 |
+
<p>Model Architecture by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a> - Pipelines by 🤗 Hugging Face
|
7 |
+
</p>
|
8 |
+
</div>
|
9 |
+
<div class="acknowledgments">
|
10 |
+
<p><h4>LICENSE</h4>
|
11 |
+
The model is licensed with a <a href="https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/LICENSE-MODEL" style="text-decoration: underline;" target="_blank">CreativeML OpenRAIL++</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
|
12 |
+
<p><h4>Biases and content acknowledgment</h4>
|
13 |
+
Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography and violence. The model was trained on the <a href="https://laion.ai/blog/laion-5b/" style="text-decoration: underline;" target="_blank">LAION-5B dataset</a>, which scraped non-curated image-text-pairs from the internet (the exception being the removal of illegal content) and is meant for research purposes. You can read more in the <a href="https://huggingface.co/CompVis/stable-diffusion-v1-4" style="text-decoration: underline;" target="_blank">model card</a></p>
|
14 |
+
</div>
|
15 |
+
|
html/header.html
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
<!-- based on https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/main/app.py -->
|
3 |
+
|
4 |
+
<div style="text-align: center; margin: 0 auto;">
|
5 |
+
<div
|
6 |
+
style="
|
7 |
+
display: inline-flex;
|
8 |
+
align-items: center;
|
9 |
+
gap: 0.8rem;
|
10 |
+
font-size: 1.75rem;
|
11 |
+
"
|
12 |
+
>
|
13 |
+
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" viewBox="0 0 32 32" style="enable-background:new 0 0 512 512;" xml:space="preserve" width="32" height="32"><path style="fill:#FCD577;" d="M29.545 29.791V2.21c-1.22 0 -2.21 -0.99 -2.21 -2.21H4.665c0 1.22 -0.99 2.21 -2.21 2.21v27.581c1.22 0 2.21 0.99 2.21 2.21H27.335C27.335 30.779 28.325 29.791 29.545 29.791z"/><path x="98.205" y="58.928" style="fill:#99B6C6;" width="315.577" height="394.144" d="M6.138 3.683H25.861V28.317H6.138V3.683z"/><path x="98.205" y="58.928" style="fill:#7BD4EF;" width="315.577" height="131.317" d="M6.138 3.683H25.861V11.89H6.138V3.683z"/><g><path style="fill:#7190A5;" d="M14.498 10.274c0 1.446 0.983 1.155 1.953 1.502l0.504 5.317c0 0 -5.599 0.989 -6.026 2.007l0.27 -2.526c0.924 -1.462 1.286 -4.864 1.419 -6.809l0.086 0.006C12.697 9.876 14.498 10.166 14.498 10.274z"/><path style="fill:#7190A5;" d="M21.96 17.647c0 0 -0.707 1.458 -1.716 1.903c0 0 -1.502 -0.827 -1.502 -0.827c-2.276 -1.557 -2.366 -8.3 -2.366 -8.3c0 -1.718 -0.185 -1.615 -1.429 -1.615c-1.167 0 -2.127 -0.606 -2.242 0.963l-0.086 -0.006c0.059 -0.859 0.074 -1.433 0.074 -1.433c0 -1.718 1.449 -3.11 3.237 -3.11s3.237 1.392 3.237 3.11C19.168 8.332 19.334 15.617 21.96 17.647z"/></g><path style="fill:#6C8793;" d="M12.248 24.739c1.538 0.711 3.256 1.591 3.922 2.258c-1.374 0.354 -2.704 0.798 -3.513 1.32h-2.156c-1.096 -0.606 -2.011 -1.472 -2.501 -2.702c-1.953 -4.907 2.905 -8.664 2.905 -8.664c0.001 -0.001 0.002 -0.002 0.003 -0.003c0.213 -0.214 0.523 -0.301 0.811 -0.21l0.02 0.006c-0.142 0.337 -0.03 0.71 0.517 1.108c1.264 0.919 3.091 1.131 4.416 1.143c-1.755 1.338 -3.42 3.333 -4.367 5.618L12.248 24.739z"/><path style="fill:#577484;" d="M16.17 26.997c-0.666 -0.666 -2.385 -1.548 -3.922 -2.258l0.059 -0.126c0.947 -2.284 2.612 -4.28 4.367 -5.618c0.001 0 0.001 0 0.001 0c0.688 -0.525 1.391 -0.948 2.068 -1.247c0.001 0 0.001 0 0.001 0c1.009 -0.446 1.964 -0.617 2.742 -0.44c0.61 0.138 1.109 0.492 1.439 1.095c1.752 3.205 0.601 9.913 0.601 9.913H12.657C13.466 27.796 14.796 27.352 16.17 26.997z"/><path style="fill:#F7DEB0;" d="M14.38 13.1c-0.971 -0.347 -1.687 -1.564 -1.687 -3.01c0 -0.107 0.004 -0.213 0.011 -0.318c0.116 -1.569 1.075 -2.792 2.242 -2.792c1.244 0 2.253 1.392 2.253 3.11c0 0 -0.735 6.103 1.542 7.66c-0.677 0.299 -1.38 0.722 -2.068 1.247c0 0 0 0 -0.001 0c-1.326 -0.012 -3.152 -0.223 -4.416 -1.143c-0.547 -0.398 -0.659 -0.771 -0.517 -1.108c0.426 -1.018 3.171 -1.697 3.171 -1.697L14.38 13.1z"/><path style="fill:#E5CA9E;" d="M14.38 13.1c0 0 1.019 0.216 1.544 -0.309c0 0 -0.401 1.04 -1.346 1.04"/><g><path style="fill:#EAC36E;" points="437.361,0 413.79,58.926 472.717,35.356 " d="M27.335 0L25.862 3.683L29.545 2.21"/><path style="fill:#EAC36E;" points="437.361,512 413.79,453.074 472.717,476.644 " d="M27.335 32L25.862 28.317L29.545 29.791"/><path style="fill:#EAC36E;" points="74.639,512 98.21,453.074 39.283,476.644 " d="M4.665 32L6.138 28.317L2.455 29.791"/><path style="fill:#EAC36E;" points="39.283,35.356 98.21,58.926 74.639,0 " d="M2.455 2.21L6.138 3.683L4.665 0"/><path style="fill:#EAC36E;" d="M26.425 28.881H5.574V3.119h20.851v25.761H26.425zM6.702 27.754h18.597V4.246H6.702V27.754z"/></g><g><path style="fill:#486572;" d="M12.758 21.613c-0.659 0.767 -1.245 1.613 -1.722 2.531l0.486 0.202C11.82 23.401 12.241 22.483 12.758 21.613z"/><path style="fill:#486572;" d="M21.541 25.576l-0.37 0.068c-0.553 0.101 -1.097 0.212 -1.641 0.331l-0.071 -0.201l-0.059 -0.167c-0.019 -0.056 -0.035 -0.112 -0.052 -0.169l-0.104 -0.338l-0.088 -0.342c-0.112 -0.457 -0.197 -0.922 -0.235 -1.393c-0.035 -0.47 -0.032 -0.947 0.042 -1.417c0.072 -0.47 0.205 -0.935 0.422 -1.369c-0.272 0.402 -0.469 0.856 -0.606 1.329c-0.138 0.473 -0.207 0.967 -0.234 1.462c-0.024 0.496 0.002 0.993 0.057 1.487l0.046 0.37l0.063 0.367c0.011 0.061 0.02 0.123 0.033 0.184l0.039 0.182l0.037 0.174c-0.677 0.157 -1.351 0.327 -2.019 0.514c-0.131 0.037 -0.262 0.075 -0.392 0.114l0.004 -0.004c-0.117 -0.095 -0.232 -0.197 -0.35 -0.275c-0.059 -0.041 -0.117 -0.084 -0.177 -0.122l-0.179 -0.112c-0.239 -0.147 -0.482 -0.279 -0.727 -0.406c-0.489 -0.252 -0.985 -0.479 -1.484 -0.697c-0.998 -0.433 -2.01 -0.825 -3.026 -1.196c0.973 0.475 1.937 0.969 2.876 1.499c0.469 0.266 0.932 0.539 1.379 0.832c0.223 0.146 0.442 0.297 0.648 0.456l0.154 0.119c0.05 0.041 0.097 0.083 0.145 0.124c0.002 0.002 0.004 0.003 0.005 0.005c-0.339 0.109 -0.675 0.224 -1.009 0.349c-0.349 0.132 -0.696 0.273 -1.034 0.431c-0.338 0.159 -0.668 0.337 -0.973 0.549c0.322 -0.186 0.662 -0.334 1.01 -0.463c0.347 -0.129 0.701 -0.239 1.056 -0.34c0.394 -0.111 0.79 -0.208 1.19 -0.297c0.006 0.006 0.013 0.013 0.019 0.019l0.03 -0.03c0.306 -0.068 0.614 -0.132 0.922 -0.192c0.727 -0.14 1.457 -0.258 2.189 -0.362c0.731 -0.103 1.469 -0.195 2.197 -0.265l0.374 -0.036L21.541 25.576z"/></g></svg>
|
14 |
+
|
15 |
+
<h1 style="font-weight: 1000; margin-bottom: 8px;margin-top:8px">
|
16 |
+
Stable Diffusion Pipeline UI
|
17 |
+
</h1>
|
18 |
+
</div>
|
19 |
+
<p style="margin-bottom: 4px; font-size: 100%; line-height: 24px;">
|
20 |
+
Stable Diffusion WebUI with first class support for HuggingFace Diffusers Pipelines and Diffusion Schedulers, made in the style of <a style="text-decoration: underline;" href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Automatic1111's WebUI</a> and <a style="text-decoration: underline;" href="https://huggingface.co/spaces/Evel/Evel_Space">Evel_Space</a>.
|
21 |
+
</p>
|
22 |
+
<p> Supports Text-to-Image, Image to Image, and Inpainting modes, with fast switching between pipeline modes by reusing loaded model weights already in memory.
|
23 |
+
</p>
|
24 |
+
</div>
|
25 |
+
|
html/style.css
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#image_upload{min-height: 512px}
|
3 |
+
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 512px}
|
4 |
+
#image_upload .touch-none{display: flex}
|
5 |
+
|
6 |
+
|
7 |
+
#generate-button {
|
8 |
+
color:white;
|
9 |
+
border-color: orangered;
|
10 |
+
background: orange;
|
11 |
+
height: 45px;
|
12 |
+
}
|
13 |
+
|
14 |
+
.footer {
|
15 |
+
margin-bottom: 45px;
|
16 |
+
margin-top: 35px;
|
17 |
+
text-align: center;
|
18 |
+
border-bottom: 1px solid #e5e5e5;
|
19 |
+
}
|
20 |
+
.footer>p {
|
21 |
+
font-size: .8rem;
|
22 |
+
display: inline-block;
|
23 |
+
padding: 0 10px;
|
24 |
+
transform: translateY(10px);
|
25 |
+
background: white;
|
26 |
+
}
|
27 |
+
.dark .footer {
|
28 |
+
border-color: #303030;
|
29 |
+
}
|
30 |
+
.dark .footer>p {
|
31 |
+
background: #0b0f19;
|
32 |
+
}
|
33 |
+
.acknowledgments h4{
|
34 |
+
margin: 1.25em 0 .25em 0;
|
35 |
+
font-weight: bold;
|
36 |
+
font-size: 115%;
|
37 |
+
}
|
38 |
+
|
model_ids.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
andite/anything-v4.0
|
2 |
+
hakurei/waifu-diffusion
|
3 |
+
prompthero/openjourney-v2
|
4 |
+
runwayml/stable-diffusion-v1-5
|
5 |
+
johnslegers/epic-diffusion
|
6 |
+
stabilityai/stable-diffusion-2-1
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.15.0
|
2 |
+
datasets==2.3.2
|
3 |
+
diffusers==0.11.1
|
4 |
+
gradio==3.16.2
|
5 |
+
huggingface_hub==0.11.1
|
6 |
+
numpy==1.23.3
|
7 |
+
packaging==23.0
|
8 |
+
Pillow==9.4.0
|
9 |
+
torch
|
10 |
+
torchvision
|
11 |
+
tqdm==4.64.0
|
12 |
+
transformers==4.25.1
|
test.ipynb
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"with open('model_ids.txt', 'r') as fp:\n",
|
10 |
+
" model_ids = fp.read().splitlines() "
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 4,
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [
|
18 |
+
{
|
19 |
+
"data": {
|
20 |
+
"text/plain": [
|
21 |
+
"['andite/anything-v4.0',\n",
|
22 |
+
" 'hakurei/waifu-diffusion',\n",
|
23 |
+
" 'prompthero/openjourney-v2',\n",
|
24 |
+
" 'runwayml/stable-diffusion-v1-5',\n",
|
25 |
+
" 'johnslegers/epic-diffusion',\n",
|
26 |
+
" 'stabilityai/stable-diffusion-2-1']"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
"execution_count": 4,
|
30 |
+
"metadata": {},
|
31 |
+
"output_type": "execute_result"
|
32 |
+
}
|
33 |
+
],
|
34 |
+
"source": [
|
35 |
+
"model_ids"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": null,
|
41 |
+
"metadata": {},
|
42 |
+
"outputs": [],
|
43 |
+
"source": []
|
44 |
+
}
|
45 |
+
],
|
46 |
+
"metadata": {
|
47 |
+
"kernelspec": {
|
48 |
+
"display_name": "ml",
|
49 |
+
"language": "python",
|
50 |
+
"name": "python3"
|
51 |
+
},
|
52 |
+
"language_info": {
|
53 |
+
"codemirror_mode": {
|
54 |
+
"name": "ipython",
|
55 |
+
"version": 3
|
56 |
+
},
|
57 |
+
"file_extension": ".py",
|
58 |
+
"mimetype": "text/x-python",
|
59 |
+
"name": "python",
|
60 |
+
"nbconvert_exporter": "python",
|
61 |
+
"pygments_lexer": "ipython3",
|
62 |
+
"version": "3.10.8"
|
63 |
+
},
|
64 |
+
"orig_nbformat": 4,
|
65 |
+
"vscode": {
|
66 |
+
"interpreter": {
|
67 |
+
"hash": "cbbcdde725e9a65f1cb734ac4223fed46e03daf1eb62d8ccb3c48face3871521"
|
68 |
+
}
|
69 |
+
}
|
70 |
+
},
|
71 |
+
"nbformat": 4,
|
72 |
+
"nbformat_minor": 2
|
73 |
+
}
|
utils/__init__.py
ADDED
File without changes
|
utils/functions.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
import argparse
|
7 |
+
import shutil
|
8 |
+
import gc
|
9 |
+
import importlib
|
10 |
+
import json
|
11 |
+
|
12 |
+
from diffusers import (
|
13 |
+
StableDiffusionPipeline,
|
14 |
+
StableDiffusionImg2ImgPipeline,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
from .inpaint_pipeline import SDInpaintPipeline as StableDiffusionInpaintPipelineLegacy
|
19 |
+
|
20 |
+
from .textual_inversion import main as run_textual_inversion
|
21 |
+
from .shared import default_scheduler, scheduler_dict, model_ids
|
22 |
+
|
23 |
+
|
24 |
+
_xformers_available = importlib.util.find_spec("xformers") is not None
|
25 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
+
# device = 'cpu'
|
27 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
28 |
+
low_vram_mode = False
|
29 |
+
|
30 |
+
|
31 |
+
tab_to_pipeline = {
|
32 |
+
1: StableDiffusionPipeline,
|
33 |
+
2: StableDiffusionImg2ImgPipeline,
|
34 |
+
3: StableDiffusionInpaintPipelineLegacy,
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
def load_pipe(model_id, scheduler_name, tab_index=1, pipe_kwargs="{}"):
|
39 |
+
global pipe, loaded_model_id
|
40 |
+
|
41 |
+
scheduler = scheduler_dict[scheduler_name]
|
42 |
+
|
43 |
+
pipe_class = tab_to_pipeline[tab_index]
|
44 |
+
|
45 |
+
# load new weights from disk only when changing model_id
|
46 |
+
if model_id != loaded_model_id:
|
47 |
+
pipe = pipe_class.from_pretrained(
|
48 |
+
model_id,
|
49 |
+
torch_dtype=dtype,
|
50 |
+
safety_checker=None,
|
51 |
+
requires_safety_checker=False,
|
52 |
+
scheduler=scheduler.from_pretrained(model_id, subfolder="scheduler"),
|
53 |
+
**json.loads(pipe_kwargs),
|
54 |
+
)
|
55 |
+
loaded_model_id = model_id
|
56 |
+
|
57 |
+
# if same model_id, instantiate new pipeline with same underlying pytorch objects to avoid reloading weights from disk
|
58 |
+
elif pipe_class != pipe.__class__ or not isinstance(pipe.scheduler, scheduler):
|
59 |
+
pipe.components["scheduler"] = scheduler.from_pretrained(
|
60 |
+
model_id, subfolder="scheduler"
|
61 |
+
)
|
62 |
+
pipe = pipe_class(**pipe.components)
|
63 |
+
|
64 |
+
if device == "cuda":
|
65 |
+
pipe = pipe.to(device)
|
66 |
+
if _xformers_available:
|
67 |
+
pipe.enable_xformers_memory_efficient_attention()
|
68 |
+
print("using xformers")
|
69 |
+
if low_vram_mode:
|
70 |
+
pipe.enable_attention_slicing()
|
71 |
+
print("using attention slicing to lower VRAM")
|
72 |
+
|
73 |
+
return pipe
|
74 |
+
|
75 |
+
|
76 |
+
pipe = None
|
77 |
+
loaded_model_id = ""
|
78 |
+
pipe = load_pipe(model_ids[0], default_scheduler)
|
79 |
+
|
80 |
+
|
81 |
+
def pad_image(image):
|
82 |
+
w, h = image.size
|
83 |
+
if w == h:
|
84 |
+
return image
|
85 |
+
elif w > h:
|
86 |
+
new_image = Image.new(image.mode, (w, w), (0, 0, 0))
|
87 |
+
new_image.paste(image, (0, (w - h) // 2))
|
88 |
+
return new_image
|
89 |
+
else:
|
90 |
+
new_image = Image.new(image.mode, (h, h), (0, 0, 0))
|
91 |
+
new_image.paste(image, ((h - w) // 2, 0))
|
92 |
+
return new_image
|
93 |
+
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def generate(
|
97 |
+
model_name,
|
98 |
+
scheduler_name,
|
99 |
+
prompt,
|
100 |
+
guidance,
|
101 |
+
steps,
|
102 |
+
n_images=1,
|
103 |
+
width=512,
|
104 |
+
height=512,
|
105 |
+
seed=0,
|
106 |
+
image=None,
|
107 |
+
strength=0.5,
|
108 |
+
inpaint_image=None,
|
109 |
+
inpaint_strength=0.5,
|
110 |
+
inpaint_radio="",
|
111 |
+
neg_prompt="",
|
112 |
+
tab_index=1,
|
113 |
+
pipe_kwargs="{}",
|
114 |
+
progress=gr.Progress(track_tqdm=True),
|
115 |
+
):
|
116 |
+
|
117 |
+
if seed == -1:
|
118 |
+
seed = random.randint(0, 2147483647)
|
119 |
+
|
120 |
+
generator = torch.Generator(device).manual_seed(seed)
|
121 |
+
|
122 |
+
pipe = load_pipe(
|
123 |
+
model_id=model_name,
|
124 |
+
scheduler_name=scheduler_name,
|
125 |
+
tab_index=tab_index,
|
126 |
+
pipe_kwargs=pipe_kwargs,
|
127 |
+
)
|
128 |
+
|
129 |
+
status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}"
|
130 |
+
|
131 |
+
if tab_index == 1:
|
132 |
+
status_message = "Text to Image " + status_message
|
133 |
+
|
134 |
+
result = pipe(
|
135 |
+
prompt,
|
136 |
+
negative_prompt=neg_prompt,
|
137 |
+
num_images_per_prompt=n_images,
|
138 |
+
num_inference_steps=int(steps),
|
139 |
+
guidance_scale=guidance,
|
140 |
+
width=width,
|
141 |
+
height=height,
|
142 |
+
generator=generator,
|
143 |
+
)
|
144 |
+
|
145 |
+
elif tab_index == 2:
|
146 |
+
|
147 |
+
status_message = "Image to Image " + status_message
|
148 |
+
print(image.size)
|
149 |
+
image = image.resize((width, height))
|
150 |
+
print(image.size)
|
151 |
+
|
152 |
+
result = pipe(
|
153 |
+
prompt,
|
154 |
+
negative_prompt=neg_prompt,
|
155 |
+
num_images_per_prompt=n_images,
|
156 |
+
image=image,
|
157 |
+
num_inference_steps=int(steps),
|
158 |
+
strength=strength,
|
159 |
+
guidance_scale=guidance,
|
160 |
+
generator=generator,
|
161 |
+
)
|
162 |
+
|
163 |
+
elif tab_index == 3:
|
164 |
+
status_message = "Inpainting " + status_message
|
165 |
+
|
166 |
+
init_image = inpaint_image["image"].resize((width, height))
|
167 |
+
mask = inpaint_image["mask"].resize((width, height))
|
168 |
+
|
169 |
+
result = pipe(
|
170 |
+
prompt,
|
171 |
+
negative_prompt=neg_prompt,
|
172 |
+
num_images_per_prompt=n_images,
|
173 |
+
image=init_image,
|
174 |
+
mask_image=mask,
|
175 |
+
num_inference_steps=int(steps),
|
176 |
+
strength=inpaint_strength,
|
177 |
+
preserve_unmasked_image=(
|
178 |
+
inpaint_radio == "preserve non-masked portions of image"
|
179 |
+
),
|
180 |
+
guidance_scale=guidance,
|
181 |
+
generator=generator,
|
182 |
+
)
|
183 |
+
|
184 |
+
else:
|
185 |
+
return None, f"Unhandled tab index: {tab_index}"
|
186 |
+
|
187 |
+
return result.images, status_message
|
188 |
+
|
189 |
+
|
190 |
+
# based on lvkaokao/textual-inversion-training
|
191 |
+
def train_textual_inversion(
|
192 |
+
model_name,
|
193 |
+
scheduler_name,
|
194 |
+
type_of_thing,
|
195 |
+
files,
|
196 |
+
concept_word,
|
197 |
+
init_word,
|
198 |
+
text_train_steps,
|
199 |
+
text_train_bsz,
|
200 |
+
text_learning_rate,
|
201 |
+
progress=gr.Progress(track_tqdm=True),
|
202 |
+
):
|
203 |
+
|
204 |
+
if device == "cpu":
|
205 |
+
raise gr.Error("Textual inversion training not supported on CPU")
|
206 |
+
|
207 |
+
pipe = load_pipe(
|
208 |
+
model_id=model_name,
|
209 |
+
scheduler_name=scheduler_name,
|
210 |
+
tab_index=1,
|
211 |
+
)
|
212 |
+
|
213 |
+
pipe.disable_xformers_memory_efficient_attention() # xformers handled by textual inversion script
|
214 |
+
|
215 |
+
concept_dir = "concept_images"
|
216 |
+
output_dir = "output_model"
|
217 |
+
training_resolution = 512
|
218 |
+
|
219 |
+
if os.path.exists(output_dir):
|
220 |
+
shutil.rmtree("output_model")
|
221 |
+
if os.path.exists(concept_dir):
|
222 |
+
shutil.rmtree("concept_images")
|
223 |
+
|
224 |
+
os.makedirs(concept_dir, exist_ok=True)
|
225 |
+
os.makedirs(output_dir, exist_ok=True)
|
226 |
+
|
227 |
+
gc.collect()
|
228 |
+
torch.cuda.empty_cache()
|
229 |
+
|
230 |
+
if concept_word == "" or concept_word == None:
|
231 |
+
raise gr.Error("You forgot to define your concept prompt")
|
232 |
+
|
233 |
+
for j, file_temp in enumerate(files):
|
234 |
+
file = Image.open(file_temp.name)
|
235 |
+
image = pad_image(file)
|
236 |
+
image = image.resize((training_resolution, training_resolution))
|
237 |
+
extension = file_temp.name.split(".")[1]
|
238 |
+
image = image.convert("RGB")
|
239 |
+
image.save(f"{concept_dir}/{j+1}.{extension}", quality=100)
|
240 |
+
|
241 |
+
args_general = argparse.Namespace(
|
242 |
+
train_data_dir=concept_dir,
|
243 |
+
learnable_property=type_of_thing,
|
244 |
+
placeholder_token=concept_word,
|
245 |
+
initializer_token=init_word,
|
246 |
+
resolution=training_resolution,
|
247 |
+
train_batch_size=text_train_bsz,
|
248 |
+
gradient_accumulation_steps=1,
|
249 |
+
gradient_checkpointing=True,
|
250 |
+
mixed_precision="fp16",
|
251 |
+
use_bf16=False,
|
252 |
+
max_train_steps=int(text_train_steps),
|
253 |
+
learning_rate=text_learning_rate,
|
254 |
+
scale_lr=True,
|
255 |
+
lr_scheduler="constant",
|
256 |
+
lr_warmup_steps=0,
|
257 |
+
output_dir=output_dir,
|
258 |
+
)
|
259 |
+
|
260 |
+
try:
|
261 |
+
final_result = run_textual_inversion(pipe, args_general)
|
262 |
+
except Exception as e:
|
263 |
+
raise gr.Error(e)
|
264 |
+
|
265 |
+
pipe.text_encoder = pipe.text_encoder.eval().to(device, dtype=dtype)
|
266 |
+
pipe.unet = pipe.unet.eval().to(device, dtype=dtype)
|
267 |
+
|
268 |
+
gc.collect()
|
269 |
+
torch.cuda.empty_cache()
|
270 |
+
|
271 |
+
return (
|
272 |
+
f"Finished training! Check the {output_dir} directory for saved model weights"
|
273 |
+
)
|
utils/inpaint_pipeline.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
2 |
+
# you may not use this file except in compliance with the License.
|
3 |
+
# You may obtain a copy of the License at
|
4 |
+
#
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
#
|
7 |
+
# Unless required by applicable law or agreed to in writing, software
|
8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10 |
+
# See the License for the specific language governing permissions and
|
11 |
+
# limitations under the License.
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from typing import Optional, Union, List, Callable
|
15 |
+
import PIL
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint_legacy import (
|
19 |
+
preprocess_image,
|
20 |
+
deprecate,
|
21 |
+
StableDiffusionInpaintPipelineLegacy,
|
22 |
+
StableDiffusionPipelineOutput,
|
23 |
+
PIL_INTERPOLATION,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def preprocess_mask(mask, scale_factor=8):
|
28 |
+
mask = mask.convert("L")
|
29 |
+
w, h = mask.size
|
30 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
31 |
+
|
32 |
+
# input_mask = mask.resize((w, h), resample=PIL_INTERPOLATION["nearest"])
|
33 |
+
input_mask = np.array(mask).astype(np.float32) / 255.0
|
34 |
+
input_mask = np.tile(input_mask, (3, 1, 1))
|
35 |
+
input_mask = input_mask[None].transpose(0, 1, 2, 3) # add batch dimension
|
36 |
+
input_mask = 1 - input_mask # repaint white, keep black
|
37 |
+
input_mask = torch.round(torch.from_numpy(input_mask))
|
38 |
+
|
39 |
+
mask = mask.resize(
|
40 |
+
(w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]
|
41 |
+
)
|
42 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
43 |
+
mask = np.tile(mask, (4, 1, 1))
|
44 |
+
mask = mask[None].transpose(0, 1, 2, 3) # add batch dimension
|
45 |
+
mask = 1 - mask # repaint white, keep black
|
46 |
+
mask = torch.round(torch.from_numpy(mask))
|
47 |
+
|
48 |
+
return mask, input_mask
|
49 |
+
|
50 |
+
|
51 |
+
class SDInpaintPipeline(StableDiffusionInpaintPipelineLegacy):
|
52 |
+
|
53 |
+
# forward call is same as StableDiffusionInpaintPipelineLegacy, but with line added to avoid noise added to final latents right before decoding step
|
54 |
+
@torch.no_grad()
|
55 |
+
def __call__(
|
56 |
+
self,
|
57 |
+
prompt: Union[str, List[str]],
|
58 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
59 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
60 |
+
strength: float = 0.8,
|
61 |
+
num_inference_steps: Optional[int] = 50,
|
62 |
+
guidance_scale: Optional[float] = 7.5,
|
63 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
64 |
+
num_images_per_prompt: Optional[int] = 1,
|
65 |
+
add_predicted_noise: Optional[bool] = False,
|
66 |
+
eta: Optional[float] = 0.0,
|
67 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
68 |
+
output_type: Optional[str] = "pil",
|
69 |
+
return_dict: bool = True,
|
70 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
71 |
+
callback_steps: Optional[int] = 1,
|
72 |
+
preserve_unmasked_image: bool = True,
|
73 |
+
**kwargs,
|
74 |
+
):
|
75 |
+
r"""
|
76 |
+
Function invoked when calling the pipeline for generation.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
prompt (`str` or `List[str]`):
|
80 |
+
The prompt or prompts to guide the image generation.
|
81 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
82 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
83 |
+
process. This is the image whose masked region will be inpainted.
|
84 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
85 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
86 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
87 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
88 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
89 |
+
strength (`float`, *optional*, defaults to 0.8):
|
90 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
91 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
92 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to
|
93 |
+
that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
94 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
95 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
96 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
97 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
98 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
99 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
100 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
101 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
102 |
+
usually at the expense of lower image quality.
|
103 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
104 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
105 |
+
if `guidance_scale` is less than `1`).
|
106 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
107 |
+
The number of images to generate per prompt.
|
108 |
+
add_predicted_noise (`bool`, *optional*, defaults to True):
|
109 |
+
Use predicted noise instead of random noise when constructing noisy versions of the original image in
|
110 |
+
the reverse diffusion process
|
111 |
+
eta (`float`, *optional*, defaults to 0.0):
|
112 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
113 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
114 |
+
generator (`torch.Generator`, *optional*):
|
115 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
116 |
+
to make generation deterministic.
|
117 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
118 |
+
The output format of the generate image. Choose between
|
119 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
120 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
121 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
122 |
+
plain tuple.
|
123 |
+
callback (`Callable`, *optional*):
|
124 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
125 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
126 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
127 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
128 |
+
called at every step.
|
129 |
+
preserve_unmasked_image (`bool`, *optional*, defaults to `True`):
|
130 |
+
Whether or not to preserve the unmasked portions of the original image in the inpainted output. If False,
|
131 |
+
inpainting of the masked latents may produce noticeable distortion of unmasked portions of the decoded
|
132 |
+
image.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
136 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
137 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
138 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
139 |
+
(nsfw) content, according to the `safety_checker`.
|
140 |
+
"""
|
141 |
+
message = "Please use `image` instead of `init_image`."
|
142 |
+
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
143 |
+
image = init_image or image
|
144 |
+
|
145 |
+
# 1. Check inputs
|
146 |
+
self.check_inputs(prompt, strength, callback_steps)
|
147 |
+
|
148 |
+
# 2. Define call parameters
|
149 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
150 |
+
device = self._execution_device
|
151 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
152 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
153 |
+
# corresponds to doing no classifier free guidance.
|
154 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
155 |
+
|
156 |
+
# 3. Encode input prompt
|
157 |
+
text_embeddings = self._encode_prompt(
|
158 |
+
prompt,
|
159 |
+
device,
|
160 |
+
num_images_per_prompt,
|
161 |
+
do_classifier_free_guidance,
|
162 |
+
negative_prompt,
|
163 |
+
)
|
164 |
+
|
165 |
+
# 4. Preprocess image and mask
|
166 |
+
if not isinstance(image, torch.FloatTensor):
|
167 |
+
image = preprocess_image(image)
|
168 |
+
|
169 |
+
# get mask corresponding to input latents as well as image
|
170 |
+
if not isinstance(mask_image, torch.FloatTensor):
|
171 |
+
mask_image, input_mask_image = preprocess_mask(
|
172 |
+
mask_image, self.vae_scale_factor
|
173 |
+
)
|
174 |
+
|
175 |
+
# 5. set timesteps
|
176 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
177 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
178 |
+
num_inference_steps, strength, device
|
179 |
+
)
|
180 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
181 |
+
|
182 |
+
# 6. Prepare latent variables
|
183 |
+
# encode the init image into latents and scale the latents
|
184 |
+
latents, init_latents_orig, noise = self.prepare_latents(
|
185 |
+
image,
|
186 |
+
latent_timestep,
|
187 |
+
batch_size,
|
188 |
+
num_images_per_prompt,
|
189 |
+
text_embeddings.dtype,
|
190 |
+
device,
|
191 |
+
generator,
|
192 |
+
)
|
193 |
+
|
194 |
+
# 7. Prepare mask latent
|
195 |
+
mask = mask_image.to(device=self.device, dtype=latents.dtype)
|
196 |
+
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
|
197 |
+
|
198 |
+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
199 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
200 |
+
|
201 |
+
# 9. Denoising loop
|
202 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
203 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
204 |
+
for i, t in enumerate(timesteps):
|
205 |
+
|
206 |
+
# expand the latents if we are doing classifier free guidance
|
207 |
+
latent_model_input = (
|
208 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
209 |
+
)
|
210 |
+
latent_model_input = self.scheduler.scale_model_input(
|
211 |
+
latent_model_input, t
|
212 |
+
)
|
213 |
+
|
214 |
+
# predict the noise residual
|
215 |
+
noise_pred = self.unet(
|
216 |
+
latent_model_input, t, encoder_hidden_states=text_embeddings
|
217 |
+
).sample
|
218 |
+
|
219 |
+
# perform guidance
|
220 |
+
if do_classifier_free_guidance:
|
221 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
222 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
223 |
+
noise_pred_text - noise_pred_uncond
|
224 |
+
)
|
225 |
+
|
226 |
+
# compute the previous noisy sample x_t -> x_t-1
|
227 |
+
latents = self.scheduler.step(
|
228 |
+
noise_pred, t, latents, **extra_step_kwargs
|
229 |
+
).prev_sample
|
230 |
+
# masking
|
231 |
+
if add_predicted_noise:
|
232 |
+
init_latents_proper = self.scheduler.add_noise(
|
233 |
+
init_latents_orig, noise_pred_uncond, torch.tensor([t])
|
234 |
+
)
|
235 |
+
else:
|
236 |
+
init_latents_proper = self.scheduler.add_noise(
|
237 |
+
init_latents_orig, noise, torch.tensor([t])
|
238 |
+
)
|
239 |
+
|
240 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
241 |
+
|
242 |
+
# call the callback, if provided
|
243 |
+
if i == len(timesteps) - 1 or (
|
244 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
245 |
+
):
|
246 |
+
progress_bar.update()
|
247 |
+
if callback is not None and i % callback_steps == 0:
|
248 |
+
callback(i, t, latents)
|
249 |
+
|
250 |
+
# use original latents corresponding to unmasked portions of the image
|
251 |
+
# necessary step because noise is still added to "init_latents_proper" after final denoising step
|
252 |
+
latents = (init_latents_orig * mask) + (latents * (1 - mask))
|
253 |
+
|
254 |
+
# 10. Post-processing
|
255 |
+
if preserve_unmasked_image:
|
256 |
+
# decode latents
|
257 |
+
latents = 1 / 0.18215 * latents
|
258 |
+
inpaint_image = self.vae.decode(latents).sample
|
259 |
+
|
260 |
+
# restore unmasked parts of image with original image
|
261 |
+
input_mask_image = input_mask_image.to(inpaint_image)
|
262 |
+
image = image.to(inpaint_image)
|
263 |
+
image = (image * input_mask_image) + (
|
264 |
+
inpaint_image * (1 - input_mask_image)
|
265 |
+
) # use original unmasked portions of image to avoid degradation
|
266 |
+
|
267 |
+
# post-processing of image
|
268 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
269 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
270 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
271 |
+
else:
|
272 |
+
image = self.decode_latents(latents)
|
273 |
+
|
274 |
+
# 11. Run safety checker
|
275 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
276 |
+
image, device, text_embeddings.dtype
|
277 |
+
)
|
278 |
+
|
279 |
+
# 12. Convert to PIL
|
280 |
+
if output_type == "pil":
|
281 |
+
image = self.numpy_to_pil(image)
|
282 |
+
|
283 |
+
if not return_dict:
|
284 |
+
return (image, has_nsfw_concept)
|
285 |
+
|
286 |
+
return StableDiffusionPipelineOutput(
|
287 |
+
images=image, nsfw_content_detected=has_nsfw_concept
|
288 |
+
)
|
utils/shared.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import diffusers.schedulers
|
2 |
+
|
3 |
+
# scheduler dict includes superclass SchedulerMixin (it still generates reasonable images)
|
4 |
+
scheduler_dict = {
|
5 |
+
k: v
|
6 |
+
for k, v in diffusers.schedulers.__dict__.items()
|
7 |
+
if "Scheduler" in k and "Flax" not in k
|
8 |
+
}
|
9 |
+
scheduler_dict.pop(
|
10 |
+
"VQDiffusionScheduler"
|
11 |
+
) # requires unique parameter, unlike other schedulers
|
12 |
+
scheduler_names = list(scheduler_dict.keys())
|
13 |
+
default_scheduler = scheduler_names[3] # expected to be DPM Multistep
|
14 |
+
|
15 |
+
with open("model_ids.txt", "r") as fp:
|
16 |
+
model_ids = fp.read().splitlines()
|
utils/textual_inversion.py
ADDED
@@ -0,0 +1,916 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import logging
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import random
|
21 |
+
from pathlib import Path
|
22 |
+
from typing import Optional
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
import torch.nn.functional as F
|
27 |
+
import torch.utils.checkpoint
|
28 |
+
from torch.utils.data import Dataset
|
29 |
+
|
30 |
+
import datasets
|
31 |
+
import diffusers
|
32 |
+
import PIL
|
33 |
+
import transformers
|
34 |
+
from accelerate import Accelerator
|
35 |
+
from accelerate.logging import get_logger
|
36 |
+
from accelerate.utils import set_seed
|
37 |
+
from diffusers import (
|
38 |
+
AutoencoderKL,
|
39 |
+
DDPMScheduler,
|
40 |
+
StableDiffusionPipeline,
|
41 |
+
UNet2DConditionModel,
|
42 |
+
)
|
43 |
+
from diffusers.optimization import get_scheduler
|
44 |
+
from diffusers.utils import check_min_version
|
45 |
+
from diffusers.utils.import_utils import is_xformers_available
|
46 |
+
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
47 |
+
|
48 |
+
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
49 |
+
from packaging import version
|
50 |
+
from PIL import Image
|
51 |
+
from torchvision import transforms
|
52 |
+
from tqdm.auto import tqdm
|
53 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
54 |
+
|
55 |
+
|
56 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
57 |
+
PIL_INTERPOLATION = {
|
58 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
59 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
60 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
61 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
62 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
63 |
+
}
|
64 |
+
else:
|
65 |
+
PIL_INTERPOLATION = {
|
66 |
+
"linear": PIL.Image.LINEAR,
|
67 |
+
"bilinear": PIL.Image.BILINEAR,
|
68 |
+
"bicubic": PIL.Image.BICUBIC,
|
69 |
+
"lanczos": PIL.Image.LANCZOS,
|
70 |
+
"nearest": PIL.Image.NEAREST,
|
71 |
+
}
|
72 |
+
# ------------------------------------------------------------------------------
|
73 |
+
|
74 |
+
|
75 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
76 |
+
check_min_version("0.10.0.dev0")
|
77 |
+
|
78 |
+
|
79 |
+
logger = get_logger(__name__)
|
80 |
+
|
81 |
+
|
82 |
+
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
|
83 |
+
logger.info("Saving embeddings")
|
84 |
+
learned_embeds = (
|
85 |
+
accelerator.unwrap_model(text_encoder)
|
86 |
+
.get_input_embeddings()
|
87 |
+
.weight[placeholder_token_id]
|
88 |
+
)
|
89 |
+
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
|
90 |
+
torch.save(learned_embeds_dict, save_path)
|
91 |
+
|
92 |
+
|
93 |
+
def parse_args():
|
94 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
95 |
+
parser.add_argument(
|
96 |
+
"--save_steps",
|
97 |
+
type=int,
|
98 |
+
default=500,
|
99 |
+
help="Save learned_embeds.bin every X updates steps.",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--only_save_embeds",
|
103 |
+
action="store_true",
|
104 |
+
default=False,
|
105 |
+
help="Save only the embeddings for the new concept.",
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--pretrained_model_name_or_path",
|
109 |
+
type=str,
|
110 |
+
default=None,
|
111 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--revision",
|
115 |
+
type=str,
|
116 |
+
default=None,
|
117 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--tokenizer_name",
|
121 |
+
type=str,
|
122 |
+
default=None,
|
123 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--train_data_dir",
|
127 |
+
type=str,
|
128 |
+
default=None,
|
129 |
+
help="A folder containing the training data.",
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--placeholder_token",
|
133 |
+
type=str,
|
134 |
+
default=None,
|
135 |
+
help="A token to use as a placeholder for the concept.",
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--initializer_token",
|
139 |
+
type=str,
|
140 |
+
default=None,
|
141 |
+
help="A token to use as initializer word.",
|
142 |
+
)
|
143 |
+
|
144 |
+
parser.add_argument(
|
145 |
+
"--learnable_property",
|
146 |
+
type=str,
|
147 |
+
default="object",
|
148 |
+
help="Choose between 'object' and 'style'",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--repeats",
|
152 |
+
type=int,
|
153 |
+
default=100,
|
154 |
+
help="How many times to repeat the training data.",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--output_dir",
|
158 |
+
type=str,
|
159 |
+
default="text-inversion-model",
|
160 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
161 |
+
)
|
162 |
+
parser.add_argument(
|
163 |
+
"--seed", type=int, default=None, help="A seed for reproducible training."
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--resolution",
|
167 |
+
type=int,
|
168 |
+
default=512,
|
169 |
+
help=(
|
170 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
171 |
+
" resolution"
|
172 |
+
),
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--center_crop",
|
176 |
+
action="store_true",
|
177 |
+
help="Whether to center crop images before resizing to resolution",
|
178 |
+
)
|
179 |
+
parser.add_argument(
|
180 |
+
"--train_batch_size",
|
181 |
+
type=int,
|
182 |
+
default=16,
|
183 |
+
help="Batch size (per device) for the training dataloader.",
|
184 |
+
)
|
185 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
186 |
+
parser.add_argument(
|
187 |
+
"--max_train_steps",
|
188 |
+
type=int,
|
189 |
+
default=5000,
|
190 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--gradient_accumulation_steps",
|
194 |
+
type=int,
|
195 |
+
default=1,
|
196 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
197 |
+
)
|
198 |
+
parser.add_argument(
|
199 |
+
"--gradient_checkpointing",
|
200 |
+
action="store_true",
|
201 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--learning_rate",
|
205 |
+
type=float,
|
206 |
+
default=1e-4,
|
207 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--scale_lr",
|
211 |
+
action="store_true",
|
212 |
+
default=False,
|
213 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--lr_scheduler",
|
217 |
+
type=str,
|
218 |
+
default="constant",
|
219 |
+
help=(
|
220 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
221 |
+
' "constant", "constant_with_warmup"]'
|
222 |
+
),
|
223 |
+
)
|
224 |
+
parser.add_argument(
|
225 |
+
"--lr_warmup_steps",
|
226 |
+
type=int,
|
227 |
+
default=500,
|
228 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
229 |
+
)
|
230 |
+
parser.add_argument(
|
231 |
+
"--adam_beta1",
|
232 |
+
type=float,
|
233 |
+
default=0.9,
|
234 |
+
help="The beta1 parameter for the Adam optimizer.",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--adam_beta2",
|
238 |
+
type=float,
|
239 |
+
default=0.999,
|
240 |
+
help="The beta2 parameter for the Adam optimizer.",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--adam_epsilon",
|
247 |
+
type=float,
|
248 |
+
default=1e-08,
|
249 |
+
help="Epsilon value for the Adam optimizer",
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--push_to_hub",
|
253 |
+
action="store_true",
|
254 |
+
help="Whether or not to push the model to the Hub.",
|
255 |
+
)
|
256 |
+
parser.add_argument(
|
257 |
+
"--hub_token",
|
258 |
+
type=str,
|
259 |
+
default=None,
|
260 |
+
help="The token to use to push to the Model Hub.",
|
261 |
+
)
|
262 |
+
parser.add_argument(
|
263 |
+
"--hub_model_id",
|
264 |
+
type=str,
|
265 |
+
default=None,
|
266 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
267 |
+
)
|
268 |
+
parser.add_argument(
|
269 |
+
"--logging_dir",
|
270 |
+
type=str,
|
271 |
+
default="logs",
|
272 |
+
help=(
|
273 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
274 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
275 |
+
),
|
276 |
+
)
|
277 |
+
parser.add_argument(
|
278 |
+
"--mixed_precision",
|
279 |
+
type=str,
|
280 |
+
default="no",
|
281 |
+
choices=["no", "fp16", "bf16"],
|
282 |
+
help=(
|
283 |
+
"Whether to use mixed precision. Choose"
|
284 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
285 |
+
"and an Nvidia Ampere GPU."
|
286 |
+
),
|
287 |
+
)
|
288 |
+
parser.add_argument(
|
289 |
+
"--allow_tf32",
|
290 |
+
action="store_true",
|
291 |
+
help=(
|
292 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
293 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
294 |
+
),
|
295 |
+
)
|
296 |
+
parser.add_argument(
|
297 |
+
"--report_to",
|
298 |
+
type=str,
|
299 |
+
default="tensorboard",
|
300 |
+
help=(
|
301 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
302 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
303 |
+
),
|
304 |
+
)
|
305 |
+
parser.add_argument(
|
306 |
+
"--local_rank",
|
307 |
+
type=int,
|
308 |
+
default=-1,
|
309 |
+
help="For distributed training: local_rank",
|
310 |
+
)
|
311 |
+
parser.add_argument(
|
312 |
+
"--checkpointing_steps",
|
313 |
+
type=int,
|
314 |
+
default=500,
|
315 |
+
help=(
|
316 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
317 |
+
" training using `--resume_from_checkpoint`."
|
318 |
+
),
|
319 |
+
)
|
320 |
+
parser.add_argument(
|
321 |
+
"--resume_from_checkpoint",
|
322 |
+
type=str,
|
323 |
+
default=None,
|
324 |
+
help=(
|
325 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
326 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
327 |
+
),
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--enable_xformers_memory_efficient_attention",
|
331 |
+
action="store_true",
|
332 |
+
help="Whether or not to use xformers.",
|
333 |
+
)
|
334 |
+
|
335 |
+
args = parser.parse_args()
|
336 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
337 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
338 |
+
args.local_rank = env_local_rank
|
339 |
+
|
340 |
+
# if args.train_data_dir is None:
|
341 |
+
# raise ValueError("You must specify a train data directory.")
|
342 |
+
|
343 |
+
return args
|
344 |
+
|
345 |
+
|
346 |
+
imagenet_templates_small = [
|
347 |
+
"a photo of a {}",
|
348 |
+
"a rendering of a {}",
|
349 |
+
"a cropped photo of the {}",
|
350 |
+
"the photo of a {}",
|
351 |
+
"a photo of a clean {}",
|
352 |
+
"a photo of a dirty {}",
|
353 |
+
"a dark photo of the {}",
|
354 |
+
"a photo of my {}",
|
355 |
+
"a photo of the cool {}",
|
356 |
+
"a close-up photo of a {}",
|
357 |
+
"a bright photo of the {}",
|
358 |
+
"a cropped photo of a {}",
|
359 |
+
"a photo of the {}",
|
360 |
+
"a good photo of the {}",
|
361 |
+
"a photo of one {}",
|
362 |
+
"a close-up photo of the {}",
|
363 |
+
"a rendition of the {}",
|
364 |
+
"a photo of the clean {}",
|
365 |
+
"a rendition of a {}",
|
366 |
+
"a photo of a nice {}",
|
367 |
+
"a good photo of a {}",
|
368 |
+
"a photo of the nice {}",
|
369 |
+
"a photo of the small {}",
|
370 |
+
"a photo of the weird {}",
|
371 |
+
"a photo of the large {}",
|
372 |
+
"a photo of a cool {}",
|
373 |
+
"a photo of a small {}",
|
374 |
+
]
|
375 |
+
|
376 |
+
imagenet_style_templates_small = [
|
377 |
+
"a painting in the style of {}",
|
378 |
+
"a rendering in the style of {}",
|
379 |
+
"a cropped painting in the style of {}",
|
380 |
+
"the painting in the style of {}",
|
381 |
+
"a clean painting in the style of {}",
|
382 |
+
"a dirty painting in the style of {}",
|
383 |
+
"a dark painting in the style of {}",
|
384 |
+
"a picture in the style of {}",
|
385 |
+
"a cool painting in the style of {}",
|
386 |
+
"a close-up painting in the style of {}",
|
387 |
+
"a bright painting in the style of {}",
|
388 |
+
"a cropped painting in the style of {}",
|
389 |
+
"a good painting in the style of {}",
|
390 |
+
"a close-up painting in the style of {}",
|
391 |
+
"a rendition in the style of {}",
|
392 |
+
"a nice painting in the style of {}",
|
393 |
+
"a small painting in the style of {}",
|
394 |
+
"a weird painting in the style of {}",
|
395 |
+
"a large painting in the style of {}",
|
396 |
+
]
|
397 |
+
|
398 |
+
|
399 |
+
class TextualInversionDataset(Dataset):
|
400 |
+
def __init__(
|
401 |
+
self,
|
402 |
+
data_root,
|
403 |
+
tokenizer,
|
404 |
+
learnable_property="object", # [object, style]
|
405 |
+
size=512,
|
406 |
+
repeats=100,
|
407 |
+
interpolation="bicubic",
|
408 |
+
flip_p=0.5,
|
409 |
+
set="train",
|
410 |
+
placeholder_token="*",
|
411 |
+
center_crop=False,
|
412 |
+
):
|
413 |
+
self.data_root = data_root
|
414 |
+
self.tokenizer = tokenizer
|
415 |
+
self.learnable_property = learnable_property
|
416 |
+
self.size = size
|
417 |
+
self.placeholder_token = placeholder_token
|
418 |
+
self.center_crop = center_crop
|
419 |
+
self.flip_p = flip_p
|
420 |
+
|
421 |
+
self.image_paths = [
|
422 |
+
os.path.join(self.data_root, file_path)
|
423 |
+
for file_path in os.listdir(self.data_root)
|
424 |
+
]
|
425 |
+
|
426 |
+
self.num_images = len(self.image_paths)
|
427 |
+
self._length = self.num_images
|
428 |
+
|
429 |
+
if set == "train":
|
430 |
+
self._length = self.num_images * repeats
|
431 |
+
|
432 |
+
self.interpolation = {
|
433 |
+
"linear": PIL_INTERPOLATION["linear"],
|
434 |
+
"bilinear": PIL_INTERPOLATION["bilinear"],
|
435 |
+
"bicubic": PIL_INTERPOLATION["bicubic"],
|
436 |
+
"lanczos": PIL_INTERPOLATION["lanczos"],
|
437 |
+
}[interpolation]
|
438 |
+
|
439 |
+
self.templates = (
|
440 |
+
imagenet_style_templates_small
|
441 |
+
if learnable_property == "style"
|
442 |
+
else imagenet_templates_small
|
443 |
+
)
|
444 |
+
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
445 |
+
|
446 |
+
def __len__(self):
|
447 |
+
return self._length
|
448 |
+
|
449 |
+
def __getitem__(self, i):
|
450 |
+
example = {}
|
451 |
+
image = Image.open(self.image_paths[i % self.num_images])
|
452 |
+
|
453 |
+
if not image.mode == "RGB":
|
454 |
+
image = image.convert("RGB")
|
455 |
+
|
456 |
+
placeholder_string = self.placeholder_token
|
457 |
+
text = random.choice(self.templates).format(placeholder_string)
|
458 |
+
|
459 |
+
example["input_ids"] = self.tokenizer(
|
460 |
+
text,
|
461 |
+
padding="max_length",
|
462 |
+
truncation=True,
|
463 |
+
max_length=self.tokenizer.model_max_length,
|
464 |
+
return_tensors="pt",
|
465 |
+
).input_ids[0]
|
466 |
+
|
467 |
+
# default to score-sde preprocessing
|
468 |
+
img = np.array(image).astype(np.uint8)
|
469 |
+
|
470 |
+
if self.center_crop:
|
471 |
+
crop = min(img.shape[0], img.shape[1])
|
472 |
+
(h, w,) = (
|
473 |
+
img.shape[0],
|
474 |
+
img.shape[1],
|
475 |
+
)
|
476 |
+
img = img[
|
477 |
+
(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2
|
478 |
+
]
|
479 |
+
|
480 |
+
image = Image.fromarray(img)
|
481 |
+
image = image.resize((self.size, self.size), resample=self.interpolation)
|
482 |
+
|
483 |
+
image = self.flip_transform(image)
|
484 |
+
image = np.array(image).astype(np.uint8)
|
485 |
+
image = (image / 127.5 - 1.0).astype(np.float32)
|
486 |
+
|
487 |
+
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
488 |
+
return example
|
489 |
+
|
490 |
+
|
491 |
+
def get_full_repo_name(
|
492 |
+
model_id: str, organization: Optional[str] = None, token: Optional[str] = None
|
493 |
+
):
|
494 |
+
if token is None:
|
495 |
+
token = HfFolder.get_token()
|
496 |
+
if organization is None:
|
497 |
+
username = whoami(token)["name"]
|
498 |
+
return f"{username}/{model_id}"
|
499 |
+
else:
|
500 |
+
return f"{organization}/{model_id}"
|
501 |
+
|
502 |
+
|
503 |
+
def main(pipe, args_imported):
|
504 |
+
|
505 |
+
args = parse_args()
|
506 |
+
vars(args).update(vars(args_imported))
|
507 |
+
|
508 |
+
print(args)
|
509 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
510 |
+
|
511 |
+
accelerator = Accelerator(
|
512 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
513 |
+
mixed_precision=args.mixed_precision,
|
514 |
+
log_with=args.report_to,
|
515 |
+
logging_dir=logging_dir,
|
516 |
+
)
|
517 |
+
|
518 |
+
# Make one log on every process with the configuration for debugging.
|
519 |
+
logging.basicConfig(
|
520 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
521 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
522 |
+
level=logging.INFO,
|
523 |
+
)
|
524 |
+
logger.info(accelerator.state, main_process_only=False)
|
525 |
+
if accelerator.is_local_main_process:
|
526 |
+
datasets.utils.logging.set_verbosity_warning()
|
527 |
+
transformers.utils.logging.set_verbosity_warning()
|
528 |
+
diffusers.utils.logging.set_verbosity_info()
|
529 |
+
else:
|
530 |
+
datasets.utils.logging.set_verbosity_error()
|
531 |
+
transformers.utils.logging.set_verbosity_error()
|
532 |
+
diffusers.utils.logging.set_verbosity_error()
|
533 |
+
|
534 |
+
# If passed along, set the training seed now.
|
535 |
+
if args.seed is not None:
|
536 |
+
set_seed(args.seed)
|
537 |
+
|
538 |
+
# Handle the repository creation
|
539 |
+
if accelerator.is_main_process:
|
540 |
+
if args.push_to_hub:
|
541 |
+
if args.hub_model_id is None:
|
542 |
+
repo_name = get_full_repo_name(
|
543 |
+
Path(args.output_dir).name, token=args.hub_token
|
544 |
+
)
|
545 |
+
else:
|
546 |
+
repo_name = args.hub_model_id
|
547 |
+
create_repo(repo_name, exist_ok=True, token=args.hub_token)
|
548 |
+
repo = Repository(
|
549 |
+
args.output_dir, clone_from=repo_name, token=args.hub_token
|
550 |
+
)
|
551 |
+
|
552 |
+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
553 |
+
if "step_*" not in gitignore:
|
554 |
+
gitignore.write("step_*\n")
|
555 |
+
if "epoch_*" not in gitignore:
|
556 |
+
gitignore.write("epoch_*\n")
|
557 |
+
elif args.output_dir is not None:
|
558 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
559 |
+
|
560 |
+
# Load tokenizer
|
561 |
+
tokenizer = pipe.tokenizer
|
562 |
+
|
563 |
+
# Load scheduler and models
|
564 |
+
noise_scheduler = pipe.scheduler
|
565 |
+
text_encoder = pipe.text_encoder
|
566 |
+
vae = pipe.vae
|
567 |
+
unet = pipe.unet
|
568 |
+
|
569 |
+
# Add the placeholder token in tokenizer
|
570 |
+
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
571 |
+
if num_added_tokens == 0:
|
572 |
+
raise ValueError(
|
573 |
+
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
574 |
+
" `placeholder_token` that is not already in the tokenizer."
|
575 |
+
)
|
576 |
+
|
577 |
+
# Convert the initializer_token, placeholder_token to ids
|
578 |
+
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
579 |
+
# Check if initializer_token is a single token or a sequence of tokens
|
580 |
+
if len(token_ids) > 1:
|
581 |
+
raise ValueError("The initializer token must be a single token.")
|
582 |
+
|
583 |
+
initializer_token_id = token_ids[0]
|
584 |
+
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
585 |
+
|
586 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
587 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
588 |
+
|
589 |
+
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
590 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
591 |
+
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
592 |
+
|
593 |
+
# Freeze vae and unet
|
594 |
+
vae.requires_grad_(False)
|
595 |
+
unet.requires_grad_(False)
|
596 |
+
# Freeze all parameters except for the token embeddings in text encoder
|
597 |
+
text_encoder.text_model.encoder.requires_grad_(False)
|
598 |
+
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
599 |
+
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
600 |
+
|
601 |
+
if args.gradient_checkpointing:
|
602 |
+
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
603 |
+
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
|
604 |
+
unet.train()
|
605 |
+
text_encoder.gradient_checkpointing_enable()
|
606 |
+
unet.enable_gradient_checkpointing()
|
607 |
+
|
608 |
+
if args.enable_xformers_memory_efficient_attention:
|
609 |
+
if is_xformers_available():
|
610 |
+
unet.enable_xformers_memory_efficient_attention()
|
611 |
+
else:
|
612 |
+
raise ValueError(
|
613 |
+
"xformers is not available. Make sure it is installed correctly"
|
614 |
+
)
|
615 |
+
|
616 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
617 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
618 |
+
if args.allow_tf32:
|
619 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
620 |
+
|
621 |
+
if args.scale_lr:
|
622 |
+
args.learning_rate = (
|
623 |
+
args.learning_rate
|
624 |
+
* args.gradient_accumulation_steps
|
625 |
+
* args.train_batch_size
|
626 |
+
* accelerator.num_processes
|
627 |
+
)
|
628 |
+
|
629 |
+
# Initialize the optimizer
|
630 |
+
optimizer = torch.optim.AdamW(
|
631 |
+
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
632 |
+
lr=args.learning_rate,
|
633 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
634 |
+
weight_decay=args.adam_weight_decay,
|
635 |
+
eps=args.adam_epsilon,
|
636 |
+
)
|
637 |
+
|
638 |
+
# Dataset and DataLoaders creation:
|
639 |
+
train_dataset = TextualInversionDataset(
|
640 |
+
data_root=args.train_data_dir,
|
641 |
+
tokenizer=tokenizer,
|
642 |
+
size=args.resolution,
|
643 |
+
placeholder_token=args.placeholder_token,
|
644 |
+
repeats=args.repeats,
|
645 |
+
learnable_property=args.learnable_property,
|
646 |
+
center_crop=args.center_crop,
|
647 |
+
set="train",
|
648 |
+
)
|
649 |
+
train_dataloader = torch.utils.data.DataLoader(
|
650 |
+
train_dataset, batch_size=args.train_batch_size, shuffle=True
|
651 |
+
)
|
652 |
+
|
653 |
+
# Scheduler and math around the number of training steps.
|
654 |
+
overrode_max_train_steps = False
|
655 |
+
num_update_steps_per_epoch = math.ceil(
|
656 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
657 |
+
)
|
658 |
+
if args.max_train_steps is None:
|
659 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
660 |
+
overrode_max_train_steps = True
|
661 |
+
|
662 |
+
lr_scheduler = get_scheduler(
|
663 |
+
args.lr_scheduler,
|
664 |
+
optimizer=optimizer,
|
665 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
666 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
667 |
+
)
|
668 |
+
|
669 |
+
# Prepare everything with our `accelerator`.
|
670 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
671 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler
|
672 |
+
)
|
673 |
+
|
674 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
675 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
676 |
+
weight_dtype = torch.float32
|
677 |
+
if accelerator.mixed_precision == "fp16":
|
678 |
+
weight_dtype = torch.float16
|
679 |
+
elif accelerator.mixed_precision == "bf16":
|
680 |
+
weight_dtype = torch.bfloat16
|
681 |
+
|
682 |
+
# Move vae and unet to device and cast to weight_dtype
|
683 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
684 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
685 |
+
text_encoder.to(accelerator.device, dtype=torch.float32)
|
686 |
+
|
687 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
688 |
+
num_update_steps_per_epoch = math.ceil(
|
689 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
690 |
+
)
|
691 |
+
if overrode_max_train_steps:
|
692 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
693 |
+
# Afterwards we recalculate our number of training epochs
|
694 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
695 |
+
|
696 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
697 |
+
# The trackers initializes automatically on the main process.
|
698 |
+
if accelerator.is_main_process:
|
699 |
+
accelerator.init_trackers("textual_inversion", config=vars(args))
|
700 |
+
|
701 |
+
# Train!
|
702 |
+
total_batch_size = (
|
703 |
+
args.train_batch_size
|
704 |
+
* accelerator.num_processes
|
705 |
+
* args.gradient_accumulation_steps
|
706 |
+
)
|
707 |
+
|
708 |
+
logger.info("***** Running training *****")
|
709 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
710 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
711 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
712 |
+
logger.info(
|
713 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
714 |
+
)
|
715 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
716 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
717 |
+
global_step = 0
|
718 |
+
first_epoch = 0
|
719 |
+
|
720 |
+
# Potentially load in the weights and states from a previous save
|
721 |
+
if args.resume_from_checkpoint:
|
722 |
+
if args.resume_from_checkpoint != "latest":
|
723 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
724 |
+
else:
|
725 |
+
# Get the most recent checkpoint
|
726 |
+
dirs = os.listdir(args.output_dir)
|
727 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
728 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
729 |
+
path = dirs[-1]
|
730 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
731 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
732 |
+
global_step = int(path.split("-")[1])
|
733 |
+
|
734 |
+
resume_global_step = global_step * args.gradient_accumulation_steps
|
735 |
+
first_epoch = resume_global_step // num_update_steps_per_epoch
|
736 |
+
resume_step = resume_global_step % num_update_steps_per_epoch
|
737 |
+
|
738 |
+
# Only show the progress bar once on each machine.
|
739 |
+
progress_bar = tqdm(
|
740 |
+
range(global_step, args.max_train_steps),
|
741 |
+
disable=not accelerator.is_local_main_process,
|
742 |
+
)
|
743 |
+
progress_bar.set_description("Steps")
|
744 |
+
|
745 |
+
# keep original embeddings as reference
|
746 |
+
orig_embeds_params = (
|
747 |
+
accelerator.unwrap_model(text_encoder)
|
748 |
+
.get_input_embeddings()
|
749 |
+
.weight.data.clone()
|
750 |
+
)
|
751 |
+
|
752 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
753 |
+
text_encoder.train()
|
754 |
+
for step, batch in enumerate(train_dataloader):
|
755 |
+
# Skip steps until we reach the resumed step
|
756 |
+
if (
|
757 |
+
args.resume_from_checkpoint
|
758 |
+
and epoch == first_epoch
|
759 |
+
and step < resume_step
|
760 |
+
):
|
761 |
+
if step % args.gradient_accumulation_steps == 0:
|
762 |
+
progress_bar.update(1)
|
763 |
+
continue
|
764 |
+
|
765 |
+
with accelerator.accumulate(text_encoder):
|
766 |
+
# Convert images to latent space
|
767 |
+
latents = (
|
768 |
+
vae.encode(batch["pixel_values"].to(dtype=weight_dtype))
|
769 |
+
.latent_dist.sample()
|
770 |
+
.detach()
|
771 |
+
)
|
772 |
+
latents = latents * 0.18215
|
773 |
+
|
774 |
+
# Sample noise that we'll add to the latents
|
775 |
+
noise = torch.randn_like(latents)
|
776 |
+
bsz = latents.shape[0]
|
777 |
+
# Sample a random timestep for each image
|
778 |
+
timesteps = torch.randint(
|
779 |
+
0,
|
780 |
+
noise_scheduler.config.num_train_timesteps,
|
781 |
+
(bsz,),
|
782 |
+
device=latents.device,
|
783 |
+
)
|
784 |
+
timesteps = timesteps.long()
|
785 |
+
|
786 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
787 |
+
# (this is the forward diffusion process)
|
788 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
789 |
+
|
790 |
+
# Get the text embedding for conditioning
|
791 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(
|
792 |
+
dtype=weight_dtype
|
793 |
+
)
|
794 |
+
|
795 |
+
# Predict the noise residual
|
796 |
+
model_pred = unet(
|
797 |
+
noisy_latents, timesteps, encoder_hidden_states
|
798 |
+
).sample
|
799 |
+
|
800 |
+
# Get the target for loss depending on the prediction type
|
801 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
802 |
+
target = noise
|
803 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
804 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
805 |
+
else:
|
806 |
+
raise ValueError(
|
807 |
+
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
808 |
+
)
|
809 |
+
|
810 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
811 |
+
|
812 |
+
accelerator.backward(loss)
|
813 |
+
|
814 |
+
if accelerator.num_processes > 1:
|
815 |
+
grads = text_encoder.module.get_input_embeddings().weight.grad
|
816 |
+
else:
|
817 |
+
grads = text_encoder.get_input_embeddings().weight.grad
|
818 |
+
# Get the index for tokens that we want to zero the grads for
|
819 |
+
index_grads_to_zero = (
|
820 |
+
torch.arange(len(tokenizer)) != placeholder_token_id
|
821 |
+
)
|
822 |
+
grads.data[index_grads_to_zero, :] = grads.data[
|
823 |
+
index_grads_to_zero, :
|
824 |
+
].fill_(0)
|
825 |
+
|
826 |
+
optimizer.step()
|
827 |
+
lr_scheduler.step()
|
828 |
+
optimizer.zero_grad()
|
829 |
+
|
830 |
+
# Let's make sure we don't update any embedding weights besides the newly added token
|
831 |
+
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
832 |
+
with torch.no_grad():
|
833 |
+
accelerator.unwrap_model(
|
834 |
+
text_encoder
|
835 |
+
).get_input_embeddings().weight[
|
836 |
+
index_no_updates
|
837 |
+
] = orig_embeds_params[
|
838 |
+
index_no_updates
|
839 |
+
]
|
840 |
+
|
841 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
842 |
+
if accelerator.sync_gradients:
|
843 |
+
progress_bar.update(1)
|
844 |
+
global_step += 1
|
845 |
+
if global_step % args.save_steps == 0:
|
846 |
+
save_path = os.path.join(
|
847 |
+
args.output_dir, f"{args.placeholder_token}-{global_step}.bin"
|
848 |
+
)
|
849 |
+
save_progress(
|
850 |
+
text_encoder, placeholder_token_id, accelerator, args, save_path
|
851 |
+
)
|
852 |
+
|
853 |
+
if global_step % args.checkpointing_steps == 0:
|
854 |
+
if accelerator.is_main_process:
|
855 |
+
save_path = os.path.join(
|
856 |
+
args.output_dir, f"checkpoint-{global_step}"
|
857 |
+
)
|
858 |
+
accelerator.save_state(save_path)
|
859 |
+
logger.info(f"Saved state to {save_path}")
|
860 |
+
|
861 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
862 |
+
progress_bar.set_postfix(**logs)
|
863 |
+
accelerator.log(logs, step=global_step)
|
864 |
+
|
865 |
+
if global_step >= args.max_train_steps:
|
866 |
+
break
|
867 |
+
|
868 |
+
# Create the pipeline using using the trained modules and save it.
|
869 |
+
accelerator.wait_for_everyone()
|
870 |
+
if accelerator.is_main_process:
|
871 |
+
if args.push_to_hub and args.only_save_embeds:
|
872 |
+
logger.warn(
|
873 |
+
"Enabling full model saving because --push_to_hub=True was specified."
|
874 |
+
)
|
875 |
+
save_full_model = True
|
876 |
+
else:
|
877 |
+
save_full_model = not args.only_save_embeds
|
878 |
+
if save_full_model:
|
879 |
+
pipe.save_pretrained(args.output_dir)
|
880 |
+
# Save the newly trained embeddings
|
881 |
+
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
|
882 |
+
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
883 |
+
|
884 |
+
if args.push_to_hub:
|
885 |
+
repo.push_to_hub(
|
886 |
+
commit_message="End of training", blocking=False, auto_lfs_prune=True
|
887 |
+
)
|
888 |
+
|
889 |
+
accelerator.end_training()
|
890 |
+
|
891 |
+
|
892 |
+
if __name__ == "__main__":
|
893 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
894 |
+
"andite/anything-v4.0", torch_dtype=torch.float16
|
895 |
+
)
|
896 |
+
|
897 |
+
imported_args = argparse.Namespace(
|
898 |
+
train_data_dir="concept_images",
|
899 |
+
learnable_property="object",
|
900 |
+
placeholder_token="redeyegirl",
|
901 |
+
initializer_token="girl",
|
902 |
+
resolution=512,
|
903 |
+
train_batch_size=1,
|
904 |
+
gradient_accumulation_steps=1,
|
905 |
+
gradient_checkpointing=True,
|
906 |
+
mixed_precision="fp16",
|
907 |
+
use_bf16=False,
|
908 |
+
max_train_steps=1000,
|
909 |
+
learning_rate=5.0e-4,
|
910 |
+
scale_lr=False,
|
911 |
+
lr_scheduler="constant",
|
912 |
+
lr_warmup_steps=0,
|
913 |
+
output_dir="output_model",
|
914 |
+
)
|
915 |
+
|
916 |
+
main(pipeline, imported_args)
|