Jackflack09 lint commited on
Commit
001c876
·
0 Parent(s):

Duplicate from lint/sdpipe_webui

Browse files

Co-authored-by: lint <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ WIP/
3
+ concept_images/
4
+ output_model/
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sdpipe Webui
3
+ emoji: 🍌
4
+ colorFrom: gray
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.16.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: openrail
11
+ duplicated_from: lint/sdpipe_webui
12
+ ---
13
+
14
+ # **Stable Diffusion Pipeline Web UI**
15
+
16
+ Stable Diffusion WebUI with first class support for HuggingFace Diffusers Pipelines and Diffusion Schedulers, made in the style of Automatic1111's WebUI and Evel_Space.
17
+
18
+ Supports Huggingface `Text-to-Image`, `Image to Image`, and `Inpainting` pipelines, with fast switching between pipeline modes by reusing loaded model weights already in memory.
19
+
20
+ Install requirements with `pip install -r requirements.txt`
21
+
22
+ Run with `python app.py`
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from multiprocessing import cpu_count
3
+ from utils.functions import generate, train_textual_inversion
4
+ from utils.shared import model_ids, scheduler_names, default_scheduler
5
+
6
+ default_img_size = 512
7
+
8
+ with open("html/header.html") as fp:
9
+ header = fp.read()
10
+
11
+ with open("html/footer.html") as fp:
12
+ footer = fp.read()
13
+
14
+ with gr.Blocks(css="html/style.css") as demo:
15
+
16
+ pipe_state = gr.State(lambda: 1)
17
+
18
+ gr.HTML(header)
19
+
20
+ with gr.Row():
21
+
22
+ with gr.Column(scale=70):
23
+
24
+ # with gr.Row():
25
+ prompt = gr.Textbox(
26
+ label="Prompt", placeholder="<Shift+Enter> to generate", lines=2
27
+ )
28
+ neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="", lines=2)
29
+
30
+ with gr.Column(scale=30):
31
+ model_name = gr.Dropdown(
32
+ label="Model", choices=model_ids, value=model_ids[0]
33
+ )
34
+ scheduler_name = gr.Dropdown(
35
+ label="Scheduler", choices=scheduler_names, value=default_scheduler
36
+ )
37
+ generate_button = gr.Button(value="Generate", elem_id="generate-button")
38
+
39
+ with gr.Row():
40
+
41
+ with gr.Column():
42
+
43
+ with gr.Tab("Text to Image") as tab:
44
+ tab.select(lambda: 1, [], pipe_state)
45
+
46
+ with gr.Tab("Image to image") as tab:
47
+ tab.select(lambda: 2, [], pipe_state)
48
+
49
+ image = gr.Image(
50
+ label="Image to Image",
51
+ source="upload",
52
+ tool="editor",
53
+ type="pil",
54
+ elem_id="image_upload",
55
+ ).style(height=default_img_size)
56
+ strength = gr.Slider(
57
+ label="Denoising strength",
58
+ minimum=0,
59
+ maximum=1,
60
+ step=0.02,
61
+ value=0.8,
62
+ )
63
+
64
+ with gr.Tab("Inpainting") as tab:
65
+ tab.select(lambda: 3, [], pipe_state)
66
+
67
+ inpaint_image = gr.Image(
68
+ label="Inpainting",
69
+ source="upload",
70
+ tool="sketch",
71
+ type="pil",
72
+ elem_id="image_upload",
73
+ ).style(height=default_img_size)
74
+ inpaint_strength = gr.Slider(
75
+ label="Denoising strength",
76
+ minimum=0,
77
+ maximum=1,
78
+ step=0.02,
79
+ value=0.8,
80
+ )
81
+ inpaint_options = [
82
+ "preserve non-masked portions of image",
83
+ "output entire inpainted image",
84
+ ]
85
+ inpaint_radio = gr.Radio(
86
+ inpaint_options,
87
+ value=inpaint_options[0],
88
+ show_label=False,
89
+ interactive=True,
90
+ )
91
+
92
+ with gr.Tab("Textual Inversion") as tab:
93
+ tab.select(lambda: 4, [], pipe_state)
94
+
95
+ type_of_thing = gr.Dropdown(
96
+ label="What would you like to train?",
97
+ choices=["object", "person", "style"],
98
+ value="object",
99
+ interactive=True,
100
+ )
101
+
102
+ text_train_bsz = gr.Slider(
103
+ label="Training Batch Size",
104
+ minimum=1,
105
+ maximum=8,
106
+ step=1,
107
+ value=1,
108
+ )
109
+
110
+ files = gr.File(
111
+ label=f"""Upload the images for your concept""",
112
+ file_count="multiple",
113
+ interactive=True,
114
+ visible=True,
115
+ )
116
+
117
+ text_train_steps = gr.Number(label="How many steps", value=1000)
118
+
119
+ text_learning_rate = gr.Number(label="Learning Rate", value=5.0e-4)
120
+
121
+ concept_word = gr.Textbox(
122
+ label=f"""concept word - use a unique, made up word to avoid collisions"""
123
+ )
124
+ init_word = gr.Textbox(
125
+ label=f"""initial word - to init the concept embedding"""
126
+ )
127
+
128
+ textual_inversion_button = gr.Button(value="Train Textual Inversion")
129
+
130
+ training_status = gr.Text(label="Training Status")
131
+
132
+ with gr.Row():
133
+ batch_size = gr.Slider(
134
+ label="Batch Size", value=1, minimum=1, maximum=8, step=1
135
+ )
136
+ seed = gr.Slider(-1, 2147483647, label="Seed", value=-1, step=1)
137
+
138
+ with gr.Row():
139
+ guidance = gr.Slider(
140
+ label="Guidance scale", value=7.5, minimum=0, maximum=20
141
+ )
142
+ steps = gr.Slider(
143
+ label="Steps", value=20, minimum=1, maximum=100, step=1
144
+ )
145
+
146
+ with gr.Row():
147
+ width = gr.Slider(
148
+ label="Width",
149
+ value=default_img_size,
150
+ minimum=64,
151
+ maximum=1024,
152
+ step=32,
153
+ )
154
+ height = gr.Slider(
155
+ label="Height",
156
+ value=default_img_size,
157
+ minimum=64,
158
+ maximum=1024,
159
+ step=32,
160
+ )
161
+
162
+ with gr.Column():
163
+ gallery = gr.Gallery(
164
+ label="Generated images", show_label=False, elem_id="gallery"
165
+ ).style(height=default_img_size, grid=2)
166
+
167
+ generation_details = gr.Markdown()
168
+
169
+ pipe_kwargs = gr.Textbox(label="Pipe kwargs", value="{\n\t\n}")
170
+
171
+ # if torch.cuda.is_available():
172
+ # giga = 2**30
173
+ # vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1)
174
+ # demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False)
175
+
176
+ gr.HTML(footer)
177
+
178
+ inputs = [
179
+ model_name,
180
+ scheduler_name,
181
+ prompt,
182
+ guidance,
183
+ steps,
184
+ batch_size,
185
+ width,
186
+ height,
187
+ seed,
188
+ image,
189
+ strength,
190
+ inpaint_image,
191
+ inpaint_strength,
192
+ inpaint_radio,
193
+ neg_prompt,
194
+ pipe_state,
195
+ pipe_kwargs,
196
+ ]
197
+ outputs = [gallery, generation_details]
198
+
199
+ prompt.submit(generate, inputs=inputs, outputs=outputs)
200
+ generate_button.click(generate, inputs=inputs, outputs=outputs)
201
+
202
+ textual_inversion_inputs = [
203
+ model_name,
204
+ scheduler_name,
205
+ type_of_thing,
206
+ files,
207
+ concept_word,
208
+ init_word,
209
+ text_train_steps,
210
+ text_train_bsz,
211
+ text_learning_rate,
212
+ ]
213
+
214
+ textual_inversion_button.click(
215
+ train_textual_inversion,
216
+ inputs=textual_inversion_inputs,
217
+ outputs=[training_status],
218
+ )
219
+
220
+
221
+ # demo = gr.TabbedInterface([demo, dreambooth_tab], ["Main", "Dreambooth"])
222
+
223
+ demo.queue(concurrency_count=cpu_count())
224
+
225
+ demo.launch()
html/footer.html ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <!-- based on https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/main/app.py -->
3
+
4
+
5
+ <div class="footer">
6
+ <p>Model Architecture by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a> - Pipelines by 🤗 Hugging Face
7
+ </p>
8
+ </div>
9
+ <div class="acknowledgments">
10
+ <p><h4>LICENSE</h4>
11
+ The model is licensed with a <a href="https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/LICENSE-MODEL" style="text-decoration: underline;" target="_blank">CreativeML OpenRAIL++</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
12
+ <p><h4>Biases and content acknowledgment</h4>
13
+ Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography and violence. The model was trained on the <a href="https://laion.ai/blog/laion-5b/" style="text-decoration: underline;" target="_blank">LAION-5B dataset</a>, which scraped non-curated image-text-pairs from the internet (the exception being the removal of illegal content) and is meant for research purposes. You can read more in the <a href="https://huggingface.co/CompVis/stable-diffusion-v1-4" style="text-decoration: underline;" target="_blank">model card</a></p>
14
+ </div>
15
+
html/header.html ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <!-- based on https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/main/app.py -->
3
+
4
+ <div style="text-align: center; margin: 0 auto;">
5
+ <div
6
+ style="
7
+ display: inline-flex;
8
+ align-items: center;
9
+ gap: 0.8rem;
10
+ font-size: 1.75rem;
11
+ "
12
+ >
13
+ <svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" viewBox="0 0 32 32" style="enable-background:new 0 0 512 512;" xml:space="preserve" width="32" height="32"><path style="fill:#FCD577;" d="M29.545 29.791V2.21c-1.22 0 -2.21 -0.99 -2.21 -2.21H4.665c0 1.22 -0.99 2.21 -2.21 2.21v27.581c1.22 0 2.21 0.99 2.21 2.21H27.335C27.335 30.779 28.325 29.791 29.545 29.791z"/><path x="98.205" y="58.928" style="fill:#99B6C6;" width="315.577" height="394.144" d="M6.138 3.683H25.861V28.317H6.138V3.683z"/><path x="98.205" y="58.928" style="fill:#7BD4EF;" width="315.577" height="131.317" d="M6.138 3.683H25.861V11.89H6.138V3.683z"/><g><path style="fill:#7190A5;" d="M14.498 10.274c0 1.446 0.983 1.155 1.953 1.502l0.504 5.317c0 0 -5.599 0.989 -6.026 2.007l0.27 -2.526c0.924 -1.462 1.286 -4.864 1.419 -6.809l0.086 0.006C12.697 9.876 14.498 10.166 14.498 10.274z"/><path style="fill:#7190A5;" d="M21.96 17.647c0 0 -0.707 1.458 -1.716 1.903c0 0 -1.502 -0.827 -1.502 -0.827c-2.276 -1.557 -2.366 -8.3 -2.366 -8.3c0 -1.718 -0.185 -1.615 -1.429 -1.615c-1.167 0 -2.127 -0.606 -2.242 0.963l-0.086 -0.006c0.059 -0.859 0.074 -1.433 0.074 -1.433c0 -1.718 1.449 -3.11 3.237 -3.11s3.237 1.392 3.237 3.11C19.168 8.332 19.334 15.617 21.96 17.647z"/></g><path style="fill:#6C8793;" d="M12.248 24.739c1.538 0.711 3.256 1.591 3.922 2.258c-1.374 0.354 -2.704 0.798 -3.513 1.32h-2.156c-1.096 -0.606 -2.011 -1.472 -2.501 -2.702c-1.953 -4.907 2.905 -8.664 2.905 -8.664c0.001 -0.001 0.002 -0.002 0.003 -0.003c0.213 -0.214 0.523 -0.301 0.811 -0.21l0.02 0.006c-0.142 0.337 -0.03 0.71 0.517 1.108c1.264 0.919 3.091 1.131 4.416 1.143c-1.755 1.338 -3.42 3.333 -4.367 5.618L12.248 24.739z"/><path style="fill:#577484;" d="M16.17 26.997c-0.666 -0.666 -2.385 -1.548 -3.922 -2.258l0.059 -0.126c0.947 -2.284 2.612 -4.28 4.367 -5.618c0.001 0 0.001 0 0.001 0c0.688 -0.525 1.391 -0.948 2.068 -1.247c0.001 0 0.001 0 0.001 0c1.009 -0.446 1.964 -0.617 2.742 -0.44c0.61 0.138 1.109 0.492 1.439 1.095c1.752 3.205 0.601 9.913 0.601 9.913H12.657C13.466 27.796 14.796 27.352 16.17 26.997z"/><path style="fill:#F7DEB0;" d="M14.38 13.1c-0.971 -0.347 -1.687 -1.564 -1.687 -3.01c0 -0.107 0.004 -0.213 0.011 -0.318c0.116 -1.569 1.075 -2.792 2.242 -2.792c1.244 0 2.253 1.392 2.253 3.11c0 0 -0.735 6.103 1.542 7.66c-0.677 0.299 -1.38 0.722 -2.068 1.247c0 0 0 0 -0.001 0c-1.326 -0.012 -3.152 -0.223 -4.416 -1.143c-0.547 -0.398 -0.659 -0.771 -0.517 -1.108c0.426 -1.018 3.171 -1.697 3.171 -1.697L14.38 13.1z"/><path style="fill:#E5CA9E;" d="M14.38 13.1c0 0 1.019 0.216 1.544 -0.309c0 0 -0.401 1.04 -1.346 1.04"/><g><path style="fill:#EAC36E;" points="437.361,0 413.79,58.926 472.717,35.356 " d="M27.335 0L25.862 3.683L29.545 2.21"/><path style="fill:#EAC36E;" points="437.361,512 413.79,453.074 472.717,476.644 " d="M27.335 32L25.862 28.317L29.545 29.791"/><path style="fill:#EAC36E;" points="74.639,512 98.21,453.074 39.283,476.644 " d="M4.665 32L6.138 28.317L2.455 29.791"/><path style="fill:#EAC36E;" points="39.283,35.356 98.21,58.926 74.639,0 " d="M2.455 2.21L6.138 3.683L4.665 0"/><path style="fill:#EAC36E;" d="M26.425 28.881H5.574V3.119h20.851v25.761H26.425zM6.702 27.754h18.597V4.246H6.702V27.754z"/></g><g><path style="fill:#486572;" d="M12.758 21.613c-0.659 0.767 -1.245 1.613 -1.722 2.531l0.486 0.202C11.82 23.401 12.241 22.483 12.758 21.613z"/><path style="fill:#486572;" d="M21.541 25.576l-0.37 0.068c-0.553 0.101 -1.097 0.212 -1.641 0.331l-0.071 -0.201l-0.059 -0.167c-0.019 -0.056 -0.035 -0.112 -0.052 -0.169l-0.104 -0.338l-0.088 -0.342c-0.112 -0.457 -0.197 -0.922 -0.235 -1.393c-0.035 -0.47 -0.032 -0.947 0.042 -1.417c0.072 -0.47 0.205 -0.935 0.422 -1.369c-0.272 0.402 -0.469 0.856 -0.606 1.329c-0.138 0.473 -0.207 0.967 -0.234 1.462c-0.024 0.496 0.002 0.993 0.057 1.487l0.046 0.37l0.063 0.367c0.011 0.061 0.02 0.123 0.033 0.184l0.039 0.182l0.037 0.174c-0.677 0.157 -1.351 0.327 -2.019 0.514c-0.131 0.037 -0.262 0.075 -0.392 0.114l0.004 -0.004c-0.117 -0.095 -0.232 -0.197 -0.35 -0.275c-0.059 -0.041 -0.117 -0.084 -0.177 -0.122l-0.179 -0.112c-0.239 -0.147 -0.482 -0.279 -0.727 -0.406c-0.489 -0.252 -0.985 -0.479 -1.484 -0.697c-0.998 -0.433 -2.01 -0.825 -3.026 -1.196c0.973 0.475 1.937 0.969 2.876 1.499c0.469 0.266 0.932 0.539 1.379 0.832c0.223 0.146 0.442 0.297 0.648 0.456l0.154 0.119c0.05 0.041 0.097 0.083 0.145 0.124c0.002 0.002 0.004 0.003 0.005 0.005c-0.339 0.109 -0.675 0.224 -1.009 0.349c-0.349 0.132 -0.696 0.273 -1.034 0.431c-0.338 0.159 -0.668 0.337 -0.973 0.549c0.322 -0.186 0.662 -0.334 1.01 -0.463c0.347 -0.129 0.701 -0.239 1.056 -0.34c0.394 -0.111 0.79 -0.208 1.19 -0.297c0.006 0.006 0.013 0.013 0.019 0.019l0.03 -0.03c0.306 -0.068 0.614 -0.132 0.922 -0.192c0.727 -0.14 1.457 -0.258 2.189 -0.362c0.731 -0.103 1.469 -0.195 2.197 -0.265l0.374 -0.036L21.541 25.576z"/></g></svg>
14
+
15
+ <h1 style="font-weight: 1000; margin-bottom: 8px;margin-top:8px">
16
+ Stable Diffusion Pipeline UI
17
+ </h1>
18
+ </div>
19
+ <p style="margin-bottom: 4px; font-size: 100%; line-height: 24px;">
20
+ Stable Diffusion WebUI with first class support for HuggingFace Diffusers Pipelines and Diffusion Schedulers, made in the style of <a style="text-decoration: underline;" href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Automatic1111's WebUI</a> and <a style="text-decoration: underline;" href="https://huggingface.co/spaces/Evel/Evel_Space">Evel_Space</a>.
21
+ </p>
22
+ <p> Supports Text-to-Image, Image to Image, and Inpainting modes, with fast switching between pipeline modes by reusing loaded model weights already in memory.
23
+ </p>
24
+ </div>
25
+
html/style.css ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #image_upload{min-height: 512px}
3
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 512px}
4
+ #image_upload .touch-none{display: flex}
5
+
6
+
7
+ #generate-button {
8
+ color:white;
9
+ border-color: orangered;
10
+ background: orange;
11
+ height: 45px;
12
+ }
13
+
14
+ .footer {
15
+ margin-bottom: 45px;
16
+ margin-top: 35px;
17
+ text-align: center;
18
+ border-bottom: 1px solid #e5e5e5;
19
+ }
20
+ .footer>p {
21
+ font-size: .8rem;
22
+ display: inline-block;
23
+ padding: 0 10px;
24
+ transform: translateY(10px);
25
+ background: white;
26
+ }
27
+ .dark .footer {
28
+ border-color: #303030;
29
+ }
30
+ .dark .footer>p {
31
+ background: #0b0f19;
32
+ }
33
+ .acknowledgments h4{
34
+ margin: 1.25em 0 .25em 0;
35
+ font-weight: bold;
36
+ font-size: 115%;
37
+ }
38
+
model_ids.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ andite/anything-v4.0
2
+ hakurei/waifu-diffusion
3
+ prompthero/openjourney-v2
4
+ runwayml/stable-diffusion-v1-5
5
+ johnslegers/epic-diffusion
6
+ stabilityai/stable-diffusion-2-1
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ datasets==2.3.2
3
+ diffusers==0.11.1
4
+ gradio==3.16.2
5
+ huggingface_hub==0.11.1
6
+ numpy==1.23.3
7
+ packaging==23.0
8
+ Pillow==9.4.0
9
+ torch
10
+ torchvision
11
+ tqdm==4.64.0
12
+ transformers==4.25.1
test.ipynb ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "with open('model_ids.txt', 'r') as fp:\n",
10
+ " model_ids = fp.read().splitlines() "
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 4,
16
+ "metadata": {},
17
+ "outputs": [
18
+ {
19
+ "data": {
20
+ "text/plain": [
21
+ "['andite/anything-v4.0',\n",
22
+ " 'hakurei/waifu-diffusion',\n",
23
+ " 'prompthero/openjourney-v2',\n",
24
+ " 'runwayml/stable-diffusion-v1-5',\n",
25
+ " 'johnslegers/epic-diffusion',\n",
26
+ " 'stabilityai/stable-diffusion-2-1']"
27
+ ]
28
+ },
29
+ "execution_count": 4,
30
+ "metadata": {},
31
+ "output_type": "execute_result"
32
+ }
33
+ ],
34
+ "source": [
35
+ "model_ids"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": []
44
+ }
45
+ ],
46
+ "metadata": {
47
+ "kernelspec": {
48
+ "display_name": "ml",
49
+ "language": "python",
50
+ "name": "python3"
51
+ },
52
+ "language_info": {
53
+ "codemirror_mode": {
54
+ "name": "ipython",
55
+ "version": 3
56
+ },
57
+ "file_extension": ".py",
58
+ "mimetype": "text/x-python",
59
+ "name": "python",
60
+ "nbconvert_exporter": "python",
61
+ "pygments_lexer": "ipython3",
62
+ "version": "3.10.8"
63
+ },
64
+ "orig_nbformat": 4,
65
+ "vscode": {
66
+ "interpreter": {
67
+ "hash": "cbbcdde725e9a65f1cb734ac4223fed46e03daf1eb62d8ccb3c48face3871521"
68
+ }
69
+ }
70
+ },
71
+ "nbformat": 4,
72
+ "nbformat_minor": 2
73
+ }
utils/__init__.py ADDED
File without changes
utils/functions.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import random
4
+ from PIL import Image
5
+ import os
6
+ import argparse
7
+ import shutil
8
+ import gc
9
+ import importlib
10
+ import json
11
+
12
+ from diffusers import (
13
+ StableDiffusionPipeline,
14
+ StableDiffusionImg2ImgPipeline,
15
+ )
16
+
17
+
18
+ from .inpaint_pipeline import SDInpaintPipeline as StableDiffusionInpaintPipelineLegacy
19
+
20
+ from .textual_inversion import main as run_textual_inversion
21
+ from .shared import default_scheduler, scheduler_dict, model_ids
22
+
23
+
24
+ _xformers_available = importlib.util.find_spec("xformers") is not None
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ # device = 'cpu'
27
+ dtype = torch.float16 if device == "cuda" else torch.float32
28
+ low_vram_mode = False
29
+
30
+
31
+ tab_to_pipeline = {
32
+ 1: StableDiffusionPipeline,
33
+ 2: StableDiffusionImg2ImgPipeline,
34
+ 3: StableDiffusionInpaintPipelineLegacy,
35
+ }
36
+
37
+
38
+ def load_pipe(model_id, scheduler_name, tab_index=1, pipe_kwargs="{}"):
39
+ global pipe, loaded_model_id
40
+
41
+ scheduler = scheduler_dict[scheduler_name]
42
+
43
+ pipe_class = tab_to_pipeline[tab_index]
44
+
45
+ # load new weights from disk only when changing model_id
46
+ if model_id != loaded_model_id:
47
+ pipe = pipe_class.from_pretrained(
48
+ model_id,
49
+ torch_dtype=dtype,
50
+ safety_checker=None,
51
+ requires_safety_checker=False,
52
+ scheduler=scheduler.from_pretrained(model_id, subfolder="scheduler"),
53
+ **json.loads(pipe_kwargs),
54
+ )
55
+ loaded_model_id = model_id
56
+
57
+ # if same model_id, instantiate new pipeline with same underlying pytorch objects to avoid reloading weights from disk
58
+ elif pipe_class != pipe.__class__ or not isinstance(pipe.scheduler, scheduler):
59
+ pipe.components["scheduler"] = scheduler.from_pretrained(
60
+ model_id, subfolder="scheduler"
61
+ )
62
+ pipe = pipe_class(**pipe.components)
63
+
64
+ if device == "cuda":
65
+ pipe = pipe.to(device)
66
+ if _xformers_available:
67
+ pipe.enable_xformers_memory_efficient_attention()
68
+ print("using xformers")
69
+ if low_vram_mode:
70
+ pipe.enable_attention_slicing()
71
+ print("using attention slicing to lower VRAM")
72
+
73
+ return pipe
74
+
75
+
76
+ pipe = None
77
+ loaded_model_id = ""
78
+ pipe = load_pipe(model_ids[0], default_scheduler)
79
+
80
+
81
+ def pad_image(image):
82
+ w, h = image.size
83
+ if w == h:
84
+ return image
85
+ elif w > h:
86
+ new_image = Image.new(image.mode, (w, w), (0, 0, 0))
87
+ new_image.paste(image, (0, (w - h) // 2))
88
+ return new_image
89
+ else:
90
+ new_image = Image.new(image.mode, (h, h), (0, 0, 0))
91
+ new_image.paste(image, ((h - w) // 2, 0))
92
+ return new_image
93
+
94
+
95
+ @torch.no_grad()
96
+ def generate(
97
+ model_name,
98
+ scheduler_name,
99
+ prompt,
100
+ guidance,
101
+ steps,
102
+ n_images=1,
103
+ width=512,
104
+ height=512,
105
+ seed=0,
106
+ image=None,
107
+ strength=0.5,
108
+ inpaint_image=None,
109
+ inpaint_strength=0.5,
110
+ inpaint_radio="",
111
+ neg_prompt="",
112
+ tab_index=1,
113
+ pipe_kwargs="{}",
114
+ progress=gr.Progress(track_tqdm=True),
115
+ ):
116
+
117
+ if seed == -1:
118
+ seed = random.randint(0, 2147483647)
119
+
120
+ generator = torch.Generator(device).manual_seed(seed)
121
+
122
+ pipe = load_pipe(
123
+ model_id=model_name,
124
+ scheduler_name=scheduler_name,
125
+ tab_index=tab_index,
126
+ pipe_kwargs=pipe_kwargs,
127
+ )
128
+
129
+ status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}"
130
+
131
+ if tab_index == 1:
132
+ status_message = "Text to Image " + status_message
133
+
134
+ result = pipe(
135
+ prompt,
136
+ negative_prompt=neg_prompt,
137
+ num_images_per_prompt=n_images,
138
+ num_inference_steps=int(steps),
139
+ guidance_scale=guidance,
140
+ width=width,
141
+ height=height,
142
+ generator=generator,
143
+ )
144
+
145
+ elif tab_index == 2:
146
+
147
+ status_message = "Image to Image " + status_message
148
+ print(image.size)
149
+ image = image.resize((width, height))
150
+ print(image.size)
151
+
152
+ result = pipe(
153
+ prompt,
154
+ negative_prompt=neg_prompt,
155
+ num_images_per_prompt=n_images,
156
+ image=image,
157
+ num_inference_steps=int(steps),
158
+ strength=strength,
159
+ guidance_scale=guidance,
160
+ generator=generator,
161
+ )
162
+
163
+ elif tab_index == 3:
164
+ status_message = "Inpainting " + status_message
165
+
166
+ init_image = inpaint_image["image"].resize((width, height))
167
+ mask = inpaint_image["mask"].resize((width, height))
168
+
169
+ result = pipe(
170
+ prompt,
171
+ negative_prompt=neg_prompt,
172
+ num_images_per_prompt=n_images,
173
+ image=init_image,
174
+ mask_image=mask,
175
+ num_inference_steps=int(steps),
176
+ strength=inpaint_strength,
177
+ preserve_unmasked_image=(
178
+ inpaint_radio == "preserve non-masked portions of image"
179
+ ),
180
+ guidance_scale=guidance,
181
+ generator=generator,
182
+ )
183
+
184
+ else:
185
+ return None, f"Unhandled tab index: {tab_index}"
186
+
187
+ return result.images, status_message
188
+
189
+
190
+ # based on lvkaokao/textual-inversion-training
191
+ def train_textual_inversion(
192
+ model_name,
193
+ scheduler_name,
194
+ type_of_thing,
195
+ files,
196
+ concept_word,
197
+ init_word,
198
+ text_train_steps,
199
+ text_train_bsz,
200
+ text_learning_rate,
201
+ progress=gr.Progress(track_tqdm=True),
202
+ ):
203
+
204
+ if device == "cpu":
205
+ raise gr.Error("Textual inversion training not supported on CPU")
206
+
207
+ pipe = load_pipe(
208
+ model_id=model_name,
209
+ scheduler_name=scheduler_name,
210
+ tab_index=1,
211
+ )
212
+
213
+ pipe.disable_xformers_memory_efficient_attention() # xformers handled by textual inversion script
214
+
215
+ concept_dir = "concept_images"
216
+ output_dir = "output_model"
217
+ training_resolution = 512
218
+
219
+ if os.path.exists(output_dir):
220
+ shutil.rmtree("output_model")
221
+ if os.path.exists(concept_dir):
222
+ shutil.rmtree("concept_images")
223
+
224
+ os.makedirs(concept_dir, exist_ok=True)
225
+ os.makedirs(output_dir, exist_ok=True)
226
+
227
+ gc.collect()
228
+ torch.cuda.empty_cache()
229
+
230
+ if concept_word == "" or concept_word == None:
231
+ raise gr.Error("You forgot to define your concept prompt")
232
+
233
+ for j, file_temp in enumerate(files):
234
+ file = Image.open(file_temp.name)
235
+ image = pad_image(file)
236
+ image = image.resize((training_resolution, training_resolution))
237
+ extension = file_temp.name.split(".")[1]
238
+ image = image.convert("RGB")
239
+ image.save(f"{concept_dir}/{j+1}.{extension}", quality=100)
240
+
241
+ args_general = argparse.Namespace(
242
+ train_data_dir=concept_dir,
243
+ learnable_property=type_of_thing,
244
+ placeholder_token=concept_word,
245
+ initializer_token=init_word,
246
+ resolution=training_resolution,
247
+ train_batch_size=text_train_bsz,
248
+ gradient_accumulation_steps=1,
249
+ gradient_checkpointing=True,
250
+ mixed_precision="fp16",
251
+ use_bf16=False,
252
+ max_train_steps=int(text_train_steps),
253
+ learning_rate=text_learning_rate,
254
+ scale_lr=True,
255
+ lr_scheduler="constant",
256
+ lr_warmup_steps=0,
257
+ output_dir=output_dir,
258
+ )
259
+
260
+ try:
261
+ final_result = run_textual_inversion(pipe, args_general)
262
+ except Exception as e:
263
+ raise gr.Error(e)
264
+
265
+ pipe.text_encoder = pipe.text_encoder.eval().to(device, dtype=dtype)
266
+ pipe.unet = pipe.unet.eval().to(device, dtype=dtype)
267
+
268
+ gc.collect()
269
+ torch.cuda.empty_cache()
270
+
271
+ return (
272
+ f"Finished training! Check the {output_dir} directory for saved model weights"
273
+ )
utils/inpaint_pipeline.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ import torch
14
+ from typing import Optional, Union, List, Callable
15
+ import PIL
16
+ import numpy as np
17
+
18
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint_legacy import (
19
+ preprocess_image,
20
+ deprecate,
21
+ StableDiffusionInpaintPipelineLegacy,
22
+ StableDiffusionPipelineOutput,
23
+ PIL_INTERPOLATION,
24
+ )
25
+
26
+
27
+ def preprocess_mask(mask, scale_factor=8):
28
+ mask = mask.convert("L")
29
+ w, h = mask.size
30
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
31
+
32
+ # input_mask = mask.resize((w, h), resample=PIL_INTERPOLATION["nearest"])
33
+ input_mask = np.array(mask).astype(np.float32) / 255.0
34
+ input_mask = np.tile(input_mask, (3, 1, 1))
35
+ input_mask = input_mask[None].transpose(0, 1, 2, 3) # add batch dimension
36
+ input_mask = 1 - input_mask # repaint white, keep black
37
+ input_mask = torch.round(torch.from_numpy(input_mask))
38
+
39
+ mask = mask.resize(
40
+ (w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]
41
+ )
42
+ mask = np.array(mask).astype(np.float32) / 255.0
43
+ mask = np.tile(mask, (4, 1, 1))
44
+ mask = mask[None].transpose(0, 1, 2, 3) # add batch dimension
45
+ mask = 1 - mask # repaint white, keep black
46
+ mask = torch.round(torch.from_numpy(mask))
47
+
48
+ return mask, input_mask
49
+
50
+
51
+ class SDInpaintPipeline(StableDiffusionInpaintPipelineLegacy):
52
+
53
+ # forward call is same as StableDiffusionInpaintPipelineLegacy, but with line added to avoid noise added to final latents right before decoding step
54
+ @torch.no_grad()
55
+ def __call__(
56
+ self,
57
+ prompt: Union[str, List[str]],
58
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
59
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
60
+ strength: float = 0.8,
61
+ num_inference_steps: Optional[int] = 50,
62
+ guidance_scale: Optional[float] = 7.5,
63
+ negative_prompt: Optional[Union[str, List[str]]] = None,
64
+ num_images_per_prompt: Optional[int] = 1,
65
+ add_predicted_noise: Optional[bool] = False,
66
+ eta: Optional[float] = 0.0,
67
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
68
+ output_type: Optional[str] = "pil",
69
+ return_dict: bool = True,
70
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
71
+ callback_steps: Optional[int] = 1,
72
+ preserve_unmasked_image: bool = True,
73
+ **kwargs,
74
+ ):
75
+ r"""
76
+ Function invoked when calling the pipeline for generation.
77
+
78
+ Args:
79
+ prompt (`str` or `List[str]`):
80
+ The prompt or prompts to guide the image generation.
81
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
82
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
83
+ process. This is the image whose masked region will be inpainted.
84
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
85
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
86
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
87
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
88
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
89
+ strength (`float`, *optional*, defaults to 0.8):
90
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
91
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
92
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to
93
+ that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
94
+ num_inference_steps (`int`, *optional*, defaults to 50):
95
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
96
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
97
+ guidance_scale (`float`, *optional*, defaults to 7.5):
98
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
99
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
100
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
101
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
102
+ usually at the expense of lower image quality.
103
+ negative_prompt (`str` or `List[str]`, *optional*):
104
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
105
+ if `guidance_scale` is less than `1`).
106
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
107
+ The number of images to generate per prompt.
108
+ add_predicted_noise (`bool`, *optional*, defaults to True):
109
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
110
+ the reverse diffusion process
111
+ eta (`float`, *optional*, defaults to 0.0):
112
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
113
+ [`schedulers.DDIMScheduler`], will be ignored for others.
114
+ generator (`torch.Generator`, *optional*):
115
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
116
+ to make generation deterministic.
117
+ output_type (`str`, *optional*, defaults to `"pil"`):
118
+ The output format of the generate image. Choose between
119
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
120
+ return_dict (`bool`, *optional*, defaults to `True`):
121
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
122
+ plain tuple.
123
+ callback (`Callable`, *optional*):
124
+ A function that will be called every `callback_steps` steps during inference. The function will be
125
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
126
+ callback_steps (`int`, *optional*, defaults to 1):
127
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
128
+ called at every step.
129
+ preserve_unmasked_image (`bool`, *optional*, defaults to `True`):
130
+ Whether or not to preserve the unmasked portions of the original image in the inpainted output. If False,
131
+ inpainting of the masked latents may produce noticeable distortion of unmasked portions of the decoded
132
+ image.
133
+
134
+ Returns:
135
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
136
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
137
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
138
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
139
+ (nsfw) content, according to the `safety_checker`.
140
+ """
141
+ message = "Please use `image` instead of `init_image`."
142
+ init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
143
+ image = init_image or image
144
+
145
+ # 1. Check inputs
146
+ self.check_inputs(prompt, strength, callback_steps)
147
+
148
+ # 2. Define call parameters
149
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
150
+ device = self._execution_device
151
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
152
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
153
+ # corresponds to doing no classifier free guidance.
154
+ do_classifier_free_guidance = guidance_scale > 1.0
155
+
156
+ # 3. Encode input prompt
157
+ text_embeddings = self._encode_prompt(
158
+ prompt,
159
+ device,
160
+ num_images_per_prompt,
161
+ do_classifier_free_guidance,
162
+ negative_prompt,
163
+ )
164
+
165
+ # 4. Preprocess image and mask
166
+ if not isinstance(image, torch.FloatTensor):
167
+ image = preprocess_image(image)
168
+
169
+ # get mask corresponding to input latents as well as image
170
+ if not isinstance(mask_image, torch.FloatTensor):
171
+ mask_image, input_mask_image = preprocess_mask(
172
+ mask_image, self.vae_scale_factor
173
+ )
174
+
175
+ # 5. set timesteps
176
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
177
+ timesteps, num_inference_steps = self.get_timesteps(
178
+ num_inference_steps, strength, device
179
+ )
180
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
181
+
182
+ # 6. Prepare latent variables
183
+ # encode the init image into latents and scale the latents
184
+ latents, init_latents_orig, noise = self.prepare_latents(
185
+ image,
186
+ latent_timestep,
187
+ batch_size,
188
+ num_images_per_prompt,
189
+ text_embeddings.dtype,
190
+ device,
191
+ generator,
192
+ )
193
+
194
+ # 7. Prepare mask latent
195
+ mask = mask_image.to(device=self.device, dtype=latents.dtype)
196
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
197
+
198
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
199
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
200
+
201
+ # 9. Denoising loop
202
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
203
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
204
+ for i, t in enumerate(timesteps):
205
+
206
+ # expand the latents if we are doing classifier free guidance
207
+ latent_model_input = (
208
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
209
+ )
210
+ latent_model_input = self.scheduler.scale_model_input(
211
+ latent_model_input, t
212
+ )
213
+
214
+ # predict the noise residual
215
+ noise_pred = self.unet(
216
+ latent_model_input, t, encoder_hidden_states=text_embeddings
217
+ ).sample
218
+
219
+ # perform guidance
220
+ if do_classifier_free_guidance:
221
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
222
+ noise_pred = noise_pred_uncond + guidance_scale * (
223
+ noise_pred_text - noise_pred_uncond
224
+ )
225
+
226
+ # compute the previous noisy sample x_t -> x_t-1
227
+ latents = self.scheduler.step(
228
+ noise_pred, t, latents, **extra_step_kwargs
229
+ ).prev_sample
230
+ # masking
231
+ if add_predicted_noise:
232
+ init_latents_proper = self.scheduler.add_noise(
233
+ init_latents_orig, noise_pred_uncond, torch.tensor([t])
234
+ )
235
+ else:
236
+ init_latents_proper = self.scheduler.add_noise(
237
+ init_latents_orig, noise, torch.tensor([t])
238
+ )
239
+
240
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
241
+
242
+ # call the callback, if provided
243
+ if i == len(timesteps) - 1 or (
244
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
245
+ ):
246
+ progress_bar.update()
247
+ if callback is not None and i % callback_steps == 0:
248
+ callback(i, t, latents)
249
+
250
+ # use original latents corresponding to unmasked portions of the image
251
+ # necessary step because noise is still added to "init_latents_proper" after final denoising step
252
+ latents = (init_latents_orig * mask) + (latents * (1 - mask))
253
+
254
+ # 10. Post-processing
255
+ if preserve_unmasked_image:
256
+ # decode latents
257
+ latents = 1 / 0.18215 * latents
258
+ inpaint_image = self.vae.decode(latents).sample
259
+
260
+ # restore unmasked parts of image with original image
261
+ input_mask_image = input_mask_image.to(inpaint_image)
262
+ image = image.to(inpaint_image)
263
+ image = (image * input_mask_image) + (
264
+ inpaint_image * (1 - input_mask_image)
265
+ ) # use original unmasked portions of image to avoid degradation
266
+
267
+ # post-processing of image
268
+ image = (image / 2 + 0.5).clamp(0, 1)
269
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
270
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
271
+ else:
272
+ image = self.decode_latents(latents)
273
+
274
+ # 11. Run safety checker
275
+ image, has_nsfw_concept = self.run_safety_checker(
276
+ image, device, text_embeddings.dtype
277
+ )
278
+
279
+ # 12. Convert to PIL
280
+ if output_type == "pil":
281
+ image = self.numpy_to_pil(image)
282
+
283
+ if not return_dict:
284
+ return (image, has_nsfw_concept)
285
+
286
+ return StableDiffusionPipelineOutput(
287
+ images=image, nsfw_content_detected=has_nsfw_concept
288
+ )
utils/shared.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import diffusers.schedulers
2
+
3
+ # scheduler dict includes superclass SchedulerMixin (it still generates reasonable images)
4
+ scheduler_dict = {
5
+ k: v
6
+ for k, v in diffusers.schedulers.__dict__.items()
7
+ if "Scheduler" in k and "Flax" not in k
8
+ }
9
+ scheduler_dict.pop(
10
+ "VQDiffusionScheduler"
11
+ ) # requires unique parameter, unlike other schedulers
12
+ scheduler_names = list(scheduler_dict.keys())
13
+ default_scheduler = scheduler_names[3] # expected to be DPM Multistep
14
+
15
+ with open("model_ids.txt", "r") as fp:
16
+ model_ids = fp.read().splitlines()
utils/textual_inversion.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import logging
18
+ import math
19
+ import os
20
+ import random
21
+ from pathlib import Path
22
+ from typing import Optional
23
+
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import torch.utils.checkpoint
28
+ from torch.utils.data import Dataset
29
+
30
+ import datasets
31
+ import diffusers
32
+ import PIL
33
+ import transformers
34
+ from accelerate import Accelerator
35
+ from accelerate.logging import get_logger
36
+ from accelerate.utils import set_seed
37
+ from diffusers import (
38
+ AutoencoderKL,
39
+ DDPMScheduler,
40
+ StableDiffusionPipeline,
41
+ UNet2DConditionModel,
42
+ )
43
+ from diffusers.optimization import get_scheduler
44
+ from diffusers.utils import check_min_version
45
+ from diffusers.utils.import_utils import is_xformers_available
46
+ from huggingface_hub import HfFolder, Repository, create_repo, whoami
47
+
48
+ # TODO: remove and import from diffusers.utils when the new version of diffusers is released
49
+ from packaging import version
50
+ from PIL import Image
51
+ from torchvision import transforms
52
+ from tqdm.auto import tqdm
53
+ from transformers import CLIPTextModel, CLIPTokenizer
54
+
55
+
56
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
57
+ PIL_INTERPOLATION = {
58
+ "linear": PIL.Image.Resampling.BILINEAR,
59
+ "bilinear": PIL.Image.Resampling.BILINEAR,
60
+ "bicubic": PIL.Image.Resampling.BICUBIC,
61
+ "lanczos": PIL.Image.Resampling.LANCZOS,
62
+ "nearest": PIL.Image.Resampling.NEAREST,
63
+ }
64
+ else:
65
+ PIL_INTERPOLATION = {
66
+ "linear": PIL.Image.LINEAR,
67
+ "bilinear": PIL.Image.BILINEAR,
68
+ "bicubic": PIL.Image.BICUBIC,
69
+ "lanczos": PIL.Image.LANCZOS,
70
+ "nearest": PIL.Image.NEAREST,
71
+ }
72
+ # ------------------------------------------------------------------------------
73
+
74
+
75
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
76
+ check_min_version("0.10.0.dev0")
77
+
78
+
79
+ logger = get_logger(__name__)
80
+
81
+
82
+ def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
83
+ logger.info("Saving embeddings")
84
+ learned_embeds = (
85
+ accelerator.unwrap_model(text_encoder)
86
+ .get_input_embeddings()
87
+ .weight[placeholder_token_id]
88
+ )
89
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
90
+ torch.save(learned_embeds_dict, save_path)
91
+
92
+
93
+ def parse_args():
94
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
95
+ parser.add_argument(
96
+ "--save_steps",
97
+ type=int,
98
+ default=500,
99
+ help="Save learned_embeds.bin every X updates steps.",
100
+ )
101
+ parser.add_argument(
102
+ "--only_save_embeds",
103
+ action="store_true",
104
+ default=False,
105
+ help="Save only the embeddings for the new concept.",
106
+ )
107
+ parser.add_argument(
108
+ "--pretrained_model_name_or_path",
109
+ type=str,
110
+ default=None,
111
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
112
+ )
113
+ parser.add_argument(
114
+ "--revision",
115
+ type=str,
116
+ default=None,
117
+ help="Revision of pretrained model identifier from huggingface.co/models.",
118
+ )
119
+ parser.add_argument(
120
+ "--tokenizer_name",
121
+ type=str,
122
+ default=None,
123
+ help="Pretrained tokenizer name or path if not the same as model_name",
124
+ )
125
+ parser.add_argument(
126
+ "--train_data_dir",
127
+ type=str,
128
+ default=None,
129
+ help="A folder containing the training data.",
130
+ )
131
+ parser.add_argument(
132
+ "--placeholder_token",
133
+ type=str,
134
+ default=None,
135
+ help="A token to use as a placeholder for the concept.",
136
+ )
137
+ parser.add_argument(
138
+ "--initializer_token",
139
+ type=str,
140
+ default=None,
141
+ help="A token to use as initializer word.",
142
+ )
143
+
144
+ parser.add_argument(
145
+ "--learnable_property",
146
+ type=str,
147
+ default="object",
148
+ help="Choose between 'object' and 'style'",
149
+ )
150
+ parser.add_argument(
151
+ "--repeats",
152
+ type=int,
153
+ default=100,
154
+ help="How many times to repeat the training data.",
155
+ )
156
+ parser.add_argument(
157
+ "--output_dir",
158
+ type=str,
159
+ default="text-inversion-model",
160
+ help="The output directory where the model predictions and checkpoints will be written.",
161
+ )
162
+ parser.add_argument(
163
+ "--seed", type=int, default=None, help="A seed for reproducible training."
164
+ )
165
+ parser.add_argument(
166
+ "--resolution",
167
+ type=int,
168
+ default=512,
169
+ help=(
170
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
171
+ " resolution"
172
+ ),
173
+ )
174
+ parser.add_argument(
175
+ "--center_crop",
176
+ action="store_true",
177
+ help="Whether to center crop images before resizing to resolution",
178
+ )
179
+ parser.add_argument(
180
+ "--train_batch_size",
181
+ type=int,
182
+ default=16,
183
+ help="Batch size (per device) for the training dataloader.",
184
+ )
185
+ parser.add_argument("--num_train_epochs", type=int, default=100)
186
+ parser.add_argument(
187
+ "--max_train_steps",
188
+ type=int,
189
+ default=5000,
190
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
191
+ )
192
+ parser.add_argument(
193
+ "--gradient_accumulation_steps",
194
+ type=int,
195
+ default=1,
196
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
197
+ )
198
+ parser.add_argument(
199
+ "--gradient_checkpointing",
200
+ action="store_true",
201
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
202
+ )
203
+ parser.add_argument(
204
+ "--learning_rate",
205
+ type=float,
206
+ default=1e-4,
207
+ help="Initial learning rate (after the potential warmup period) to use.",
208
+ )
209
+ parser.add_argument(
210
+ "--scale_lr",
211
+ action="store_true",
212
+ default=False,
213
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
214
+ )
215
+ parser.add_argument(
216
+ "--lr_scheduler",
217
+ type=str,
218
+ default="constant",
219
+ help=(
220
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
221
+ ' "constant", "constant_with_warmup"]'
222
+ ),
223
+ )
224
+ parser.add_argument(
225
+ "--lr_warmup_steps",
226
+ type=int,
227
+ default=500,
228
+ help="Number of steps for the warmup in the lr scheduler.",
229
+ )
230
+ parser.add_argument(
231
+ "--adam_beta1",
232
+ type=float,
233
+ default=0.9,
234
+ help="The beta1 parameter for the Adam optimizer.",
235
+ )
236
+ parser.add_argument(
237
+ "--adam_beta2",
238
+ type=float,
239
+ default=0.999,
240
+ help="The beta2 parameter for the Adam optimizer.",
241
+ )
242
+ parser.add_argument(
243
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
244
+ )
245
+ parser.add_argument(
246
+ "--adam_epsilon",
247
+ type=float,
248
+ default=1e-08,
249
+ help="Epsilon value for the Adam optimizer",
250
+ )
251
+ parser.add_argument(
252
+ "--push_to_hub",
253
+ action="store_true",
254
+ help="Whether or not to push the model to the Hub.",
255
+ )
256
+ parser.add_argument(
257
+ "--hub_token",
258
+ type=str,
259
+ default=None,
260
+ help="The token to use to push to the Model Hub.",
261
+ )
262
+ parser.add_argument(
263
+ "--hub_model_id",
264
+ type=str,
265
+ default=None,
266
+ help="The name of the repository to keep in sync with the local `output_dir`.",
267
+ )
268
+ parser.add_argument(
269
+ "--logging_dir",
270
+ type=str,
271
+ default="logs",
272
+ help=(
273
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
274
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
275
+ ),
276
+ )
277
+ parser.add_argument(
278
+ "--mixed_precision",
279
+ type=str,
280
+ default="no",
281
+ choices=["no", "fp16", "bf16"],
282
+ help=(
283
+ "Whether to use mixed precision. Choose"
284
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
285
+ "and an Nvidia Ampere GPU."
286
+ ),
287
+ )
288
+ parser.add_argument(
289
+ "--allow_tf32",
290
+ action="store_true",
291
+ help=(
292
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
293
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
294
+ ),
295
+ )
296
+ parser.add_argument(
297
+ "--report_to",
298
+ type=str,
299
+ default="tensorboard",
300
+ help=(
301
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
302
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
303
+ ),
304
+ )
305
+ parser.add_argument(
306
+ "--local_rank",
307
+ type=int,
308
+ default=-1,
309
+ help="For distributed training: local_rank",
310
+ )
311
+ parser.add_argument(
312
+ "--checkpointing_steps",
313
+ type=int,
314
+ default=500,
315
+ help=(
316
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
317
+ " training using `--resume_from_checkpoint`."
318
+ ),
319
+ )
320
+ parser.add_argument(
321
+ "--resume_from_checkpoint",
322
+ type=str,
323
+ default=None,
324
+ help=(
325
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
326
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
327
+ ),
328
+ )
329
+ parser.add_argument(
330
+ "--enable_xformers_memory_efficient_attention",
331
+ action="store_true",
332
+ help="Whether or not to use xformers.",
333
+ )
334
+
335
+ args = parser.parse_args()
336
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
337
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
338
+ args.local_rank = env_local_rank
339
+
340
+ # if args.train_data_dir is None:
341
+ # raise ValueError("You must specify a train data directory.")
342
+
343
+ return args
344
+
345
+
346
+ imagenet_templates_small = [
347
+ "a photo of a {}",
348
+ "a rendering of a {}",
349
+ "a cropped photo of the {}",
350
+ "the photo of a {}",
351
+ "a photo of a clean {}",
352
+ "a photo of a dirty {}",
353
+ "a dark photo of the {}",
354
+ "a photo of my {}",
355
+ "a photo of the cool {}",
356
+ "a close-up photo of a {}",
357
+ "a bright photo of the {}",
358
+ "a cropped photo of a {}",
359
+ "a photo of the {}",
360
+ "a good photo of the {}",
361
+ "a photo of one {}",
362
+ "a close-up photo of the {}",
363
+ "a rendition of the {}",
364
+ "a photo of the clean {}",
365
+ "a rendition of a {}",
366
+ "a photo of a nice {}",
367
+ "a good photo of a {}",
368
+ "a photo of the nice {}",
369
+ "a photo of the small {}",
370
+ "a photo of the weird {}",
371
+ "a photo of the large {}",
372
+ "a photo of a cool {}",
373
+ "a photo of a small {}",
374
+ ]
375
+
376
+ imagenet_style_templates_small = [
377
+ "a painting in the style of {}",
378
+ "a rendering in the style of {}",
379
+ "a cropped painting in the style of {}",
380
+ "the painting in the style of {}",
381
+ "a clean painting in the style of {}",
382
+ "a dirty painting in the style of {}",
383
+ "a dark painting in the style of {}",
384
+ "a picture in the style of {}",
385
+ "a cool painting in the style of {}",
386
+ "a close-up painting in the style of {}",
387
+ "a bright painting in the style of {}",
388
+ "a cropped painting in the style of {}",
389
+ "a good painting in the style of {}",
390
+ "a close-up painting in the style of {}",
391
+ "a rendition in the style of {}",
392
+ "a nice painting in the style of {}",
393
+ "a small painting in the style of {}",
394
+ "a weird painting in the style of {}",
395
+ "a large painting in the style of {}",
396
+ ]
397
+
398
+
399
+ class TextualInversionDataset(Dataset):
400
+ def __init__(
401
+ self,
402
+ data_root,
403
+ tokenizer,
404
+ learnable_property="object", # [object, style]
405
+ size=512,
406
+ repeats=100,
407
+ interpolation="bicubic",
408
+ flip_p=0.5,
409
+ set="train",
410
+ placeholder_token="*",
411
+ center_crop=False,
412
+ ):
413
+ self.data_root = data_root
414
+ self.tokenizer = tokenizer
415
+ self.learnable_property = learnable_property
416
+ self.size = size
417
+ self.placeholder_token = placeholder_token
418
+ self.center_crop = center_crop
419
+ self.flip_p = flip_p
420
+
421
+ self.image_paths = [
422
+ os.path.join(self.data_root, file_path)
423
+ for file_path in os.listdir(self.data_root)
424
+ ]
425
+
426
+ self.num_images = len(self.image_paths)
427
+ self._length = self.num_images
428
+
429
+ if set == "train":
430
+ self._length = self.num_images * repeats
431
+
432
+ self.interpolation = {
433
+ "linear": PIL_INTERPOLATION["linear"],
434
+ "bilinear": PIL_INTERPOLATION["bilinear"],
435
+ "bicubic": PIL_INTERPOLATION["bicubic"],
436
+ "lanczos": PIL_INTERPOLATION["lanczos"],
437
+ }[interpolation]
438
+
439
+ self.templates = (
440
+ imagenet_style_templates_small
441
+ if learnable_property == "style"
442
+ else imagenet_templates_small
443
+ )
444
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
445
+
446
+ def __len__(self):
447
+ return self._length
448
+
449
+ def __getitem__(self, i):
450
+ example = {}
451
+ image = Image.open(self.image_paths[i % self.num_images])
452
+
453
+ if not image.mode == "RGB":
454
+ image = image.convert("RGB")
455
+
456
+ placeholder_string = self.placeholder_token
457
+ text = random.choice(self.templates).format(placeholder_string)
458
+
459
+ example["input_ids"] = self.tokenizer(
460
+ text,
461
+ padding="max_length",
462
+ truncation=True,
463
+ max_length=self.tokenizer.model_max_length,
464
+ return_tensors="pt",
465
+ ).input_ids[0]
466
+
467
+ # default to score-sde preprocessing
468
+ img = np.array(image).astype(np.uint8)
469
+
470
+ if self.center_crop:
471
+ crop = min(img.shape[0], img.shape[1])
472
+ (h, w,) = (
473
+ img.shape[0],
474
+ img.shape[1],
475
+ )
476
+ img = img[
477
+ (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2
478
+ ]
479
+
480
+ image = Image.fromarray(img)
481
+ image = image.resize((self.size, self.size), resample=self.interpolation)
482
+
483
+ image = self.flip_transform(image)
484
+ image = np.array(image).astype(np.uint8)
485
+ image = (image / 127.5 - 1.0).astype(np.float32)
486
+
487
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
488
+ return example
489
+
490
+
491
+ def get_full_repo_name(
492
+ model_id: str, organization: Optional[str] = None, token: Optional[str] = None
493
+ ):
494
+ if token is None:
495
+ token = HfFolder.get_token()
496
+ if organization is None:
497
+ username = whoami(token)["name"]
498
+ return f"{username}/{model_id}"
499
+ else:
500
+ return f"{organization}/{model_id}"
501
+
502
+
503
+ def main(pipe, args_imported):
504
+
505
+ args = parse_args()
506
+ vars(args).update(vars(args_imported))
507
+
508
+ print(args)
509
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
510
+
511
+ accelerator = Accelerator(
512
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
513
+ mixed_precision=args.mixed_precision,
514
+ log_with=args.report_to,
515
+ logging_dir=logging_dir,
516
+ )
517
+
518
+ # Make one log on every process with the configuration for debugging.
519
+ logging.basicConfig(
520
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
521
+ datefmt="%m/%d/%Y %H:%M:%S",
522
+ level=logging.INFO,
523
+ )
524
+ logger.info(accelerator.state, main_process_only=False)
525
+ if accelerator.is_local_main_process:
526
+ datasets.utils.logging.set_verbosity_warning()
527
+ transformers.utils.logging.set_verbosity_warning()
528
+ diffusers.utils.logging.set_verbosity_info()
529
+ else:
530
+ datasets.utils.logging.set_verbosity_error()
531
+ transformers.utils.logging.set_verbosity_error()
532
+ diffusers.utils.logging.set_verbosity_error()
533
+
534
+ # If passed along, set the training seed now.
535
+ if args.seed is not None:
536
+ set_seed(args.seed)
537
+
538
+ # Handle the repository creation
539
+ if accelerator.is_main_process:
540
+ if args.push_to_hub:
541
+ if args.hub_model_id is None:
542
+ repo_name = get_full_repo_name(
543
+ Path(args.output_dir).name, token=args.hub_token
544
+ )
545
+ else:
546
+ repo_name = args.hub_model_id
547
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
548
+ repo = Repository(
549
+ args.output_dir, clone_from=repo_name, token=args.hub_token
550
+ )
551
+
552
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
553
+ if "step_*" not in gitignore:
554
+ gitignore.write("step_*\n")
555
+ if "epoch_*" not in gitignore:
556
+ gitignore.write("epoch_*\n")
557
+ elif args.output_dir is not None:
558
+ os.makedirs(args.output_dir, exist_ok=True)
559
+
560
+ # Load tokenizer
561
+ tokenizer = pipe.tokenizer
562
+
563
+ # Load scheduler and models
564
+ noise_scheduler = pipe.scheduler
565
+ text_encoder = pipe.text_encoder
566
+ vae = pipe.vae
567
+ unet = pipe.unet
568
+
569
+ # Add the placeholder token in tokenizer
570
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
571
+ if num_added_tokens == 0:
572
+ raise ValueError(
573
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
574
+ " `placeholder_token` that is not already in the tokenizer."
575
+ )
576
+
577
+ # Convert the initializer_token, placeholder_token to ids
578
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
579
+ # Check if initializer_token is a single token or a sequence of tokens
580
+ if len(token_ids) > 1:
581
+ raise ValueError("The initializer token must be a single token.")
582
+
583
+ initializer_token_id = token_ids[0]
584
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
585
+
586
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
587
+ text_encoder.resize_token_embeddings(len(tokenizer))
588
+
589
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
590
+ token_embeds = text_encoder.get_input_embeddings().weight.data
591
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
592
+
593
+ # Freeze vae and unet
594
+ vae.requires_grad_(False)
595
+ unet.requires_grad_(False)
596
+ # Freeze all parameters except for the token embeddings in text encoder
597
+ text_encoder.text_model.encoder.requires_grad_(False)
598
+ text_encoder.text_model.final_layer_norm.requires_grad_(False)
599
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
600
+
601
+ if args.gradient_checkpointing:
602
+ # Keep unet in train mode if we are using gradient checkpointing to save memory.
603
+ # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
604
+ unet.train()
605
+ text_encoder.gradient_checkpointing_enable()
606
+ unet.enable_gradient_checkpointing()
607
+
608
+ if args.enable_xformers_memory_efficient_attention:
609
+ if is_xformers_available():
610
+ unet.enable_xformers_memory_efficient_attention()
611
+ else:
612
+ raise ValueError(
613
+ "xformers is not available. Make sure it is installed correctly"
614
+ )
615
+
616
+ # Enable TF32 for faster training on Ampere GPUs,
617
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
618
+ if args.allow_tf32:
619
+ torch.backends.cuda.matmul.allow_tf32 = True
620
+
621
+ if args.scale_lr:
622
+ args.learning_rate = (
623
+ args.learning_rate
624
+ * args.gradient_accumulation_steps
625
+ * args.train_batch_size
626
+ * accelerator.num_processes
627
+ )
628
+
629
+ # Initialize the optimizer
630
+ optimizer = torch.optim.AdamW(
631
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
632
+ lr=args.learning_rate,
633
+ betas=(args.adam_beta1, args.adam_beta2),
634
+ weight_decay=args.adam_weight_decay,
635
+ eps=args.adam_epsilon,
636
+ )
637
+
638
+ # Dataset and DataLoaders creation:
639
+ train_dataset = TextualInversionDataset(
640
+ data_root=args.train_data_dir,
641
+ tokenizer=tokenizer,
642
+ size=args.resolution,
643
+ placeholder_token=args.placeholder_token,
644
+ repeats=args.repeats,
645
+ learnable_property=args.learnable_property,
646
+ center_crop=args.center_crop,
647
+ set="train",
648
+ )
649
+ train_dataloader = torch.utils.data.DataLoader(
650
+ train_dataset, batch_size=args.train_batch_size, shuffle=True
651
+ )
652
+
653
+ # Scheduler and math around the number of training steps.
654
+ overrode_max_train_steps = False
655
+ num_update_steps_per_epoch = math.ceil(
656
+ len(train_dataloader) / args.gradient_accumulation_steps
657
+ )
658
+ if args.max_train_steps is None:
659
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
660
+ overrode_max_train_steps = True
661
+
662
+ lr_scheduler = get_scheduler(
663
+ args.lr_scheduler,
664
+ optimizer=optimizer,
665
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
666
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
667
+ )
668
+
669
+ # Prepare everything with our `accelerator`.
670
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
671
+ text_encoder, optimizer, train_dataloader, lr_scheduler
672
+ )
673
+
674
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
675
+ # as these models are only used for inference, keeping weights in full precision is not required.
676
+ weight_dtype = torch.float32
677
+ if accelerator.mixed_precision == "fp16":
678
+ weight_dtype = torch.float16
679
+ elif accelerator.mixed_precision == "bf16":
680
+ weight_dtype = torch.bfloat16
681
+
682
+ # Move vae and unet to device and cast to weight_dtype
683
+ unet.to(accelerator.device, dtype=weight_dtype)
684
+ vae.to(accelerator.device, dtype=weight_dtype)
685
+ text_encoder.to(accelerator.device, dtype=torch.float32)
686
+
687
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
688
+ num_update_steps_per_epoch = math.ceil(
689
+ len(train_dataloader) / args.gradient_accumulation_steps
690
+ )
691
+ if overrode_max_train_steps:
692
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
693
+ # Afterwards we recalculate our number of training epochs
694
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
695
+
696
+ # We need to initialize the trackers we use, and also store our configuration.
697
+ # The trackers initializes automatically on the main process.
698
+ if accelerator.is_main_process:
699
+ accelerator.init_trackers("textual_inversion", config=vars(args))
700
+
701
+ # Train!
702
+ total_batch_size = (
703
+ args.train_batch_size
704
+ * accelerator.num_processes
705
+ * args.gradient_accumulation_steps
706
+ )
707
+
708
+ logger.info("***** Running training *****")
709
+ logger.info(f" Num examples = {len(train_dataset)}")
710
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
711
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
712
+ logger.info(
713
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
714
+ )
715
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
716
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
717
+ global_step = 0
718
+ first_epoch = 0
719
+
720
+ # Potentially load in the weights and states from a previous save
721
+ if args.resume_from_checkpoint:
722
+ if args.resume_from_checkpoint != "latest":
723
+ path = os.path.basename(args.resume_from_checkpoint)
724
+ else:
725
+ # Get the most recent checkpoint
726
+ dirs = os.listdir(args.output_dir)
727
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
728
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
729
+ path = dirs[-1]
730
+ accelerator.print(f"Resuming from checkpoint {path}")
731
+ accelerator.load_state(os.path.join(args.output_dir, path))
732
+ global_step = int(path.split("-")[1])
733
+
734
+ resume_global_step = global_step * args.gradient_accumulation_steps
735
+ first_epoch = resume_global_step // num_update_steps_per_epoch
736
+ resume_step = resume_global_step % num_update_steps_per_epoch
737
+
738
+ # Only show the progress bar once on each machine.
739
+ progress_bar = tqdm(
740
+ range(global_step, args.max_train_steps),
741
+ disable=not accelerator.is_local_main_process,
742
+ )
743
+ progress_bar.set_description("Steps")
744
+
745
+ # keep original embeddings as reference
746
+ orig_embeds_params = (
747
+ accelerator.unwrap_model(text_encoder)
748
+ .get_input_embeddings()
749
+ .weight.data.clone()
750
+ )
751
+
752
+ for epoch in range(first_epoch, args.num_train_epochs):
753
+ text_encoder.train()
754
+ for step, batch in enumerate(train_dataloader):
755
+ # Skip steps until we reach the resumed step
756
+ if (
757
+ args.resume_from_checkpoint
758
+ and epoch == first_epoch
759
+ and step < resume_step
760
+ ):
761
+ if step % args.gradient_accumulation_steps == 0:
762
+ progress_bar.update(1)
763
+ continue
764
+
765
+ with accelerator.accumulate(text_encoder):
766
+ # Convert images to latent space
767
+ latents = (
768
+ vae.encode(batch["pixel_values"].to(dtype=weight_dtype))
769
+ .latent_dist.sample()
770
+ .detach()
771
+ )
772
+ latents = latents * 0.18215
773
+
774
+ # Sample noise that we'll add to the latents
775
+ noise = torch.randn_like(latents)
776
+ bsz = latents.shape[0]
777
+ # Sample a random timestep for each image
778
+ timesteps = torch.randint(
779
+ 0,
780
+ noise_scheduler.config.num_train_timesteps,
781
+ (bsz,),
782
+ device=latents.device,
783
+ )
784
+ timesteps = timesteps.long()
785
+
786
+ # Add noise to the latents according to the noise magnitude at each timestep
787
+ # (this is the forward diffusion process)
788
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
789
+
790
+ # Get the text embedding for conditioning
791
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(
792
+ dtype=weight_dtype
793
+ )
794
+
795
+ # Predict the noise residual
796
+ model_pred = unet(
797
+ noisy_latents, timesteps, encoder_hidden_states
798
+ ).sample
799
+
800
+ # Get the target for loss depending on the prediction type
801
+ if noise_scheduler.config.prediction_type == "epsilon":
802
+ target = noise
803
+ elif noise_scheduler.config.prediction_type == "v_prediction":
804
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
805
+ else:
806
+ raise ValueError(
807
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
808
+ )
809
+
810
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
811
+
812
+ accelerator.backward(loss)
813
+
814
+ if accelerator.num_processes > 1:
815
+ grads = text_encoder.module.get_input_embeddings().weight.grad
816
+ else:
817
+ grads = text_encoder.get_input_embeddings().weight.grad
818
+ # Get the index for tokens that we want to zero the grads for
819
+ index_grads_to_zero = (
820
+ torch.arange(len(tokenizer)) != placeholder_token_id
821
+ )
822
+ grads.data[index_grads_to_zero, :] = grads.data[
823
+ index_grads_to_zero, :
824
+ ].fill_(0)
825
+
826
+ optimizer.step()
827
+ lr_scheduler.step()
828
+ optimizer.zero_grad()
829
+
830
+ # Let's make sure we don't update any embedding weights besides the newly added token
831
+ index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
832
+ with torch.no_grad():
833
+ accelerator.unwrap_model(
834
+ text_encoder
835
+ ).get_input_embeddings().weight[
836
+ index_no_updates
837
+ ] = orig_embeds_params[
838
+ index_no_updates
839
+ ]
840
+
841
+ # Checks if the accelerator has performed an optimization step behind the scenes
842
+ if accelerator.sync_gradients:
843
+ progress_bar.update(1)
844
+ global_step += 1
845
+ if global_step % args.save_steps == 0:
846
+ save_path = os.path.join(
847
+ args.output_dir, f"{args.placeholder_token}-{global_step}.bin"
848
+ )
849
+ save_progress(
850
+ text_encoder, placeholder_token_id, accelerator, args, save_path
851
+ )
852
+
853
+ if global_step % args.checkpointing_steps == 0:
854
+ if accelerator.is_main_process:
855
+ save_path = os.path.join(
856
+ args.output_dir, f"checkpoint-{global_step}"
857
+ )
858
+ accelerator.save_state(save_path)
859
+ logger.info(f"Saved state to {save_path}")
860
+
861
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
862
+ progress_bar.set_postfix(**logs)
863
+ accelerator.log(logs, step=global_step)
864
+
865
+ if global_step >= args.max_train_steps:
866
+ break
867
+
868
+ # Create the pipeline using using the trained modules and save it.
869
+ accelerator.wait_for_everyone()
870
+ if accelerator.is_main_process:
871
+ if args.push_to_hub and args.only_save_embeds:
872
+ logger.warn(
873
+ "Enabling full model saving because --push_to_hub=True was specified."
874
+ )
875
+ save_full_model = True
876
+ else:
877
+ save_full_model = not args.only_save_embeds
878
+ if save_full_model:
879
+ pipe.save_pretrained(args.output_dir)
880
+ # Save the newly trained embeddings
881
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
882
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
883
+
884
+ if args.push_to_hub:
885
+ repo.push_to_hub(
886
+ commit_message="End of training", blocking=False, auto_lfs_prune=True
887
+ )
888
+
889
+ accelerator.end_training()
890
+
891
+
892
+ if __name__ == "__main__":
893
+ pipeline = StableDiffusionPipeline.from_pretrained(
894
+ "andite/anything-v4.0", torch_dtype=torch.float16
895
+ )
896
+
897
+ imported_args = argparse.Namespace(
898
+ train_data_dir="concept_images",
899
+ learnable_property="object",
900
+ placeholder_token="redeyegirl",
901
+ initializer_token="girl",
902
+ resolution=512,
903
+ train_batch_size=1,
904
+ gradient_accumulation_steps=1,
905
+ gradient_checkpointing=True,
906
+ mixed_precision="fp16",
907
+ use_bf16=False,
908
+ max_train_steps=1000,
909
+ learning_rate=5.0e-4,
910
+ scale_lr=False,
911
+ lr_scheduler="constant",
912
+ lr_warmup_steps=0,
913
+ output_dir="output_model",
914
+ )
915
+
916
+ main(pipeline, imported_args)