ManglerFTW
commited on
Commit
•
3a18eba
1
Parent(s):
ba2acac
Upload 12 files
Browse files- StableTuner_RunPod_Fix/captionBuddy.py +967 -0
- StableTuner_RunPod_Fix/clip_segmentation.py +325 -0
- StableTuner_RunPod_Fix/configuration_gui.py +0 -0
- StableTuner_RunPod_Fix/convert_diffusers_to_sd_cli.py +22 -0
- StableTuner_RunPod_Fix/converters.py +120 -0
- StableTuner_RunPod_Fix/dataloaders_util.py +1331 -0
- StableTuner_RunPod_Fix/discriminator.py +764 -0
- StableTuner_RunPod_Fix/lion_pytorch.py +88 -0
- StableTuner_RunPod_Fix/lora_utils.py +236 -0
- StableTuner_RunPod_Fix/model_util.py +1543 -0
- StableTuner_RunPod_Fix/trainer.py +1750 -0
- StableTuner_RunPod_Fix/trainer_util.py +435 -0
StableTuner_RunPod_Fix/captionBuddy.py
ADDED
@@ -0,0 +1,967 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tkinter as tk
|
2 |
+
from tkinter import ttk, Menu
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
from PIL import Image, ImageTk, ImageDraw
|
6 |
+
import tkinter.filedialog as fd
|
7 |
+
import json
|
8 |
+
import sys
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode
|
13 |
+
import torch
|
14 |
+
import subprocess
|
15 |
+
import numpy as np
|
16 |
+
import requests
|
17 |
+
import random
|
18 |
+
import customtkinter as ctk
|
19 |
+
from customtkinter import ThemeManager
|
20 |
+
|
21 |
+
from clip_segmentation import ClipSeg
|
22 |
+
|
23 |
+
#main class
|
24 |
+
ctk.set_appearance_mode("dark")
|
25 |
+
ctk.set_default_color_theme("blue")
|
26 |
+
|
27 |
+
class BatchMaskWindow(ctk.CTkToplevel):
|
28 |
+
def __init__(self, parent, path, *args, **kwargs):
|
29 |
+
ctk.CTkToplevel.__init__(self, parent, *args, **kwargs)
|
30 |
+
self.parent = parent
|
31 |
+
|
32 |
+
self.title("Batch process masks")
|
33 |
+
self.geometry("320x310")
|
34 |
+
self.resizable(False, False)
|
35 |
+
self.wait_visibility()
|
36 |
+
self.grab_set()
|
37 |
+
self.focus_set()
|
38 |
+
|
39 |
+
self.mode_var = tk.StringVar(self, "Create if absent")
|
40 |
+
self.modes = ["Replace all masks", "Create if absent", "Add to existing", "Subtract from existing"]
|
41 |
+
|
42 |
+
self.frame = ctk.CTkFrame(self, width=600, height=300)
|
43 |
+
self.frame.grid(row=0, column=0, sticky="nsew", padx=10, pady=10)
|
44 |
+
|
45 |
+
self.path_label = ctk.CTkLabel(self.frame, text="Folder", width=100)
|
46 |
+
self.path_label.grid(row=0, column=0, sticky="w",padx=5, pady=5)
|
47 |
+
self.path_entry = ctk.CTkEntry(self.frame, width=150)
|
48 |
+
self.path_entry.insert(0, path)
|
49 |
+
self.path_entry.grid(row=0, column=1, sticky="w", padx=5, pady=5)
|
50 |
+
self.path_button = ctk.CTkButton(self.frame, width=30, text="...", command=lambda: self.browse_for_path(self.path_entry))
|
51 |
+
self.path_button.grid(row=0, column=1, sticky="e", padx=5, pady=5)
|
52 |
+
|
53 |
+
self.prompt_label = ctk.CTkLabel(self.frame, text="Prompt", width=100)
|
54 |
+
self.prompt_label.grid(row=1, column=0, sticky="w",padx=5, pady=5)
|
55 |
+
self.prompt_entry = ctk.CTkEntry(self.frame, width=200)
|
56 |
+
self.prompt_entry.grid(row=1, column=1, sticky="w", padx=5, pady=5)
|
57 |
+
|
58 |
+
self.mode_label = ctk.CTkLabel(self.frame, text="Mode", width=100)
|
59 |
+
self.mode_label.grid(row=2, column=0, sticky="w", padx=5, pady=5)
|
60 |
+
self.mode_dropdown = ctk.CTkOptionMenu(self.frame, variable=self.mode_var, values=self.modes, dynamic_resizing=False, width=200)
|
61 |
+
self.mode_dropdown.grid(row=2, column=1, sticky="w", padx=5, pady=5)
|
62 |
+
|
63 |
+
self.threshold_label = ctk.CTkLabel(self.frame, text="Threshold", width=100)
|
64 |
+
self.threshold_label.grid(row=3, column=0, sticky="w", padx=5, pady=5)
|
65 |
+
self.threshold_entry = ctk.CTkEntry(self.frame, width=200, placeholder_text="0.0 - 1.0")
|
66 |
+
self.threshold_entry.insert(0, "0.3")
|
67 |
+
self.threshold_entry.grid(row=3, column=1, sticky="w", padx=5, pady=5)
|
68 |
+
|
69 |
+
self.smooth_label = ctk.CTkLabel(self.frame, text="Smooth", width=100)
|
70 |
+
self.smooth_label.grid(row=4, column=0, sticky="w", padx=5, pady=5)
|
71 |
+
self.smooth_entry = ctk.CTkEntry(self.frame, width=200, placeholder_text="5")
|
72 |
+
self.smooth_entry.insert(0, 5)
|
73 |
+
self.smooth_entry.grid(row=4, column=1, sticky="w", padx=5, pady=5)
|
74 |
+
|
75 |
+
self.expand_label = ctk.CTkLabel(self.frame, text="Expand", width=100)
|
76 |
+
self.expand_label.grid(row=5, column=0, sticky="w", padx=5, pady=5)
|
77 |
+
self.expand_entry = ctk.CTkEntry(self.frame, width=200, placeholder_text="10")
|
78 |
+
self.expand_entry.insert(0, 10)
|
79 |
+
self.expand_entry.grid(row=5, column=1, sticky="w", padx=5, pady=5)
|
80 |
+
|
81 |
+
self.progress_label = ctk.CTkLabel(self.frame, text="Progress: 0/0", width=100)
|
82 |
+
self.progress_label.grid(row=6, column=0, sticky="w", padx=5, pady=5)
|
83 |
+
self.progress = ctk.CTkProgressBar(self.frame, orientation="horizontal", mode="determinate", width=200)
|
84 |
+
self.progress.grid(row=6, column=1, sticky="w", padx=5, pady=5)
|
85 |
+
|
86 |
+
self.create_masks_button = ctk.CTkButton(self.frame, text="Create Masks", width=310, command=self.create_masks)
|
87 |
+
self.create_masks_button.grid(row=7, column=0, columnspan=2, sticky="w", padx=5, pady=5)
|
88 |
+
|
89 |
+
self.frame.pack(fill="both", expand=True)
|
90 |
+
|
91 |
+
def browse_for_path(self, entry_box):
|
92 |
+
# get the path from the user
|
93 |
+
path = fd.askdirectory()
|
94 |
+
# set the path to the entry box
|
95 |
+
# delete entry box text
|
96 |
+
entry_box.focus_set()
|
97 |
+
entry_box.delete(0, tk.END)
|
98 |
+
entry_box.insert(0, path)
|
99 |
+
self.focus_set()
|
100 |
+
|
101 |
+
def set_progress(self, value, max_value):
|
102 |
+
progress = value / max_value
|
103 |
+
self.progress.set(progress)
|
104 |
+
self.progress_label.configure(text="{0}/{1}".format(value, max_value))
|
105 |
+
self.progress.update()
|
106 |
+
|
107 |
+
def create_masks(self):
|
108 |
+
self.parent.load_clip_seg_model()
|
109 |
+
|
110 |
+
mode = {
|
111 |
+
"Replace all masks": "replace",
|
112 |
+
"Create if absent": "fill",
|
113 |
+
"Add to existing": "add",
|
114 |
+
"Subtract from existing": "subtract"
|
115 |
+
}[self.mode_var.get()]
|
116 |
+
|
117 |
+
self.parent.clip_seg.mask_folder(
|
118 |
+
sample_dir=self.path_entry.get(),
|
119 |
+
prompts=[self.prompt_entry.get()],
|
120 |
+
mode=mode,
|
121 |
+
threshold=float(self.threshold_entry.get()),
|
122 |
+
smooth_pixels=int(self.smooth_entry.get()),
|
123 |
+
expand_pixels=int(self.expand_entry.get()),
|
124 |
+
progress_callback=self.set_progress,
|
125 |
+
)
|
126 |
+
self.parent.load_image()
|
127 |
+
|
128 |
+
|
129 |
+
def _check_file_type(f: str) -> bool:
|
130 |
+
return f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', ".bmp", ".tiff"))
|
131 |
+
|
132 |
+
|
133 |
+
class ImageBrowser(ctk.CTkToplevel):
|
134 |
+
def __init__(self,mainProcess=None):
|
135 |
+
super().__init__()
|
136 |
+
if not os.path.exists("scripts/BLIP"):
|
137 |
+
print("Getting BLIP from GitHub.")
|
138 |
+
subprocess.run(["git", "clone", "https://github.com/salesforce/BLIP", "scripts/BLIP"])
|
139 |
+
#if not os.path.exists("scripts/CLIP"):
|
140 |
+
# print("Getting CLIP from GitHub.")
|
141 |
+
# subprocess.run(["git", "clone", "https://github.com/pharmapsychotic/clip-interrogator.git', 'scripts/CLIP"])
|
142 |
+
blip_path = "scripts/BLIP"
|
143 |
+
sys.path.append(blip_path)
|
144 |
+
#clip_path = "scripts/CLIP"
|
145 |
+
#sys.path.append(clip_path)
|
146 |
+
self.mainProcess = mainProcess
|
147 |
+
self.captioner_folder = os.path.dirname(os.path.realpath(__file__))
|
148 |
+
self.clip_seg = None
|
149 |
+
self.PILimage = None
|
150 |
+
self.PILmask = None
|
151 |
+
self.mask_draw_x = 0
|
152 |
+
self.mask_draw_y = 0
|
153 |
+
self.mask_draw_radius = 20
|
154 |
+
#self = master
|
155 |
+
#self.overrideredirect(True)
|
156 |
+
#self.title_bar = TitleBar(self)
|
157 |
+
#self.title_bar.pack(side="top", fill="x")
|
158 |
+
#make not user resizable
|
159 |
+
self.title("Caption Buddy")
|
160 |
+
#self.resizable(False, False)
|
161 |
+
self.geometry("720x820")
|
162 |
+
self.top_frame = ctk.CTkFrame(self,fg_color='transparent')
|
163 |
+
self.top_frame.pack(side="top", fill="x",expand=False)
|
164 |
+
self.top_subframe = ctk.CTkFrame(self.top_frame,fg_color='transparent')
|
165 |
+
self.top_subframe.pack(side="bottom", fill="x",pady=10)
|
166 |
+
self.top_subframe.grid_columnconfigure(0, weight=1)
|
167 |
+
self.top_subframe.grid_columnconfigure(1, weight=1)
|
168 |
+
self.tip_frame = ctk.CTkFrame(self,fg_color='transparent')
|
169 |
+
self.tip_frame.pack(side="top")
|
170 |
+
self.dark_mode_var = "#202020"
|
171 |
+
#self.dark_purple_mode_var = "#1B0F1B"
|
172 |
+
self.dark_mode_title_var = "#286aff"
|
173 |
+
self.dark_mode_button_pressed_var = "#BB91B6"
|
174 |
+
self.dark_mode_button_var = "#8ea0e1"
|
175 |
+
self.dark_mode_text_var = "#c6c7c8"
|
176 |
+
#self.configure(bg_color=self.dark_mode_var)
|
177 |
+
self.canvas = ctk.CTkLabel(self,text='', width=600, height=600)
|
178 |
+
#self.canvas.configure(bg_color=self.dark_mode_var)
|
179 |
+
#create temporary image for canvas
|
180 |
+
self.canvas.pack()
|
181 |
+
self.cur_img_index = 0
|
182 |
+
self.image_count = 0
|
183 |
+
#make a frame with a grid under the canvas
|
184 |
+
self.frame = ctk.CTkFrame(self)
|
185 |
+
#grid
|
186 |
+
self.frame.grid_columnconfigure(0, weight=1)
|
187 |
+
self.frame.grid_columnconfigure(1, weight=100)
|
188 |
+
self.frame.grid_columnconfigure(2, weight=1)
|
189 |
+
self.frame.grid_rowconfigure(0, weight=1)
|
190 |
+
|
191 |
+
#show the frame
|
192 |
+
self.frame.pack(side="bottom", fill="x")
|
193 |
+
#bottom frame
|
194 |
+
self.bottom_frame = ctk.CTkFrame(self)
|
195 |
+
#make grid
|
196 |
+
self.bottom_frame.grid_columnconfigure(0, weight=0)
|
197 |
+
self.bottom_frame.grid_columnconfigure(1, weight=2)
|
198 |
+
self.bottom_frame.grid_columnconfigure(2, weight=0)
|
199 |
+
self.bottom_frame.grid_columnconfigure(3, weight=2)
|
200 |
+
self.bottom_frame.grid_columnconfigure(4, weight=0)
|
201 |
+
self.bottom_frame.grid_columnconfigure(5, weight=2)
|
202 |
+
self.bottom_frame.grid_rowconfigure(0, weight=1)
|
203 |
+
#show the frame
|
204 |
+
self.bottom_frame.pack(side="bottom", fill="x")
|
205 |
+
|
206 |
+
self.image_index = 0
|
207 |
+
self.image_list = []
|
208 |
+
self.caption = ''
|
209 |
+
self.caption_file = ''
|
210 |
+
self.caption_file_path = ''
|
211 |
+
self.caption_file_name = ''
|
212 |
+
self.caption_file_ext = ''
|
213 |
+
self.caption_file_name_no_ext = ''
|
214 |
+
self.output_format='text'
|
215 |
+
#check if bad_files.txt exists
|
216 |
+
if os.path.exists("bad_files.txt"):
|
217 |
+
#delete it
|
218 |
+
os.remove("bad_files.txt")
|
219 |
+
self.use_blip = True
|
220 |
+
self.debug = False
|
221 |
+
self.create_widgets()
|
222 |
+
self.load_blip_model()
|
223 |
+
self.load_options()
|
224 |
+
#self.open_folder()
|
225 |
+
|
226 |
+
self.canvas.focus_force()
|
227 |
+
self.canvas.bind("<Alt-Right>", self.next_image)
|
228 |
+
self.canvas.bind("<Alt-Left>", self.prev_image)
|
229 |
+
#on close window
|
230 |
+
self.protocol("WM_DELETE_WINDOW", self.on_closing)
|
231 |
+
def on_closing(self):
|
232 |
+
#self.save_options()
|
233 |
+
self.mainProcess.deiconify()
|
234 |
+
self.destroy()
|
235 |
+
def create_widgets(self):
|
236 |
+
self.output_folder = ''
|
237 |
+
|
238 |
+
# add a checkbox to toggle auto generate caption
|
239 |
+
self.auto_generate_caption = tk.BooleanVar(self.top_subframe)
|
240 |
+
self.auto_generate_caption.set(True)
|
241 |
+
self.auto_generate_caption_checkbox = ctk.CTkCheckBox(self.top_subframe, text="Auto Generate Caption", variable=self.auto_generate_caption,width=50)
|
242 |
+
self.auto_generate_caption_checkbox.pack(side="left", fill="x", expand=True, padx=10)
|
243 |
+
|
244 |
+
# add a checkbox to skip auto generating captions if they already exist
|
245 |
+
self.auto_generate_caption_text_override = tk.BooleanVar(self.top_subframe)
|
246 |
+
self.auto_generate_caption_text_override.set(False)
|
247 |
+
self.auto_generate_caption_checkbox_text_override = ctk.CTkCheckBox(self.top_subframe, text="Skip Auto Generate If Text Caption Exists", variable=self.auto_generate_caption_text_override,width=50)
|
248 |
+
self.auto_generate_caption_checkbox_text_override.pack(side="left", fill="x", expand=True, padx=10)
|
249 |
+
|
250 |
+
# add a checkbox to enable mask editing
|
251 |
+
self.enable_mask_editing = tk.BooleanVar(self.top_subframe)
|
252 |
+
self.enable_mask_editing.set(False)
|
253 |
+
self.enable_mask_editing_checkbox = ctk.CTkCheckBox(self.top_subframe, text="Enable Mask Editing", variable=self.enable_mask_editing, width=50)
|
254 |
+
self.enable_mask_editing_checkbox.pack(side="left", fill="x", expand=True, padx=10)
|
255 |
+
|
256 |
+
self.open_button = ctk.CTkButton(self.top_frame,text="Load Folder",fg_color=("gray75", "gray25"), command=self.open_folder,width=50)
|
257 |
+
#self.open_button.grid(row=0, column=1)
|
258 |
+
self.open_button.pack(side="left", fill="x",expand=True,padx=10)
|
259 |
+
#add a batch folder button
|
260 |
+
self.batch_folder_caption_button = ctk.CTkButton(self.top_frame, text="Batch Folder Caption", fg_color=("gray75", "gray25"), command=self.batch_folder_caption, width=50)
|
261 |
+
self.batch_folder_caption_button.pack(side="left", fill="x", expand=True, padx=10)
|
262 |
+
self.batch_folder_mask_button = ctk.CTkButton(self.top_frame, text="Batch Folder Mask", fg_color=("gray75", "gray25"), command=self.batch_folder_mask, width=50)
|
263 |
+
self.batch_folder_mask_button.pack(side="left", fill="x", expand=True, padx=10)
|
264 |
+
|
265 |
+
#add an options button to the same row as the open button
|
266 |
+
self.options_button = ctk.CTkButton(self.top_frame, text="Options",fg_color=("gray75", "gray25"), command=self.open_options,width=50)
|
267 |
+
self.options_button.pack(side="left", fill="x",expand=True,padx=10)
|
268 |
+
#add generate caption button
|
269 |
+
self.generate_caption_button = ctk.CTkButton(self.top_frame, text="Generate Caption",fg_color=("gray75", "gray25"), command=self.generate_caption,width=50)
|
270 |
+
self.generate_caption_button.pack(side="left", fill="x",expand=True,padx=10)
|
271 |
+
|
272 |
+
#add a label for tips under the buttons
|
273 |
+
self.tips_label = ctk.CTkLabel(self.tip_frame, text="Use Alt with left and right arrow keys to navigate images, enter to save the caption.")
|
274 |
+
self.tips_label.pack(side="top")
|
275 |
+
#add image count label
|
276 |
+
self.image_count_label = ctk.CTkLabel(self.tip_frame, text=f"Image {self.cur_img_index} of {self.image_count}")
|
277 |
+
self.image_count_label.pack(side="top")
|
278 |
+
|
279 |
+
self.image_label = ctk.CTkLabel(self.canvas,text='',width=100,height=100)
|
280 |
+
self.image_label.grid(row=0, column=0, sticky="nsew")
|
281 |
+
#self.image_label.bind("<Button-3>", self.click_canvas)
|
282 |
+
self.image_label.bind("<Motion>", self.draw_mask)
|
283 |
+
self.image_label.bind("<Button-1>", self.draw_mask)
|
284 |
+
self.image_label.bind("<Button-3>", self.draw_mask)
|
285 |
+
self.image_label.bind("<MouseWheel>", self.draw_mask_radius)
|
286 |
+
#self.image_label.pack(side="top")
|
287 |
+
#previous button
|
288 |
+
self.prev_button = ctk.CTkButton(self.frame,text="Previous", command= lambda event=None: self.prev_image(event),width=50)
|
289 |
+
#grid
|
290 |
+
self.prev_button.grid(row=1, column=0, sticky="w",padx=5,pady=10)
|
291 |
+
#self.prev_button.pack(side="left")
|
292 |
+
#self.prev_button.bind("<Left>", self.prev_image)
|
293 |
+
self.caption_entry = ctk.CTkEntry(self.frame)
|
294 |
+
#grid
|
295 |
+
self.caption_entry.grid(row=1, column=1, rowspan=3, sticky="nsew",pady=10)
|
296 |
+
#bind to enter key
|
297 |
+
self.caption_entry.bind("<Return>", self.save)
|
298 |
+
self.canvas.bind("<Return>", self.save)
|
299 |
+
self.caption_entry.bind("<Alt-Right>", self.next_image)
|
300 |
+
self.caption_entry.bind("<Alt-Left>", self.prev_image)
|
301 |
+
self.caption_entry.bind("<Control-BackSpace>", self.delete_word)
|
302 |
+
#next button
|
303 |
+
|
304 |
+
self.next_button = ctk.CTkButton(self.frame,text='Next', command= lambda event=None: self.next_image(event),width=50)
|
305 |
+
#self.next_button["text"] = "Next"
|
306 |
+
#grid
|
307 |
+
self.next_button.grid(row=1, column=2, sticky="e",padx=5,pady=10)
|
308 |
+
#add two entry boxes and labels in the style of :replace _ with _
|
309 |
+
#create replace string variable
|
310 |
+
self.replace_label = ctk.CTkLabel(self.bottom_frame, text="Replace:")
|
311 |
+
self.replace_label.grid(row=0, column=0, sticky="w",padx=5)
|
312 |
+
self.replace_entry = ctk.CTkEntry(self.bottom_frame, )
|
313 |
+
self.replace_entry.grid(row=0, column=1, sticky="nsew",padx=5)
|
314 |
+
self.replace_entry.bind("<Return>", self.save)
|
315 |
+
#self.replace_entry.bind("<Tab>", self.replace)
|
316 |
+
#with label
|
317 |
+
#create with string variable
|
318 |
+
self.with_label = ctk.CTkLabel(self.bottom_frame, text="With:")
|
319 |
+
self.with_label.grid(row=0, column=2, sticky="w",padx=5)
|
320 |
+
self.with_entry = ctk.CTkEntry(self.bottom_frame, )
|
321 |
+
self.with_entry.grid(row=0, column=3, sticky="nswe",padx=5)
|
322 |
+
self.with_entry.bind("<Return>", self.save)
|
323 |
+
#add another entry with label, add suffix
|
324 |
+
|
325 |
+
#create prefix string var
|
326 |
+
self.prefix_label = ctk.CTkLabel(self.bottom_frame, text="Add to start:")
|
327 |
+
self.prefix_label.grid(row=0, column=4, sticky="w",padx=5)
|
328 |
+
self.prefix_entry = ctk.CTkEntry(self.bottom_frame, )
|
329 |
+
self.prefix_entry.grid(row=0, column=5, sticky="nsew",padx=5)
|
330 |
+
self.prefix_entry.bind("<Return>", self.save)
|
331 |
+
|
332 |
+
#create suffix string var
|
333 |
+
self.suffix_label = ctk.CTkLabel(self.bottom_frame, text="Add to end:")
|
334 |
+
self.suffix_label.grid(row=0, column=6, sticky="w",padx=5)
|
335 |
+
self.suffix_entry = ctk.CTkEntry(self.bottom_frame, )
|
336 |
+
self.suffix_entry.grid(row=0, column=7, sticky="nsew",padx=5)
|
337 |
+
self.suffix_entry.bind("<Return>", self.save)
|
338 |
+
self.all_entries = [self.replace_entry, self.with_entry, self.suffix_entry, self.caption_entry, self.prefix_entry]
|
339 |
+
#bind right click menu to all entries
|
340 |
+
for entry in self.all_entries:
|
341 |
+
entry.bind("<Button-3>", self.create_right_click_menu)
|
342 |
+
def batch_folder_caption(self):
|
343 |
+
#show imgs in folder askdirectory
|
344 |
+
#ask user if to batch current folder or select folder
|
345 |
+
#if bad_files.txt exists, delete it
|
346 |
+
self.bad_files = []
|
347 |
+
if os.path.exists('bad_files.txt'):
|
348 |
+
os.remove('bad_files.txt')
|
349 |
+
try:
|
350 |
+
#check if self.folder is set
|
351 |
+
self.folder
|
352 |
+
except AttributeError:
|
353 |
+
self.folder = ''
|
354 |
+
if self.folder == '':
|
355 |
+
self.folder = fd.askdirectory(title="Select Folder to Batch Process", initialdir=os.getcwd())
|
356 |
+
batch_input_dir = self.folder
|
357 |
+
else:
|
358 |
+
ask = tk.messagebox.askquestion("Batch Folder", "Batch current folder?")
|
359 |
+
if ask == 'yes':
|
360 |
+
batch_input_dir = self.folder
|
361 |
+
else:
|
362 |
+
batch_input_dir = fd.askdirectory(title="Select Folder to Batch Process", initialdir=os.getcwd())
|
363 |
+
ask2 = tk.messagebox.askquestion("Batch Folder", "Save output to same directory?")
|
364 |
+
if ask2 == 'yes':
|
365 |
+
batch_output_dir = batch_input_dir
|
366 |
+
else:
|
367 |
+
batch_output_dir = fd.askdirectory(title="Select Folder to Save Batch Processed Images", initialdir=os.getcwd())
|
368 |
+
if batch_input_dir == '':
|
369 |
+
return
|
370 |
+
if batch_output_dir == '':
|
371 |
+
batch_output_dir = batch_input_dir
|
372 |
+
|
373 |
+
self.caption_file_name = os.path.basename(batch_input_dir)
|
374 |
+
self.image_list = []
|
375 |
+
for file in os.listdir(batch_input_dir):
|
376 |
+
if _check_file_type(file) and not file.endswith('-masklabel.png'):
|
377 |
+
self.image_list.append(os.path.join(batch_input_dir, file))
|
378 |
+
self.image_index = 0
|
379 |
+
#use progress bar class
|
380 |
+
#pba = tk.Tk()
|
381 |
+
#pba.title("Batch Processing")
|
382 |
+
#remove icon
|
383 |
+
#pba.wm_attributes('-toolwindow','True')
|
384 |
+
pb = ProgressbarWithCancel(max=len(self.image_list))
|
385 |
+
#pb.set_max(len(self.image_list))
|
386 |
+
pb.set_progress(0)
|
387 |
+
|
388 |
+
#if batch_output_dir doesn't exist, create it
|
389 |
+
if not os.path.exists(batch_output_dir):
|
390 |
+
os.makedirs(batch_output_dir)
|
391 |
+
for i in range(len(self.image_list)):
|
392 |
+
radnom_chance = random.randint(0,25)
|
393 |
+
if radnom_chance == 0:
|
394 |
+
pb.set_random_label()
|
395 |
+
if pb.is_cancelled():
|
396 |
+
pb.destroy()
|
397 |
+
return
|
398 |
+
self.image_index = i
|
399 |
+
#get float value of progress between 0 and 1 according to the image index and the total number of images
|
400 |
+
progress = i / len(self.image_list)
|
401 |
+
pb.set_progress(progress)
|
402 |
+
self.update()
|
403 |
+
try:
|
404 |
+
img = Image.open(self.image_list[i]).convert("RGB")
|
405 |
+
except:
|
406 |
+
self.bad_files.append(self.image_list[i])
|
407 |
+
#skip file
|
408 |
+
continue
|
409 |
+
tensor = transforms.Compose([
|
410 |
+
transforms.Resize((self.blipSize, self.blipSize), interpolation=InterpolationMode.BICUBIC),
|
411 |
+
transforms.ToTensor(),
|
412 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
413 |
+
])
|
414 |
+
torch_image = tensor(img).unsqueeze(0).to(torch.device("cuda"))
|
415 |
+
if self.nucleus_sampling:
|
416 |
+
captions = self.blip_decoder.generate(torch_image, sample=True, top_p=self.q_factor)
|
417 |
+
else:
|
418 |
+
captions = self.blip_decoder.generate(torch_image, sample=False, num_beams=16, min_length=self.min_length, \
|
419 |
+
max_length=48, repetition_penalty=self.q_factor)
|
420 |
+
caption = captions[0]
|
421 |
+
self.replace = self.replace_entry.get()
|
422 |
+
self.replace_with = self.with_entry.get()
|
423 |
+
self.suffix_var = self.suffix_entry.get()
|
424 |
+
self.prefix = self.prefix_entry.get()
|
425 |
+
#prepare the caption
|
426 |
+
if self.suffix_var.startswith(',') or self.suffix_var.startswith(' '):
|
427 |
+
self.suffix_var = self.suffix_var
|
428 |
+
else:
|
429 |
+
self.suffix_var = ' ' + self.suffix_var
|
430 |
+
caption = caption.replace(self.replace, self.replace_with)
|
431 |
+
if self.prefix != '':
|
432 |
+
if self.prefix.endswith(' '):
|
433 |
+
self.prefix = self.prefix[:-1]
|
434 |
+
if not self.prefix.endswith(','):
|
435 |
+
self.prefix = self.prefix+','
|
436 |
+
caption = self.prefix + ' ' + caption
|
437 |
+
if caption.endswith(',') or caption.endswith('.'):
|
438 |
+
caption = caption[:-1]
|
439 |
+
caption = caption +', ' + self.suffix_var
|
440 |
+
else:
|
441 |
+
caption = caption + self.suffix_var
|
442 |
+
#saving the captioned image
|
443 |
+
if self.output_format == 'text':
|
444 |
+
#text file with same name as image
|
445 |
+
imgName = os.path.basename(self.image_list[self.image_index])
|
446 |
+
imgName = imgName[:imgName.rfind('.')]
|
447 |
+
caption_file = os.path.join(batch_output_dir, imgName + '.txt')
|
448 |
+
with open(caption_file, 'w') as f:
|
449 |
+
f.write(caption)
|
450 |
+
elif self.output_format == 'filename':
|
451 |
+
#duplicate image with caption as file name
|
452 |
+
img.save(os.path.join(batch_output_dir, caption+'.png'))
|
453 |
+
progress = i + 1 / len(self.image_list)
|
454 |
+
pb.set_progress(progress)
|
455 |
+
#show message box when done
|
456 |
+
pb.destroy()
|
457 |
+
donemsg = tk.messagebox.showinfo("Batch Folder", "Batching complete!",parent=self.master)
|
458 |
+
if len(self.bad_files) > 0:
|
459 |
+
bad_files_msg = tk.messagebox.showinfo("Bad Files", "Couldn't process " + str(len(self.bad_files)) + "files,\nFor a list of problematic files see bad_files.txt",parent=self.master)
|
460 |
+
with open('bad_files.txt', 'w') as f:
|
461 |
+
for item in self.bad_files:
|
462 |
+
f.write(item + '\n')
|
463 |
+
|
464 |
+
#ask user if we should load the batch output folder
|
465 |
+
ask3 = tk.messagebox.askquestion("Batch Folder", "Load batch output folder?")
|
466 |
+
if ask3 == 'yes':
|
467 |
+
self.image_index = 0
|
468 |
+
self.open_folder(folder=batch_output_dir)
|
469 |
+
#focus on donemsg
|
470 |
+
#donemsg.focus_force()
|
471 |
+
def generate_caption(self):
|
472 |
+
#get the image
|
473 |
+
tensor = transforms.Compose([
|
474 |
+
#transforms.CenterCrop(SIZE),
|
475 |
+
transforms.Resize((self.blipSize, self.blipSize), interpolation=InterpolationMode.BICUBIC),
|
476 |
+
transforms.ToTensor(),
|
477 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
478 |
+
])
|
479 |
+
torch_image = tensor(self.PILimage).unsqueeze(0).to(torch.device("cuda"))
|
480 |
+
if self.nucleus_sampling:
|
481 |
+
captions = self.blip_decoder.generate(torch_image, sample=True, top_p=self.q_factor)
|
482 |
+
else:
|
483 |
+
captions = self.blip_decoder.generate(torch_image, sample=False, num_beams=16, min_length=self.min_length, \
|
484 |
+
max_length=48, repetition_penalty=self.q_factor)
|
485 |
+
self.caption = captions[0]
|
486 |
+
self.caption_entry.delete(0, tk.END)
|
487 |
+
self.caption_entry.insert(0, self.caption)
|
488 |
+
#change the caption entry color to red
|
489 |
+
self.caption_entry.configure(fg_color='red')
|
490 |
+
def load_blip_model(self):
|
491 |
+
self.blipSize = 384
|
492 |
+
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
493 |
+
#check if options file exists
|
494 |
+
if os.path.exists(os.path.join(self.captioner_folder, 'options.json')):
|
495 |
+
with open(os.path.join(self.captioner_folder, 'options.json'), 'r') as f:
|
496 |
+
self.nucleus_sampling = json.load(f)['nucleus_sampling']
|
497 |
+
self.q_factor = json.load(f)['q_factor']
|
498 |
+
self.min_length = json.load(f)['min_length']
|
499 |
+
else:
|
500 |
+
self.nucleus_sampling = False
|
501 |
+
self.q_factor = 1.0
|
502 |
+
self.min_length = 22
|
503 |
+
config_path = os.path.join(self.captioner_folder, "BLIP/configs/med_config.json")
|
504 |
+
cache_folder = os.path.join(self.captioner_folder, "BLIP/cache")
|
505 |
+
model_path = os.path.join(self.captioner_folder, "BLIP/models/model_base_caption_capfilt_large.pth")
|
506 |
+
if not os.path.exists(cache_folder):
|
507 |
+
os.makedirs(cache_folder)
|
508 |
+
|
509 |
+
if not os.path.exists(model_path):
|
510 |
+
print(f"Downloading BLIP to {cache_folder}")
|
511 |
+
with requests.get(blip_model_url, stream=True) as session:
|
512 |
+
session.raise_for_status()
|
513 |
+
with open(model_path, 'wb') as f:
|
514 |
+
for chunk in session.iter_content(chunk_size=1024):
|
515 |
+
f.write(chunk)
|
516 |
+
print('Download complete')
|
517 |
+
else:
|
518 |
+
print(f"Found BLIP model")
|
519 |
+
import models.blip
|
520 |
+
blip_decoder = models.blip.blip_decoder(pretrained=model_path, image_size=self.blipSize, vit='base', med_config=config_path)
|
521 |
+
blip_decoder.eval()
|
522 |
+
self.blip_decoder = blip_decoder.to(torch.device("cuda"))
|
523 |
+
|
524 |
+
def batch_folder_mask(self):
|
525 |
+
folder = ''
|
526 |
+
try:
|
527 |
+
# check if self.folder is set
|
528 |
+
folder = self.folder
|
529 |
+
except:
|
530 |
+
pass
|
531 |
+
|
532 |
+
dialog = BatchMaskWindow(self, folder)
|
533 |
+
dialog.mainloop()
|
534 |
+
|
535 |
+
def load_clip_seg_model(self):
|
536 |
+
if self.clip_seg is None:
|
537 |
+
self.clip_seg = ClipSeg()
|
538 |
+
|
539 |
+
def open_folder(self,folder=None):
|
540 |
+
if folder is None:
|
541 |
+
self.folder = fd.askdirectory()
|
542 |
+
else:
|
543 |
+
self.folder = folder
|
544 |
+
if self.folder == '':
|
545 |
+
return
|
546 |
+
self.output_folder = self.folder
|
547 |
+
self.image_list = [os.path.join(self.folder, f) for f in os.listdir(self.folder) if _check_file_type(f) and not f.endswith('-masklabel.png') and not f.endswith('-depth.png')]
|
548 |
+
#self.image_list.sort()
|
549 |
+
#sort the image list alphabetically so that the images are in the same order every time
|
550 |
+
self.image_list.sort(key=lambda x: x.lower())
|
551 |
+
|
552 |
+
self.image_count = len(self.image_list)
|
553 |
+
if self.image_count == 0:
|
554 |
+
tk.messagebox.showinfo("No Images", "No images found in the selected folder")
|
555 |
+
return
|
556 |
+
#update the image count label
|
557 |
+
|
558 |
+
self.image_index = 0
|
559 |
+
self.image_count_label.configure(text=f'Image {self.image_index+1} of {self.image_count}')
|
560 |
+
self.output_folder = self.folder
|
561 |
+
self.load_image()
|
562 |
+
self.caption_entry.focus_set()
|
563 |
+
|
564 |
+
def draw_mask(self, event):
|
565 |
+
if not self.enable_mask_editing.get():
|
566 |
+
return
|
567 |
+
|
568 |
+
if event.widget != self.image_label.children["!label"]:
|
569 |
+
return
|
570 |
+
|
571 |
+
start_x = int(event.x / self.image_size[0] * self.PILimage.width)
|
572 |
+
start_y = int(event.y / self.image_size[1] * self.PILimage.height)
|
573 |
+
end_x = int(self.mask_draw_x / self.image_size[0] * self.PILimage.width)
|
574 |
+
end_y = int(self.mask_draw_y / self.image_size[1] * self.PILimage.height)
|
575 |
+
|
576 |
+
self.mask_draw_x = event.x
|
577 |
+
self.mask_draw_y = event.y
|
578 |
+
|
579 |
+
color = None
|
580 |
+
|
581 |
+
if event.state & 0x0100 or event.num == 1: # left mouse button
|
582 |
+
color = (255, 255, 255)
|
583 |
+
elif event.state & 0x0400 or event.num == 3: # right mouse button
|
584 |
+
color = (0, 0, 0)
|
585 |
+
|
586 |
+
if color is not None:
|
587 |
+
if self.PILmask is None:
|
588 |
+
self.PILmask = Image.new('RGB', size=self.PILimage.size, color=(0, 0, 0))
|
589 |
+
|
590 |
+
draw = ImageDraw.Draw(self.PILmask)
|
591 |
+
draw.line((start_x, start_y, end_x, end_y), fill=color, width=self.mask_draw_radius + self.mask_draw_radius + 1)
|
592 |
+
draw.ellipse((start_x - self.mask_draw_radius, start_y - self.mask_draw_radius, start_x + self.mask_draw_radius, start_y + self.mask_draw_radius), fill=color, outline=None)
|
593 |
+
draw.ellipse((end_x - self.mask_draw_radius, end_y - self.mask_draw_radius, end_x + self.mask_draw_radius, end_y + self.mask_draw_radius), fill=color, outline=None)
|
594 |
+
|
595 |
+
self.compose_masked_image()
|
596 |
+
self.display_image()
|
597 |
+
|
598 |
+
def draw_mask_radius(self, event):
|
599 |
+
if event.widget != self.image_label.children["!label"]:
|
600 |
+
return
|
601 |
+
|
602 |
+
delta = -np.sign(event.delta) * 5
|
603 |
+
self.mask_draw_radius += delta
|
604 |
+
|
605 |
+
def compose_masked_image(self):
|
606 |
+
np_image = np.array(self.PILimage).astype(np.float32) / 255.0
|
607 |
+
np_mask = np.array(self.PILmask).astype(np.float32) / 255.0
|
608 |
+
np_mask = np.clip(np_mask, 0.4, 1.0)
|
609 |
+
np_masked_image = (np_image * np_mask * 255.0).astype(np.uint8)
|
610 |
+
self.image = Image.fromarray(np_masked_image, mode='RGB')
|
611 |
+
|
612 |
+
def display_image(self):
|
613 |
+
#resize to fit 600x600 while maintaining aspect ratio
|
614 |
+
width, height = self.image.size
|
615 |
+
if width > height:
|
616 |
+
new_width = 600
|
617 |
+
new_height = int(600 * height / width)
|
618 |
+
else:
|
619 |
+
new_height = 600
|
620 |
+
new_width = int(600 * width / height)
|
621 |
+
self.image_size = (new_width, new_height)
|
622 |
+
self.image = self.image.resize(self.image_size, Image.Resampling.LANCZOS)
|
623 |
+
self.image = ctk.CTkImage(self.image, size=self.image_size)
|
624 |
+
self.image_label.configure(image=self.image)
|
625 |
+
|
626 |
+
def load_image(self):
|
627 |
+
try:
|
628 |
+
self.PILimage = Image.open(self.image_list[self.image_index]).convert('RGB')
|
629 |
+
except:
|
630 |
+
print(f'Error opening image {self.image_list[self.image_index]}')
|
631 |
+
print('Logged path to bad_files.txt')
|
632 |
+
#if bad_files.txt doesn't exist, create it
|
633 |
+
if not os.path.exists('bad_files.txt'):
|
634 |
+
with open('bad_files.txt', 'w') as f:
|
635 |
+
f.write(self.image_list[self.image_index]+'\n')
|
636 |
+
else:
|
637 |
+
with open('bad_files.txt', 'a') as f:
|
638 |
+
f.write(self.image_list[self.image_index]+'\n')
|
639 |
+
return
|
640 |
+
|
641 |
+
self.image = self.PILimage.copy()
|
642 |
+
|
643 |
+
try:
|
644 |
+
self.PILmask = None
|
645 |
+
mask_filename = os.path.splitext(self.image_list[self.image_index])[0] + '-masklabel.png'
|
646 |
+
if os.path.exists(mask_filename):
|
647 |
+
self.PILmask = Image.open(mask_filename).convert('RGB')
|
648 |
+
self.compose_masked_image()
|
649 |
+
except Exception as e:
|
650 |
+
print(f'Error opening mask for {self.image_list[self.image_index]}')
|
651 |
+
print('Logged path to bad_files.txt')
|
652 |
+
#if bad_files.txt doesn't exist, create it
|
653 |
+
if not os.path.exists('bad_files.txt'):
|
654 |
+
with open('bad_files.txt', 'w') as f:
|
655 |
+
f.write(self.image_list[self.image_index]+'\n')
|
656 |
+
else:
|
657 |
+
with open('bad_files.txt', 'a') as f:
|
658 |
+
f.write(self.image_list[self.image_index]+'\n')
|
659 |
+
return
|
660 |
+
|
661 |
+
self.display_image()
|
662 |
+
|
663 |
+
self.caption_file_path = self.image_list[self.image_index]
|
664 |
+
self.caption_file_name = os.path.basename(self.caption_file_path)
|
665 |
+
self.caption_file_ext = os.path.splitext(self.caption_file_name)[1]
|
666 |
+
self.caption_file_name_no_ext = os.path.splitext(self.caption_file_name)[0]
|
667 |
+
self.caption_file = os.path.join(self.folder, self.caption_file_name_no_ext + '.txt')
|
668 |
+
if os.path.isfile(self.caption_file) and self.auto_generate_caption.get() == False or os.path.isfile(self.caption_file) and self.auto_generate_caption.get() == True and self.auto_generate_caption_text_override.get() == True:
|
669 |
+
with open(self.caption_file, 'r') as f:
|
670 |
+
self.caption = f.read()
|
671 |
+
self.caption_entry.delete(0, tk.END)
|
672 |
+
self.caption_entry.insert(0, self.caption)
|
673 |
+
self.caption_entry.configure(fg_color=ThemeManager.theme["CTkEntry"]["fg_color"])
|
674 |
+
self.use_blip = False
|
675 |
+
elif os.path.isfile(self.caption_file) and self.auto_generate_caption.get() == True and self.auto_generate_caption_text_override.get() == False or os.path.isfile(self.caption_file)==False and self.auto_generate_caption.get() == True and self.auto_generate_caption_text_override.get() == True:
|
676 |
+
self.use_blip = True
|
677 |
+
self.caption_entry.delete(0, tk.END)
|
678 |
+
elif os.path.isfile(self.caption_file) == False and self.auto_generate_caption.get() == False:
|
679 |
+
self.caption_entry.delete(0, tk.END)
|
680 |
+
return
|
681 |
+
if self.use_blip and self.debug==False:
|
682 |
+
tensor = transforms.Compose([
|
683 |
+
transforms.Resize((self.blipSize, self.blipSize), interpolation=InterpolationMode.BICUBIC),
|
684 |
+
transforms.ToTensor(),
|
685 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
686 |
+
])
|
687 |
+
torch_image = tensor(self.PILimage).unsqueeze(0).to(torch.device("cuda"))
|
688 |
+
if self.nucleus_sampling:
|
689 |
+
captions = self.blip_decoder.generate(torch_image, sample=True, top_p=self.q_factor)
|
690 |
+
else:
|
691 |
+
captions = self.blip_decoder.generate(torch_image, sample=False, num_beams=16, min_length=self.min_length, \
|
692 |
+
max_length=48, repetition_penalty=self.q_factor)
|
693 |
+
self.caption = captions[0]
|
694 |
+
self.caption_entry.delete(0, tk.END)
|
695 |
+
self.caption_entry.insert(0, self.caption)
|
696 |
+
#change the caption entry color to red
|
697 |
+
self.caption_entry.configure(fg_color='red')
|
698 |
+
|
699 |
+
def save(self, event):
|
700 |
+
self.save_caption()
|
701 |
+
|
702 |
+
if self.enable_mask_editing.get():
|
703 |
+
self.save_mask()
|
704 |
+
|
705 |
+
def save_mask(self):
|
706 |
+
mask_filename = os.path.splitext(self.image_list[self.image_index])[0] + '-masklabel.png'
|
707 |
+
if self.PILmask is not None:
|
708 |
+
self.PILmask.save(mask_filename)
|
709 |
+
|
710 |
+
def save_caption(self):
|
711 |
+
self.caption = self.caption_entry.get()
|
712 |
+
self.replace = self.replace_entry.get()
|
713 |
+
self.replace_with = self.with_entry.get()
|
714 |
+
self.suffix_var = self.suffix_entry.get()
|
715 |
+
self.prefix = self.prefix_entry.get()
|
716 |
+
#prepare the caption
|
717 |
+
self.caption = self.caption.replace(self.replace, self.replace_with)
|
718 |
+
if self.suffix_var.startswith(',') or self.suffix_var.startswith(' '):
|
719 |
+
self.suffix_var = self.suffix_var
|
720 |
+
else:
|
721 |
+
self.suffix_var = ' ' + self.suffix_var
|
722 |
+
if self.prefix != '':
|
723 |
+
if self.prefix.endswith(' '):
|
724 |
+
self.prefix = self.prefix[:-1]
|
725 |
+
if not self.prefix.endswith(','):
|
726 |
+
self.prefix = self.prefix+','
|
727 |
+
self.caption = self.prefix + ' ' + self.caption
|
728 |
+
if self.caption.endswith(',') or self.caption.endswith('.'):
|
729 |
+
self.caption = self.caption[:-1]
|
730 |
+
self.caption = self.caption +', ' + self.suffix_var
|
731 |
+
else:
|
732 |
+
self.caption = self.caption + self.suffix_var
|
733 |
+
self.caption = self.caption.strip()
|
734 |
+
if self.output_folder != self.folder:
|
735 |
+
outputFolder = self.output_folder
|
736 |
+
else:
|
737 |
+
outputFolder = self.folder
|
738 |
+
if self.output_format == 'text':
|
739 |
+
#text file with same name as image
|
740 |
+
#image name
|
741 |
+
#print('test')
|
742 |
+
imgName = os.path.basename(self.image_list[self.image_index])
|
743 |
+
imgName = imgName[:imgName.rfind('.')]
|
744 |
+
self.caption_file = os.path.join(outputFolder, imgName + '.txt')
|
745 |
+
with open(self.caption_file, 'w') as f:
|
746 |
+
f.write(self.caption)
|
747 |
+
elif self.output_format == 'filename':
|
748 |
+
#duplicate image with caption as file name
|
749 |
+
#make sure self.caption doesn't contain any illegal characters
|
750 |
+
illegal_chars = ['/', '\\', ':', '*', '?', '"', "'",'<', '>', '|', '.']
|
751 |
+
for char in illegal_chars:
|
752 |
+
self.caption = self.caption.replace(char, '')
|
753 |
+
self.PILimage.save(os.path.join(outputFolder, self.caption+'.png'))
|
754 |
+
self.caption_entry.delete(0, tk.END)
|
755 |
+
self.caption_entry.insert(0, self.caption)
|
756 |
+
self.caption_entry.configure(fg_color='green')
|
757 |
+
|
758 |
+
self.caption_entry.focus_force()
|
759 |
+
def delete_word(self,event):
|
760 |
+
ent = event.widget
|
761 |
+
end_idx = ent.index(tk.INSERT)
|
762 |
+
start_idx = ent.get().rfind(" ", None, end_idx)
|
763 |
+
ent.selection_range(start_idx, end_idx)
|
764 |
+
def prev_image(self, event):
|
765 |
+
if self.image_index > 0:
|
766 |
+
self.image_index -= 1
|
767 |
+
self.image_count_label.configure(text=f'Image {self.image_index+1} of {self.image_count}')
|
768 |
+
self.load_image()
|
769 |
+
self.caption_entry.focus_set()
|
770 |
+
self.caption_entry.focus_force()
|
771 |
+
def next_image(self, event):
|
772 |
+
if self.image_index < len(self.image_list) - 1:
|
773 |
+
self.image_index += 1
|
774 |
+
self.image_count_label.configure(text=f'Image {self.image_index+1} of {self.image_count}')
|
775 |
+
self.load_image()
|
776 |
+
self.caption_entry.focus_set()
|
777 |
+
self.caption_entry.focus_force()
|
778 |
+
def open_options(self):
|
779 |
+
self.options_window = ctk.CTkToplevel(self)
|
780 |
+
self.options_window.title("Options")
|
781 |
+
self.options_window.geometry("320x550")
|
782 |
+
#disable reszie
|
783 |
+
self.options_window.resizable(False, False)
|
784 |
+
self.options_window.focus_force()
|
785 |
+
self.options_window.grab_set()
|
786 |
+
self.options_window.transient(self)
|
787 |
+
self.options_window.protocol("WM_DELETE_WINDOW", self.close_options)
|
788 |
+
#add title label
|
789 |
+
self.options_title_label = ctk.CTkLabel(self.options_window, text="Options",font=ctk.CTkFont(size=20, weight="bold"))
|
790 |
+
self.options_title_label.pack(side="top", pady=5)
|
791 |
+
#add an entry with a button to select a folder as output folder
|
792 |
+
self.output_folder_label = ctk.CTkLabel(self.options_window, text="Output Folder")
|
793 |
+
self.output_folder_label.pack(side="top", pady=5)
|
794 |
+
self.output_folder_entry = ctk.CTkEntry(self.options_window)
|
795 |
+
self.output_folder_entry.pack(side="top", fill="x", expand=False,padx=15, pady=5)
|
796 |
+
self.output_folder_entry.insert(0, self.output_folder)
|
797 |
+
self.output_folder_button = ctk.CTkButton(self.options_window, text="Select Folder", command=self.select_output_folder,fg_color=("gray75", "gray25"))
|
798 |
+
self.output_folder_button.pack(side="top", pady=5)
|
799 |
+
#add radio buttons to select the output format between text and filename
|
800 |
+
self.output_format_label = ctk.CTkLabel(self.options_window, text="Output Format")
|
801 |
+
self.output_format_label.pack(side="top", pady=5)
|
802 |
+
self.output_format_var = tk.StringVar(self.options_window)
|
803 |
+
self.output_format_var.set(self.output_format)
|
804 |
+
self.output_format_text = ctk.CTkRadioButton(self.options_window, text="Text File", variable=self.output_format_var, value="text")
|
805 |
+
self.output_format_text.pack(side="top", pady=5)
|
806 |
+
self.output_format_filename = ctk.CTkRadioButton(self.options_window, text="File name", variable=self.output_format_var, value="filename")
|
807 |
+
self.output_format_filename.pack(side="top", pady=5)
|
808 |
+
#add BLIP settings section
|
809 |
+
self.blip_settings_label = ctk.CTkLabel(self.options_window, text="BLIP Settings",font=ctk.CTkFont(size=20, weight="bold"))
|
810 |
+
self.blip_settings_label.pack(side="top", pady=10)
|
811 |
+
#add a checkbox to use nucleas sampling or not
|
812 |
+
self.nucleus_sampling_var = tk.IntVar(self.options_window)
|
813 |
+
self.nucleus_sampling_checkbox = ctk.CTkCheckBox(self.options_window, text="Use nucleus sampling", variable=self.nucleus_sampling_var)
|
814 |
+
self.nucleus_sampling_checkbox.pack(side="top", pady=5)
|
815 |
+
if self.debug:
|
816 |
+
self.nucleus_sampling = 0
|
817 |
+
self.q_factor = 0.5
|
818 |
+
self.min_length = 10
|
819 |
+
self.nucleus_sampling_var.set(self.nucleus_sampling)
|
820 |
+
#add a float entry to set the q factor
|
821 |
+
self.q_factor_label = ctk.CTkLabel(self.options_window, text="Q Factor")
|
822 |
+
self.q_factor_label.pack(side="top", pady=5)
|
823 |
+
self.q_factor_entry = ctk.CTkEntry(self.options_window)
|
824 |
+
self.q_factor_entry.insert(0, self.q_factor)
|
825 |
+
self.q_factor_entry.pack(side="top", pady=5)
|
826 |
+
#add a int entry to set the number minimum length
|
827 |
+
self.min_length_label = ctk.CTkLabel(self.options_window, text="Minimum Length")
|
828 |
+
self.min_length_label.pack(side="top", pady=5)
|
829 |
+
self.min_length_entry = ctk.CTkEntry(self.options_window)
|
830 |
+
self.min_length_entry.insert(0, self.min_length)
|
831 |
+
self.min_length_entry.pack(side="top", pady=5)
|
832 |
+
#add a horozontal radio button to select between None, ViT-L-14/openai, ViT-H-14/laion2b_s32b_b79k
|
833 |
+
#self.model_label = ctk.CTkLabel(self.options_window, text="CLIP Interrogation")
|
834 |
+
#self.model_label.pack(side="top")
|
835 |
+
#self.model_var = tk.StringVar(self.options_window)
|
836 |
+
#self.model_var.set(self.model)
|
837 |
+
#self.model_none = tk.Radiobutton(self.options_window, text="None", variable=self.model_var, value="None")
|
838 |
+
#self.model_none.pack(side="top")
|
839 |
+
#self.model_vit_l_14 = tk.Radiobutton(self.options_window, text="ViT-L-14/openai", variable=self.model_var, value="ViT-L-14/openai")
|
840 |
+
#self.model_vit_l_14.pack(side="top")
|
841 |
+
#self.model_vit_h_14 = tk.Radiobutton(self.options_window, text="ViT-H-14/laion2b_s32b_b79k", variable=self.model_var, value="ViT-H-14/laion2b_s32b_b79k")
|
842 |
+
#self.model_vit_h_14.pack(side="top")
|
843 |
+
|
844 |
+
#add a save button
|
845 |
+
self.save_button = ctk.CTkButton(self.options_window, text="Save", command=self.save_options, fg_color=("gray75", "gray25"))
|
846 |
+
self.save_button.pack(side="top",fill='x',pady=10,padx=10)
|
847 |
+
#all entries list
|
848 |
+
entries = [self.output_folder_entry, self.q_factor_entry, self.min_length_entry]
|
849 |
+
#bind the right click to all entries
|
850 |
+
for entry in entries:
|
851 |
+
entry.bind("<Button-3>", self.create_right_click_menu)
|
852 |
+
self.options_file = os.path.join(self.captioner_folder, 'captioner_options.json')
|
853 |
+
if os.path.isfile(self.options_file):
|
854 |
+
with open(self.options_file, 'r') as f:
|
855 |
+
self.options = json.load(f)
|
856 |
+
self.output_folder_entry.delete(0, tk.END)
|
857 |
+
self.output_folder_entry.insert(0, self.output_folder)
|
858 |
+
self.output_format_var.set(self.options['output_format'])
|
859 |
+
self.nucleus_sampling_var.set(self.options['nucleus_sampling'])
|
860 |
+
self.q_factor_entry.delete(0, tk.END)
|
861 |
+
self.q_factor_entry.insert(0, self.options['q_factor'])
|
862 |
+
self.min_length_entry.delete(0, tk.END)
|
863 |
+
self.min_length_entry.insert(0, self.options['min_length'])
|
864 |
+
def load_options(self):
|
865 |
+
self.options_file = os.path.join(self.captioner_folder, 'captioner_options.json')
|
866 |
+
if os.path.isfile(self.options_file):
|
867 |
+
with open(self.options_file, 'r') as f:
|
868 |
+
self.options = json.load(f)
|
869 |
+
#self.output_folder = self.folder
|
870 |
+
#self.output_folder = self.options['output_folder']
|
871 |
+
if 'folder' in self.__dict__:
|
872 |
+
self.output_folder = self.folder
|
873 |
+
else:
|
874 |
+
self.output_folder = ''
|
875 |
+
self.output_format = self.options['output_format']
|
876 |
+
self.nucleus_sampling = self.options['nucleus_sampling']
|
877 |
+
self.q_factor = self.options['q_factor']
|
878 |
+
self.min_length = self.options['min_length']
|
879 |
+
else:
|
880 |
+
#if self has folder, use it, otherwise use the current folder
|
881 |
+
if 'folder' in self.__dict__ :
|
882 |
+
self.output_folder = self.folder
|
883 |
+
else:
|
884 |
+
self.output_folder = ''
|
885 |
+
self.output_format = "text"
|
886 |
+
self.nucleus_sampling = False
|
887 |
+
self.q_factor = 0.9
|
888 |
+
self.min_length =22
|
889 |
+
def save_options(self):
|
890 |
+
self.output_folder = self.output_folder_entry.get()
|
891 |
+
self.output_format = self.output_format_var.get()
|
892 |
+
self.nucleus_sampling = self.nucleus_sampling_var.get()
|
893 |
+
self.q_factor = float(self.q_factor_entry.get())
|
894 |
+
self.min_length = int(self.min_length_entry.get())
|
895 |
+
#save options to a file
|
896 |
+
self.options_file = os.path.join(self.captioner_folder, 'captioner_options.json')
|
897 |
+
with open(self.options_file, 'w') as f:
|
898 |
+
json.dump({'output_folder': self.output_folder, 'output_format': self.output_format, 'nucleus_sampling': self.nucleus_sampling, 'q_factor': self.q_factor, 'min_length': self.min_length}, f)
|
899 |
+
self.close_options()
|
900 |
+
|
901 |
+
def select_output_folder(self):
|
902 |
+
self.output_folder = fd.askdirectory()
|
903 |
+
self.output_folder_entry.delete(0, tk.END)
|
904 |
+
self.output_folder_entry.insert(0, self.output_folder)
|
905 |
+
def close_options(self):
|
906 |
+
self.options_window.destroy()
|
907 |
+
self.caption_entry.focus_force()
|
908 |
+
def create_right_click_menu(self, event):
|
909 |
+
#create a menu
|
910 |
+
self.menu = Menu(self, tearoff=0)
|
911 |
+
#add commands to the menu
|
912 |
+
self.menu.add_command(label="Cut", command=lambda: self.focus_get().event_generate("<<Cut>>"))
|
913 |
+
self.menu.add_command(label="Copy", command=lambda: self.focus_get().event_generate("<<Copy>>"))
|
914 |
+
self.menu.add_command(label="Paste", command=lambda: self.focus_get().event_generate("<<Paste>>"))
|
915 |
+
self.menu.add_command(label="Select All", command=lambda: self.focus_get().event_generate("<<SelectAll>>"))
|
916 |
+
#display the menu
|
917 |
+
try:
|
918 |
+
self.menu.tk_popup(event.x_root, event.y_root)
|
919 |
+
finally:
|
920 |
+
#make sure to release the grab (Tk 8.0a1 only)
|
921 |
+
self.menu.grab_release()
|
922 |
+
|
923 |
+
|
924 |
+
#progress bar class with cancel button
|
925 |
+
class ProgressbarWithCancel(ctk.CTkToplevel):
|
926 |
+
def __init__(self,max=None, **kw):
|
927 |
+
super().__init__(**kw)
|
928 |
+
self.title("Batching...")
|
929 |
+
self.max = max
|
930 |
+
self.possibleLabels = ['Searching for answers...',"I'm working, I promise.",'ARE THOSE TENTACLES?!','Weird data man...','Another one bites the dust' ,"I think it's a cat?" ,'Looking for the meaning of life', 'Dreaming of captions']
|
931 |
+
|
932 |
+
self.label = ctk.CTkLabel(self, text="Searching for answers...")
|
933 |
+
self.label.pack(side="top", fill="x", expand=True,padx=10,pady=10)
|
934 |
+
self.progress = ctk.CTkProgressBar(self, orientation="horizontal", mode="determinate")
|
935 |
+
self.progress.pack(side="left", fill="x", expand=True,padx=10,pady=10)
|
936 |
+
self.cancel_button = ctk.CTkButton(self, text="Cancel", command=self.cancel)
|
937 |
+
self.cancel_button.pack(side="right",padx=10,pady=10)
|
938 |
+
self.cancelled = False
|
939 |
+
self.count_label = ctk.CTkLabel(self, text="0/{0}".format(self.max))
|
940 |
+
self.count_label.pack(side="right",padx=10,pady=10)
|
941 |
+
def set_random_label(self):
|
942 |
+
import random
|
943 |
+
self.label["text"] = random.choice(self.possibleLabels)
|
944 |
+
#pop from list
|
945 |
+
#self.possibleLabels.remove(self.label["text"])
|
946 |
+
def cancel(self):
|
947 |
+
self.cancelled = True
|
948 |
+
def set_progress(self, value):
|
949 |
+
self.progress.set(value)
|
950 |
+
self.count_label.configure(text="{0}/{1}".format(int(value * self.max), self.max))
|
951 |
+
def get_progress(self):
|
952 |
+
return self.progress.get
|
953 |
+
def set_max(self, value):
|
954 |
+
return value
|
955 |
+
def get_max(self):
|
956 |
+
return self.progress["maximum"]
|
957 |
+
def is_cancelled(self):
|
958 |
+
return self.cancelled
|
959 |
+
#quit the progress bar window
|
960 |
+
|
961 |
+
|
962 |
+
#run when imported as a module
|
963 |
+
if __name__ == "__main__":
|
964 |
+
|
965 |
+
#root = tk.Tk()
|
966 |
+
app = ImageBrowser()
|
967 |
+
app.mainloop()
|
StableTuner_RunPod_Fix/clip_segmentation.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from typing import Optional, Callable
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from torch import Tensor, nn
|
8 |
+
from torchvision.transforms import transforms, functional
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
11 |
+
|
12 |
+
DEVICE = "cuda"
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(description="ClipSeg script.")
|
17 |
+
parser.add_argument(
|
18 |
+
"--sample_dir",
|
19 |
+
type=str,
|
20 |
+
required=True,
|
21 |
+
help="directory where samples are located",
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--add_prompt",
|
25 |
+
type=str,
|
26 |
+
required=True,
|
27 |
+
action="append",
|
28 |
+
help="a prompt used to create a mask",
|
29 |
+
dest="prompts",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--mode",
|
33 |
+
type=str,
|
34 |
+
default='fill',
|
35 |
+
required=False,
|
36 |
+
help="Either replace, fill, add or subtract",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--threshold",
|
40 |
+
type=float,
|
41 |
+
default='0.3',
|
42 |
+
required=False,
|
43 |
+
help="threshold for including pixels in the mask",
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--smooth_pixels",
|
47 |
+
type=int,
|
48 |
+
default=5,
|
49 |
+
required=False,
|
50 |
+
help="radius of a smoothing operation applied to the generated mask",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--expand_pixels",
|
54 |
+
type=int,
|
55 |
+
default=10,
|
56 |
+
required=False,
|
57 |
+
help="amount of expansion of the generated mask in all directions",
|
58 |
+
)
|
59 |
+
|
60 |
+
args = parser.parse_args()
|
61 |
+
return args
|
62 |
+
|
63 |
+
|
64 |
+
class MaskSample:
|
65 |
+
def __init__(self, filename: str):
|
66 |
+
self.image_filename = filename
|
67 |
+
self.mask_filename = os.path.splitext(filename)[0] + "-masklabel.png"
|
68 |
+
|
69 |
+
self.image = None
|
70 |
+
self.mask_tensor = None
|
71 |
+
|
72 |
+
self.height = 0
|
73 |
+
self.width = 0
|
74 |
+
|
75 |
+
self.image2Tensor = transforms.Compose([
|
76 |
+
transforms.ToTensor(),
|
77 |
+
])
|
78 |
+
|
79 |
+
self.tensor2Image = transforms.Compose([
|
80 |
+
transforms.ToPILImage(),
|
81 |
+
])
|
82 |
+
|
83 |
+
def get_image(self) -> Image:
|
84 |
+
if self.image is None:
|
85 |
+
self.image = Image.open(self.image_filename).convert('RGB')
|
86 |
+
self.height = self.image.height
|
87 |
+
self.width = self.image.width
|
88 |
+
|
89 |
+
return self.image
|
90 |
+
|
91 |
+
def get_mask_tensor(self) -> Tensor:
|
92 |
+
if self.mask_tensor is None and os.path.exists(self.mask_filename):
|
93 |
+
mask = Image.open(self.mask_filename).convert('L')
|
94 |
+
mask = self.image2Tensor(mask)
|
95 |
+
mask = mask.to(DEVICE)
|
96 |
+
self.mask_tensor = mask.unsqueeze(0)
|
97 |
+
|
98 |
+
return self.mask_tensor
|
99 |
+
|
100 |
+
def set_mask_tensor(self, mask_tensor: Tensor):
|
101 |
+
self.mask_tensor = mask_tensor
|
102 |
+
|
103 |
+
def add_mask_tensor(self, mask_tensor: Tensor):
|
104 |
+
mask = self.get_mask_tensor()
|
105 |
+
if mask is None:
|
106 |
+
mask = mask_tensor
|
107 |
+
else:
|
108 |
+
mask += mask_tensor
|
109 |
+
mask = torch.clamp(mask, 0, 1)
|
110 |
+
|
111 |
+
self.mask_tensor = mask
|
112 |
+
|
113 |
+
def subtract_mask_tensor(self, mask_tensor: Tensor):
|
114 |
+
mask = self.get_mask_tensor()
|
115 |
+
if mask is None:
|
116 |
+
mask = mask_tensor
|
117 |
+
else:
|
118 |
+
mask -= mask_tensor
|
119 |
+
mask = torch.clamp(mask, 0, 1)
|
120 |
+
|
121 |
+
self.mask_tensor = mask
|
122 |
+
|
123 |
+
def save_mask(self):
|
124 |
+
if self.mask_tensor is not None:
|
125 |
+
mask = self.mask_tensor.cpu().squeeze()
|
126 |
+
mask = self.tensor2Image(mask).convert('RGB')
|
127 |
+
mask.save(self.mask_filename)
|
128 |
+
|
129 |
+
|
130 |
+
class ClipSeg:
|
131 |
+
def __init__(self):
|
132 |
+
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
133 |
+
|
134 |
+
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
135 |
+
self.model.eval()
|
136 |
+
self.model.to(DEVICE)
|
137 |
+
|
138 |
+
self.smoothing_kernel_radius = None
|
139 |
+
self.smoothing_kernel = self.__create_average_kernel(self.smoothing_kernel_radius)
|
140 |
+
|
141 |
+
self.expand_kernel_radius = None
|
142 |
+
self.expand_kernel = self.__create_average_kernel(self.expand_kernel_radius)
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
def __create_average_kernel(kernel_radius: Optional[int]):
|
146 |
+
if kernel_radius is None:
|
147 |
+
return None
|
148 |
+
|
149 |
+
kernel_size = kernel_radius * 2 + 1
|
150 |
+
kernel_weights = torch.ones(1, 1, kernel_size, kernel_size) / (kernel_size * kernel_size)
|
151 |
+
kernel = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size, bias=False, padding_mode='replicate', padding=kernel_radius)
|
152 |
+
kernel.weight.data = kernel_weights
|
153 |
+
kernel.requires_grad_(False)
|
154 |
+
kernel.to(DEVICE)
|
155 |
+
return kernel
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
def __get_sample_filenames(sample_dir: str) -> [str]:
|
159 |
+
filenames = []
|
160 |
+
for filename in os.listdir(sample_dir):
|
161 |
+
ext = os.path.splitext(filename)[1].lower()
|
162 |
+
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] and '-masklabel.png' not in filename:
|
163 |
+
filenames.append(os.path.join(sample_dir, filename))
|
164 |
+
|
165 |
+
return filenames
|
166 |
+
|
167 |
+
def __process_mask(self, mask: Tensor, target_height: int, target_width: int, threshold: float) -> Tensor:
|
168 |
+
while len(mask.shape) < 4:
|
169 |
+
mask = mask.unsqueeze(0)
|
170 |
+
|
171 |
+
mask = torch.sigmoid(mask)
|
172 |
+
mask = mask.sum(1).unsqueeze(1)
|
173 |
+
if self.smoothing_kernel is not None:
|
174 |
+
mask = self.smoothing_kernel(mask)
|
175 |
+
mask = functional.resize(mask, [target_height, target_width])
|
176 |
+
mask = (mask > threshold).float()
|
177 |
+
if self.expand_kernel is not None:
|
178 |
+
mask = self.expand_kernel(mask)
|
179 |
+
mask = (mask > 0).float()
|
180 |
+
|
181 |
+
return mask
|
182 |
+
|
183 |
+
def mask_image(self, filename: str, prompts: [str], mode: str = 'fill', threshold: float = 0.3, smooth_pixels: int = 5, expand_pixels: int = 10):
|
184 |
+
"""
|
185 |
+
Masks a sample
|
186 |
+
|
187 |
+
Parameters:
|
188 |
+
filename (`str`): a sample filename
|
189 |
+
prompts (`[str]`): a list of prompts used to create a mask
|
190 |
+
mode (`str`): can be one of
|
191 |
+
- replace: creates new masks for all samples, even if a mask already exists
|
192 |
+
- fill: creates new masks for all samples without a mask
|
193 |
+
- add: adds the new region to existing masks
|
194 |
+
- subtract: subtracts the new region from existing masks
|
195 |
+
threshold (`float`): threshold for including pixels in the mask
|
196 |
+
smooth_pixels (`int`): radius of a smoothing operation applied to the generated mask
|
197 |
+
expand_pixels (`int`): amount of expansion of the generated mask in all directions
|
198 |
+
"""
|
199 |
+
|
200 |
+
mask_sample = MaskSample(filename)
|
201 |
+
|
202 |
+
if mode == 'fill' and mask_sample.get_mask_tensor() is not None:
|
203 |
+
return
|
204 |
+
|
205 |
+
if self.smoothing_kernel_radius != smooth_pixels:
|
206 |
+
self.smoothing_kernel = self.__create_average_kernel(smooth_pixels)
|
207 |
+
self.smoothing_kernel_radius = smooth_pixels
|
208 |
+
|
209 |
+
if self.expand_kernel_radius != expand_pixels:
|
210 |
+
self.expand_kernel = self.__create_average_kernel(expand_pixels)
|
211 |
+
self.expand_kernel_radius = expand_pixels
|
212 |
+
|
213 |
+
inputs = self.processor(text=prompts, images=[mask_sample.get_image()] * len(prompts), padding="max_length", return_tensors="pt")
|
214 |
+
inputs.to(DEVICE)
|
215 |
+
with torch.no_grad():
|
216 |
+
outputs = self.model(**inputs)
|
217 |
+
predicted_mask = self.__process_mask(outputs.logits, mask_sample.height, mask_sample.width, threshold)
|
218 |
+
|
219 |
+
if mode == 'replace' or mode == 'fill':
|
220 |
+
mask_sample.set_mask_tensor(predicted_mask)
|
221 |
+
elif mode == 'add':
|
222 |
+
mask_sample.add_mask_tensor(predicted_mask)
|
223 |
+
elif mode == 'subtract':
|
224 |
+
mask_sample.subtract_mask_tensor(predicted_mask)
|
225 |
+
|
226 |
+
mask_sample.save_mask()
|
227 |
+
|
228 |
+
def mask_folder(
|
229 |
+
self,
|
230 |
+
sample_dir: str,
|
231 |
+
prompts: [str],
|
232 |
+
mode: str = 'fill',
|
233 |
+
threshold: float = 0.3,
|
234 |
+
smooth_pixels: int = 5,
|
235 |
+
expand_pixels: int = 10,
|
236 |
+
progress_callback: Callable[[int, int], None] = None,
|
237 |
+
error_callback: Callable[[str], None] = None,
|
238 |
+
):
|
239 |
+
"""
|
240 |
+
Masks all samples in a folder
|
241 |
+
|
242 |
+
Parameters:
|
243 |
+
sample_dir (`str`): directory where samples are located
|
244 |
+
prompts (`[str]`): a list of prompts used to create a mask
|
245 |
+
mode (`str`): can be one of
|
246 |
+
- replace: creates new masks for all samples, even if a mask already exists
|
247 |
+
- fill: creates new masks for all samples without a mask
|
248 |
+
- add: adds the new region to existing masks
|
249 |
+
- subtract: subtracts the new region from existing masks
|
250 |
+
threshold (`float`): threshold for including pixels in the mask
|
251 |
+
smooth_pixels (`int`): radius of a smoothing operation applied to the generated mask
|
252 |
+
expand_pixels (`int`): amount of expansion of the generated mask in all directions
|
253 |
+
progress_callback (`Callable[[int, int], None]`): called after every processed image
|
254 |
+
error_callback (`Callable[[str], None]`): called for every exception
|
255 |
+
"""
|
256 |
+
|
257 |
+
filenames = self.__get_sample_filenames(sample_dir)
|
258 |
+
self.mask_images(
|
259 |
+
filenames=filenames,
|
260 |
+
prompts=prompts,
|
261 |
+
mode=mode,
|
262 |
+
threshold=threshold,
|
263 |
+
smooth_pixels=smooth_pixels,
|
264 |
+
expand_pixels=expand_pixels,
|
265 |
+
progress_callback=progress_callback,
|
266 |
+
error_callback=error_callback,
|
267 |
+
)
|
268 |
+
|
269 |
+
def mask_images(
|
270 |
+
self,
|
271 |
+
filenames: [str],
|
272 |
+
prompts: [str],
|
273 |
+
mode: str = 'fill',
|
274 |
+
threshold: float = 0.3,
|
275 |
+
smooth_pixels: int = 5,
|
276 |
+
expand_pixels: int = 10,
|
277 |
+
progress_callback: Callable[[int, int], None] = None,
|
278 |
+
error_callback: Callable[[str], None] = None,
|
279 |
+
):
|
280 |
+
"""
|
281 |
+
Masks all samples in a list
|
282 |
+
|
283 |
+
Parameters:
|
284 |
+
filenames (`[str]`): a list of sample filenames
|
285 |
+
prompts (`[str]`): a list of prompts used to create a mask
|
286 |
+
mode (`str`): can be one of
|
287 |
+
- replace: creates new masks for all samples, even if a mask already exists
|
288 |
+
- fill: creates new masks for all samples without a mask
|
289 |
+
- add: adds the new region to existing masks
|
290 |
+
- subtract: subtracts the new region from existing masks
|
291 |
+
threshold (`float`): threshold for including pixels in the mask
|
292 |
+
smooth_pixels (`int`): radius of a smoothing operation applied to the generated mask
|
293 |
+
expand_pixels (`int`): amount of expansion of the generated mask in all directions
|
294 |
+
progress_callback (`Callable[[int, int], None]`): called after every processed image
|
295 |
+
error_callback (`Callable[[str], None]`): called for every exception
|
296 |
+
"""
|
297 |
+
|
298 |
+
if progress_callback is not None:
|
299 |
+
progress_callback(0, len(filenames))
|
300 |
+
for i, filename in enumerate(tqdm(filenames)):
|
301 |
+
try:
|
302 |
+
self.mask_image(filename, prompts, mode, threshold, smooth_pixels, expand_pixels)
|
303 |
+
except Exception as e:
|
304 |
+
if error_callback is not None:
|
305 |
+
error_callback(filename)
|
306 |
+
if progress_callback is not None:
|
307 |
+
progress_callback(i + 1, len(filenames))
|
308 |
+
|
309 |
+
|
310 |
+
def main():
|
311 |
+
args = parse_args()
|
312 |
+
clip_seg = ClipSeg()
|
313 |
+
clip_seg.mask_folder(
|
314 |
+
sample_dir=args.sample_dir,
|
315 |
+
prompts=args.prompts,
|
316 |
+
mode=args.mode,
|
317 |
+
threshold=args.threshold,
|
318 |
+
smooth_pixels=args.smooth_pixels,
|
319 |
+
expand_pixels=args.expand_pixels,
|
320 |
+
error_callback=lambda filename: print("Error while processing image " + filename)
|
321 |
+
)
|
322 |
+
|
323 |
+
|
324 |
+
if __name__ == "__main__":
|
325 |
+
main()
|
StableTuner_RunPod_Fix/configuration_gui.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
StableTuner_RunPod_Fix/convert_diffusers_to_sd_cli.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
try:
|
4 |
+
import converters
|
5 |
+
except ImportError:
|
6 |
+
|
7 |
+
#if there's a scripts folder where the script is, add it to the path
|
8 |
+
if 'scripts' in os.listdir(os.path.dirname(os.path.abspath(__file__))):
|
9 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '\\scripts')
|
10 |
+
else:
|
11 |
+
print('Could not find scripts folder. Please add it to the path manually or place this file in it.')
|
12 |
+
import converters
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
args = sys.argv[1:]
|
17 |
+
if len(args) != 2:
|
18 |
+
print('Usage: python3 convert_diffusers_to_sd.py <model_path> <output_path>')
|
19 |
+
sys.exit(1)
|
20 |
+
model_path = args[0]
|
21 |
+
output_path = args[1]
|
22 |
+
converters.Convert_Diffusers_to_SD(model_path, output_path)
|
StableTuner_RunPod_Fix/converters.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import requests
|
16 |
+
import os
|
17 |
+
import os.path as osp
|
18 |
+
import torch
|
19 |
+
try:
|
20 |
+
from omegaconf import OmegaConf
|
21 |
+
except ImportError:
|
22 |
+
raise ImportError(
|
23 |
+
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
24 |
+
)
|
25 |
+
|
26 |
+
from diffusers import (
|
27 |
+
AutoencoderKL,
|
28 |
+
DDIMScheduler,
|
29 |
+
DPMSolverMultistepScheduler,
|
30 |
+
EulerAncestralDiscreteScheduler,
|
31 |
+
EulerDiscreteScheduler,
|
32 |
+
HeunDiscreteScheduler,
|
33 |
+
LDMTextToImagePipeline,
|
34 |
+
LMSDiscreteScheduler,
|
35 |
+
PNDMScheduler,
|
36 |
+
StableDiffusionPipeline,
|
37 |
+
UNet2DConditionModel,
|
38 |
+
DiffusionPipeline
|
39 |
+
)
|
40 |
+
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
41 |
+
#from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
42 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
43 |
+
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig, CLIPTextConfig
|
44 |
+
import model_util
|
45 |
+
|
46 |
+
class Convert_SD_to_Diffusers():
|
47 |
+
|
48 |
+
def __init__(self, checkpoint_path, output_path, prediction_type=None, img_size=None, original_config_file=None, extract_ema=False, num_in_channels=None,pipeline_type=None,scheduler_type=None,sd_version=None,half=None,version=None):
|
49 |
+
self.checkpoint_path = checkpoint_path
|
50 |
+
self.output_path = output_path
|
51 |
+
self.prediction_type = prediction_type
|
52 |
+
self.img_size = img_size
|
53 |
+
self.original_config_file = original_config_file
|
54 |
+
self.extract_ema = extract_ema
|
55 |
+
self.num_in_channels = num_in_channels
|
56 |
+
self.pipeline_type = pipeline_type
|
57 |
+
self.scheduler_type = scheduler_type
|
58 |
+
self.sd_version = sd_version
|
59 |
+
self.half = half
|
60 |
+
self.version = version
|
61 |
+
self.main()
|
62 |
+
|
63 |
+
|
64 |
+
def main(self):
|
65 |
+
image_size = self.img_size
|
66 |
+
prediction_type = self.prediction_type
|
67 |
+
original_config_file = self.original_config_file
|
68 |
+
num_in_channels = self.num_in_channels
|
69 |
+
scheduler_type = self.scheduler_type
|
70 |
+
pipeline_type = self.pipeline_type
|
71 |
+
extract_ema = self.extract_ema
|
72 |
+
reference_diffusers_model = None
|
73 |
+
if self.version == 'v1':
|
74 |
+
is_v1 = True
|
75 |
+
is_v2 = False
|
76 |
+
if self.version == 'v2':
|
77 |
+
is_v1 = False
|
78 |
+
is_v2 = True
|
79 |
+
if is_v2 == True and prediction_type == 'vprediction':
|
80 |
+
reference_diffusers_model = 'stabilityai/stable-diffusion-2'
|
81 |
+
if is_v2 == True and prediction_type == 'epsilon':
|
82 |
+
reference_diffusers_model = 'stabilityai/stable-diffusion-2-base'
|
83 |
+
if is_v1 == True and prediction_type == 'epsilon':
|
84 |
+
reference_diffusers_model = 'runwayml/stable-diffusion-v1-5'
|
85 |
+
dtype = 'fp16' if self.half else None
|
86 |
+
v2_model = True if is_v2 else False
|
87 |
+
print(f"loading model from: {self.checkpoint_path}")
|
88 |
+
#print(v2_model)
|
89 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, self.checkpoint_path)
|
90 |
+
print(f"copy scheduler/tokenizer config from: {reference_diffusers_model}")
|
91 |
+
model_util.save_diffusers_checkpoint(v2_model, self.output_path, text_encoder, unet, reference_diffusers_model, vae)
|
92 |
+
print(f"Diffusers model saved.")
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
class Convert_Diffusers_to_SD():
|
97 |
+
def __init__(self,model_path=None, output_path=None):
|
98 |
+
pass
|
99 |
+
def main(model_path:str, output_path:str):
|
100 |
+
#print(model_path)
|
101 |
+
#print(output_path)
|
102 |
+
global_step = None
|
103 |
+
epoch = None
|
104 |
+
dtype = torch.float32
|
105 |
+
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype, tokenizer=None, safety_checker=None)
|
106 |
+
text_encoder = pipe.text_encoder
|
107 |
+
vae = pipe.vae
|
108 |
+
if os.path.exists(os.path.join(model_path, "ema_unet")):
|
109 |
+
pipe.unet = UNet2DConditionModel.from_pretrained(
|
110 |
+
model_path,
|
111 |
+
subfolder="ema_unet",
|
112 |
+
torch_dtype=dtype
|
113 |
+
)
|
114 |
+
unet = pipe.unet
|
115 |
+
v2_model = unet.config.cross_attention_dim == 1024
|
116 |
+
original_model = None
|
117 |
+
key_count = model_util.save_stable_diffusion_checkpoint(v2_model, output_path, text_encoder, unet,
|
118 |
+
original_model, epoch, global_step, dtype, vae)
|
119 |
+
print(f"Saved model")
|
120 |
+
return main(model_path, output_path)
|
StableTuner_RunPod_Fix/dataloaders_util.py
ADDED
@@ -0,0 +1,1331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torch.utils.checkpoint
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from torchvision import transforms
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from trainer_util import *
|
12 |
+
from clip_segmentation import ClipSeg
|
13 |
+
|
14 |
+
class bcolors:
|
15 |
+
HEADER = '\033[95m'
|
16 |
+
OKBLUE = '\033[94m'
|
17 |
+
OKCYAN = '\033[96m'
|
18 |
+
OKGREEN = '\033[92m'
|
19 |
+
WARNING = '\033[93m'
|
20 |
+
FAIL = '\033[91m'
|
21 |
+
ENDC = '\033[0m'
|
22 |
+
BOLD = '\033[1m'
|
23 |
+
UNDERLINE = '\033[4m'
|
24 |
+
ASPECT_2048 = [[2048, 2048],
|
25 |
+
[2112, 1984],[1984, 2112],
|
26 |
+
[2176, 1920],[1920, 2176],
|
27 |
+
[2240, 1856],[1856, 2240],
|
28 |
+
[2304, 1792],[1792, 2304],
|
29 |
+
[2368, 1728],[1728, 2368],
|
30 |
+
[2432, 1664],[1664, 2432],
|
31 |
+
[2496, 1600],[1600, 2496],
|
32 |
+
[2560, 1536],[1536, 2560],
|
33 |
+
[2624, 1472],[1472, 2624]]
|
34 |
+
ASPECT_1984 = [[1984, 1984],
|
35 |
+
[2048, 1920],[1920, 2048],
|
36 |
+
[2112, 1856],[1856, 2112],
|
37 |
+
[2176, 1792],[1792, 2176],
|
38 |
+
[2240, 1728],[1728, 2240],
|
39 |
+
[2304, 1664],[1664, 2304],
|
40 |
+
[2368, 1600],[1600, 2368],
|
41 |
+
[2432, 1536],[1536, 2432],
|
42 |
+
[2496, 1472],[1472, 2496],
|
43 |
+
[2560, 1408],[1408, 2560]]
|
44 |
+
ASPECT_1920 = [[1920, 1920],
|
45 |
+
[1984, 1856],[1856, 1984],
|
46 |
+
[2048, 1792],[1792, 2048],
|
47 |
+
[2112, 1728],[1728, 2112],
|
48 |
+
[2176, 1664],[1664, 2176],
|
49 |
+
[2240, 1600],[1600, 2240],
|
50 |
+
[2304, 1536],[1536, 2304],
|
51 |
+
[2368, 1472],[1472, 2368],
|
52 |
+
[2432, 1408],[1408, 2432],
|
53 |
+
[2496, 1344],[1344, 2496]]
|
54 |
+
ASPECT_1856 = [[1856, 1856],
|
55 |
+
[1920, 1792],[1792, 1920],
|
56 |
+
[1984, 1728],[1728, 1984],
|
57 |
+
[2048, 1664],[1664, 2048],
|
58 |
+
[2112, 1600],[1600, 2112],
|
59 |
+
[2176, 1536],[1536, 2176],
|
60 |
+
[2240, 1472],[1472, 2240],
|
61 |
+
[2304, 1408],[1408, 2304],
|
62 |
+
[2368, 1344],[1344, 2368],
|
63 |
+
[2432, 1280],[1280, 2432]]
|
64 |
+
ASPECT_1792 = [[1792, 1792],
|
65 |
+
[1856, 1728],[1728, 1856],
|
66 |
+
[1920, 1664],[1664, 1920],
|
67 |
+
[1984, 1600],[1600, 1984],
|
68 |
+
[2048, 1536],[1536, 2048],
|
69 |
+
[2112, 1472],[1472, 2112],
|
70 |
+
[2176, 1408],[1408, 2176],
|
71 |
+
[2240, 1344],[1344, 2240],
|
72 |
+
[2304, 1280],[1280, 2304],
|
73 |
+
[2368, 1216],[1216, 2368]]
|
74 |
+
ASPECT_1728 = [[1728, 1728],
|
75 |
+
[1792, 1664],[1664, 1792],
|
76 |
+
[1856, 1600],[1600, 1856],
|
77 |
+
[1920, 1536],[1536, 1920],
|
78 |
+
[1984, 1472],[1472, 1984],
|
79 |
+
[2048, 1408],[1408, 2048],
|
80 |
+
[2112, 1344],[1344, 2112],
|
81 |
+
[2176, 1280],[1280, 2176],
|
82 |
+
[2240, 1216],[1216, 2240],
|
83 |
+
[2304, 1152],[1152, 2304]]
|
84 |
+
ASPECT_1664 = [[1664, 1664],
|
85 |
+
[1728, 1600],[1600, 1728],
|
86 |
+
[1792, 1536],[1536, 1792],
|
87 |
+
[1856, 1472],[1472, 1856],
|
88 |
+
[1920, 1408],[1408, 1920],
|
89 |
+
[1984, 1344],[1344, 1984],
|
90 |
+
[2048, 1280],[1280, 2048],
|
91 |
+
[2112, 1216],[1216, 2112],
|
92 |
+
[2176, 1152],[1152, 2176],
|
93 |
+
[2240, 1088],[1088, 2240]]
|
94 |
+
ASPECT_1600 = [[1600, 1600],
|
95 |
+
[1664, 1536],[1536, 1664],
|
96 |
+
[1728, 1472],[1472, 1728],
|
97 |
+
[1792, 1408],[1408, 1792],
|
98 |
+
[1856, 1344],[1344, 1856],
|
99 |
+
[1920, 1280],[1280, 1920],
|
100 |
+
[1984, 1216],[1216, 1984],
|
101 |
+
[2048, 1152],[1152, 2048],
|
102 |
+
[2112, 1088],[1088, 2112],
|
103 |
+
[2176, 1024],[1024, 2176]]
|
104 |
+
ASPECT_1536 = [[1536, 1536],
|
105 |
+
[1600, 1472],[1472, 1600],
|
106 |
+
[1664, 1408],[1408, 1664],
|
107 |
+
[1728, 1344],[1344, 1728],
|
108 |
+
[1792, 1280],[1280, 1792],
|
109 |
+
[1856, 1216],[1216, 1856],
|
110 |
+
[1920, 1152],[1152, 1920],
|
111 |
+
[1984, 1088],[1088, 1984],
|
112 |
+
[2048, 1024],[1024, 2048],
|
113 |
+
[2112, 960],[960, 2112]]
|
114 |
+
ASPECT_1472 = [[1472, 1472],
|
115 |
+
[1536, 1408],[1408, 1536],
|
116 |
+
[1600, 1344],[1344, 1600],
|
117 |
+
[1664, 1280],[1280, 1664],
|
118 |
+
[1728, 1216],[1216, 1728],
|
119 |
+
[1792, 1152],[1152, 1792],
|
120 |
+
[1856, 1088],[1088, 1856],
|
121 |
+
[1920, 1024],[1024, 1920],
|
122 |
+
[1984, 960],[960, 1984],
|
123 |
+
[2048, 896],[896, 2048]]
|
124 |
+
ASPECT_1408 = [[1408, 1408],
|
125 |
+
[1472, 1344],[1344, 1472],
|
126 |
+
[1536, 1280],[1280, 1536],
|
127 |
+
[1600, 1216],[1216, 1600],
|
128 |
+
[1664, 1152],[1152, 1664],
|
129 |
+
[1728, 1088],[1088, 1728],
|
130 |
+
[1792, 1024],[1024, 1792],
|
131 |
+
[1856, 960],[960, 1856],
|
132 |
+
[1920, 896],[896, 1920],
|
133 |
+
[1984, 832],[832, 1984]]
|
134 |
+
ASPECT_1344 = [[1344, 1344],
|
135 |
+
[1408, 1280],[1280, 1408],
|
136 |
+
[1472, 1216],[1216, 1472],
|
137 |
+
[1536, 1152],[1152, 1536],
|
138 |
+
[1600, 1088],[1088, 1600],
|
139 |
+
[1664, 1024],[1024, 1664],
|
140 |
+
[1728, 960],[960, 1728],
|
141 |
+
[1792, 896],[896, 1792],
|
142 |
+
[1856, 832],[832, 1856],
|
143 |
+
[1920, 768],[768, 1920]]
|
144 |
+
ASPECT_1280 = [[1280, 1280],
|
145 |
+
[1344, 1216],[1216, 1344],
|
146 |
+
[1408, 1152],[1152, 1408],
|
147 |
+
[1472, 1088],[1088, 1472],
|
148 |
+
[1536, 1024],[1024, 1536],
|
149 |
+
[1600, 960],[960, 1600],
|
150 |
+
[1664, 896],[896, 1664],
|
151 |
+
[1728, 832],[832, 1728],
|
152 |
+
[1792, 768],[768, 1792],
|
153 |
+
[1856, 704],[704, 1856]]
|
154 |
+
ASPECT_1216 = [[1216, 1216],
|
155 |
+
[1280, 1152],[1152, 1280],
|
156 |
+
[1344, 1088],[1088, 1344],
|
157 |
+
[1408, 1024],[1024, 1408],
|
158 |
+
[1472, 960],[960, 1472],
|
159 |
+
[1536, 896],[896, 1536],
|
160 |
+
[1600, 832],[832, 1600],
|
161 |
+
[1664, 768],[768, 1664],
|
162 |
+
[1728, 704],[704, 1728],
|
163 |
+
[1792, 640],[640, 1792]]
|
164 |
+
ASPECT_1152 = [[1152, 1152],
|
165 |
+
[1216, 1088],[1088, 1216],
|
166 |
+
[1280, 1024],[1024, 1280],
|
167 |
+
[1344, 960],[960, 1344],
|
168 |
+
[1408, 896],[896, 1408],
|
169 |
+
[1472, 832],[832, 1472],
|
170 |
+
[1536, 768],[768, 1536],
|
171 |
+
[1600, 704],[704, 1600],
|
172 |
+
[1664, 640],[640, 1664],
|
173 |
+
[1728, 576],[576, 1728]]
|
174 |
+
ASPECT_1088 = [[1088, 1088],
|
175 |
+
[1152, 1024],[1024, 1152],
|
176 |
+
[1216, 960],[960, 1216],
|
177 |
+
[1280, 896],[896, 1280],
|
178 |
+
[1344, 832],[832, 1344],
|
179 |
+
[1408, 768],[768, 1408],
|
180 |
+
[1472, 704],[704, 1472],
|
181 |
+
[1536, 640],[640, 1536],
|
182 |
+
[1600, 576],[576, 1600],
|
183 |
+
[1664, 512],[512, 1664]]
|
184 |
+
ASPECT_832 = [[832, 832],
|
185 |
+
[896, 768], [768, 896],
|
186 |
+
[960, 704], [704, 960],
|
187 |
+
[1024, 640], [640, 1024],
|
188 |
+
[1152, 576], [576, 1152],
|
189 |
+
[1280, 512], [512, 1280],
|
190 |
+
[1344, 512], [512, 1344],
|
191 |
+
[1408, 448], [448, 1408],
|
192 |
+
[1472, 448], [448, 1472],
|
193 |
+
[1536, 384], [384, 1536],
|
194 |
+
[1600, 384], [384, 1600]]
|
195 |
+
|
196 |
+
ASPECT_896 = [[896, 896],
|
197 |
+
[960, 832], [832, 960],
|
198 |
+
[1024, 768], [768, 1024],
|
199 |
+
[1088, 704], [704, 1088],
|
200 |
+
[1152, 704], [704, 1152],
|
201 |
+
[1216, 640], [640, 1216],
|
202 |
+
[1280, 640], [640, 1280],
|
203 |
+
[1344, 576], [576, 1344],
|
204 |
+
[1408, 576], [576, 1408],
|
205 |
+
[1472, 512], [512, 1472],
|
206 |
+
[1536, 512], [512, 1536],
|
207 |
+
[1600, 448], [448, 1600],
|
208 |
+
[1664, 448], [448, 1664]]
|
209 |
+
ASPECT_960 = [[960, 960],
|
210 |
+
[1024, 896],[896, 1024],
|
211 |
+
[1088, 832],[832, 1088],
|
212 |
+
[1152, 768],[768, 1152],
|
213 |
+
[1216, 704],[704, 1216],
|
214 |
+
[1280, 640],[640, 1280],
|
215 |
+
[1344, 576],[576, 1344],
|
216 |
+
[1408, 512],[512, 1408],
|
217 |
+
[1472, 448],[448, 1472],
|
218 |
+
[1536, 384],[384, 1536]]
|
219 |
+
ASPECT_1024 = [[1024, 1024],
|
220 |
+
[1088, 960], [960, 1088],
|
221 |
+
[1152, 896], [896, 1152],
|
222 |
+
[1216, 832], [832, 1216],
|
223 |
+
[1344, 768], [768, 1344],
|
224 |
+
[1472, 704], [704, 1472],
|
225 |
+
[1600, 640], [640, 1600],
|
226 |
+
[1728, 576], [576, 1728],
|
227 |
+
[1792, 576], [576, 1792]]
|
228 |
+
ASPECT_768 = [[768,768], # 589824 1:1
|
229 |
+
[896,640],[640,896], # 573440 1.4:1
|
230 |
+
[832,704],[704,832], # 585728 1.181:1
|
231 |
+
[960,576],[576,960], # 552960 1.6:1
|
232 |
+
[1024,576],[576,1024], # 524288 1.778:1
|
233 |
+
[1088,512],[512,1088], # 497664 2.125:1
|
234 |
+
[1152,512],[512,1152], # 589824 2.25:1
|
235 |
+
[1216,448],[448,1216], # 552960 2.714:1
|
236 |
+
[1280,448],[448,1280], # 573440 2.857:1
|
237 |
+
[1344,384],[384,1344], # 518400 3.5:1
|
238 |
+
[1408,384],[384,1408], # 540672 3.667:1
|
239 |
+
[1472,320],[320,1472], # 470400 4.6:1
|
240 |
+
[1536,320],[320,1536], # 491520 4.8:1
|
241 |
+
]
|
242 |
+
|
243 |
+
ASPECT_704 = [[704,704], # 501,376 1:1
|
244 |
+
[768,640],[640,768], # 491,520 1.2:1
|
245 |
+
[832,576],[576,832], # 458,752 1.444:1
|
246 |
+
[896,512],[512,896], # 458,752 1.75:1
|
247 |
+
[960,512],[512,960], # 491,520 1.875:1
|
248 |
+
[1024,448],[448,1024], # 458,752 2.286:1
|
249 |
+
[1088,448],[448,1088], # 487,424 2.429:1
|
250 |
+
[1152,384],[384,1152], # 442,368 3:1
|
251 |
+
[1216,384],[384,1216], # 466,944 3.125:1
|
252 |
+
[1280,384],[384,1280], # 491,520 3.333:1
|
253 |
+
[1280,320],[320,1280], # 409,600 4:1
|
254 |
+
[1408,320],[320,1408], # 450,560 4.4:1
|
255 |
+
[1536,320],[320,1536], # 491,520 4.8:1
|
256 |
+
]
|
257 |
+
|
258 |
+
ASPECT_640 = [[640,640], # 409600 1:1
|
259 |
+
[704,576],[576,704], # 405504 1.25:1
|
260 |
+
[768,512],[512,768], # 393216 1.5:1
|
261 |
+
[896,448],[448,896], # 401408 2:1
|
262 |
+
[1024,384],[384,1024], # 393216 2.667:1
|
263 |
+
[1280,320],[320,1280], # 409600 4:1
|
264 |
+
[1408,256],[256,1408], # 360448 5.5:1
|
265 |
+
[1472,256],[256,1472], # 376832 5.75:1
|
266 |
+
[1536,256],[256,1536], # 393216 6:1
|
267 |
+
[1600,256],[256,1600], # 409600 6.25:1
|
268 |
+
]
|
269 |
+
|
270 |
+
ASPECT_576 = [[576,576], # 331776 1:1
|
271 |
+
[640,512],[512,640], # 327680 1.25:1
|
272 |
+
[640,448],[448,640], # 286720 1.4286:1
|
273 |
+
[704,448],[448,704], # 314928 1.5625:1
|
274 |
+
[832,384],[384,832], # 317440 2.1667:1
|
275 |
+
[1024,320],[320,1024], # 327680 3.2:1
|
276 |
+
[1280,256],[256,1280], # 327680 5:1
|
277 |
+
]
|
278 |
+
|
279 |
+
ASPECT_512 = [[512,512], # 262144 1:1
|
280 |
+
[576,448],[448,576], # 258048 1.29:1
|
281 |
+
[640,384],[384,640], # 245760 1.667:1
|
282 |
+
[768,320],[320,768], # 245760 2.4:1
|
283 |
+
[832,256],[256,832], # 212992 3.25:1
|
284 |
+
[896,256],[256,896], # 229376 3.5:1
|
285 |
+
[960,256],[256,960], # 245760 3.75:1
|
286 |
+
[1024,256],[256,1024], # 245760 4:1
|
287 |
+
]
|
288 |
+
|
289 |
+
ASPECT_448 = [[448,448], # 200704 1:1
|
290 |
+
[512,384],[384,512], # 196608 1.33:1
|
291 |
+
[576,320],[320,576], # 184320 1.8:1
|
292 |
+
[768,256],[256,768], # 196608 3:1
|
293 |
+
]
|
294 |
+
|
295 |
+
ASPECT_384 = [[384,384], # 147456 1:1
|
296 |
+
[448,320],[320,448], # 143360 1.4:1
|
297 |
+
[576,256],[256,576], # 147456 2.25:1
|
298 |
+
[768,192],[192,768], # 147456 4:1
|
299 |
+
]
|
300 |
+
|
301 |
+
ASPECT_320 = [[320,320], # 102400 1:1
|
302 |
+
[384,256],[256,384], # 98304 1.5:1
|
303 |
+
[512,192],[192,512], # 98304 2.67:1
|
304 |
+
]
|
305 |
+
|
306 |
+
ASPECT_256 = [[256,256], # 65536 1:1
|
307 |
+
[320,192],[192,320], # 61440 1.67:1
|
308 |
+
[512,128],[128,512], # 65536 4:1
|
309 |
+
]
|
310 |
+
|
311 |
+
#failsafe aspects
|
312 |
+
ASPECTS = ASPECT_512
|
313 |
+
def get_aspect_buckets(resolution,mode=''):
|
314 |
+
if resolution < 256:
|
315 |
+
raise ValueError("Resolution must be at least 512")
|
316 |
+
try:
|
317 |
+
rounded_resolution = int(resolution / 64) * 64
|
318 |
+
print(f" {bcolors.WARNING} Rounded resolution to: {rounded_resolution}{bcolors.ENDC}")
|
319 |
+
all_image_sizes = __get_all_aspects()
|
320 |
+
if mode == 'MJ':
|
321 |
+
#truncate to the first 3 resolutions
|
322 |
+
all_image_sizes = [x[0:3] for x in all_image_sizes]
|
323 |
+
aspects = next(filter(lambda sizes: sizes[0][0]==rounded_resolution, all_image_sizes), None)
|
324 |
+
ASPECTS = aspects
|
325 |
+
#print(aspects)
|
326 |
+
return aspects
|
327 |
+
except Exception as e:
|
328 |
+
print(f" {bcolors.FAIL} *** Could not find selected resolution: {rounded_resolution}{bcolors.ENDC}")
|
329 |
+
|
330 |
+
raise e
|
331 |
+
|
332 |
+
def __get_all_aspects():
|
333 |
+
return [ASPECT_256, ASPECT_320, ASPECT_384, ASPECT_448, ASPECT_512, ASPECT_576, ASPECT_640, ASPECT_704, ASPECT_768,ASPECT_832,ASPECT_896,ASPECT_960,ASPECT_1024,ASPECT_1088,ASPECT_1152,ASPECT_1216,ASPECT_1280,ASPECT_1344,ASPECT_1408,ASPECT_1472,ASPECT_1536,ASPECT_1600,ASPECT_1664,ASPECT_1728,ASPECT_1792,ASPECT_1856,ASPECT_1920,ASPECT_1984,ASPECT_2048]
|
334 |
+
class AutoBucketing(Dataset):
|
335 |
+
def __init__(self,
|
336 |
+
concepts_list,
|
337 |
+
tokenizer=None,
|
338 |
+
flip_p=0.0,
|
339 |
+
repeats=1,
|
340 |
+
debug_level=0,
|
341 |
+
batch_size=1,
|
342 |
+
set='val',
|
343 |
+
resolution=512,
|
344 |
+
center_crop=False,
|
345 |
+
use_image_names_as_captions=True,
|
346 |
+
shuffle_captions=False,
|
347 |
+
add_class_images_to_dataset=None,
|
348 |
+
balance_datasets=False,
|
349 |
+
crop_jitter=20,
|
350 |
+
with_prior_loss=False,
|
351 |
+
use_text_files_as_captions=False,
|
352 |
+
aspect_mode='dynamic',
|
353 |
+
action_preference='dynamic',
|
354 |
+
seed=555,
|
355 |
+
model_variant='base',
|
356 |
+
extra_module=None,
|
357 |
+
mask_prompts=None,
|
358 |
+
load_mask=False,
|
359 |
+
):
|
360 |
+
|
361 |
+
self.debug_level = debug_level
|
362 |
+
self.resolution = resolution
|
363 |
+
self.center_crop = center_crop
|
364 |
+
self.tokenizer = tokenizer
|
365 |
+
self.batch_size = batch_size
|
366 |
+
self.concepts_list = concepts_list
|
367 |
+
self.use_image_names_as_captions = use_image_names_as_captions
|
368 |
+
self.shuffle_captions = shuffle_captions
|
369 |
+
self.num_train_images = 0
|
370 |
+
self.num_reg_images = 0
|
371 |
+
self.image_train_items = []
|
372 |
+
self.image_reg_items = []
|
373 |
+
self.add_class_images_to_dataset = add_class_images_to_dataset
|
374 |
+
self.balance_datasets = balance_datasets
|
375 |
+
self.crop_jitter = crop_jitter
|
376 |
+
self.with_prior_loss = with_prior_loss
|
377 |
+
self.use_text_files_as_captions = use_text_files_as_captions
|
378 |
+
self.aspect_mode = aspect_mode
|
379 |
+
self.action_preference = action_preference
|
380 |
+
self.model_variant = model_variant
|
381 |
+
self.extra_module = extra_module
|
382 |
+
self.image_transforms = transforms.Compose(
|
383 |
+
[
|
384 |
+
transforms.ToTensor(),
|
385 |
+
transforms.Normalize([0.5], [0.5]),
|
386 |
+
]
|
387 |
+
)
|
388 |
+
self.mask_transforms = transforms.Compose(
|
389 |
+
[
|
390 |
+
transforms.ToTensor(),
|
391 |
+
]
|
392 |
+
)
|
393 |
+
self.depth_image_transforms = transforms.Compose(
|
394 |
+
[
|
395 |
+
transforms.ToTensor(),
|
396 |
+
]
|
397 |
+
)
|
398 |
+
self.seed = seed
|
399 |
+
#shared_dataloader = None
|
400 |
+
print(f" {bcolors.WARNING}Creating Auto Bucketing Dataloader{bcolors.ENDC}")
|
401 |
+
|
402 |
+
shared_dataloader = DataLoaderMultiAspect(concepts_list,
|
403 |
+
debug_level=debug_level,
|
404 |
+
resolution=self.resolution,
|
405 |
+
seed=self.seed,
|
406 |
+
batch_size=self.batch_size,
|
407 |
+
flip_p=flip_p,
|
408 |
+
use_image_names_as_captions=self.use_image_names_as_captions,
|
409 |
+
add_class_images_to_dataset=self.add_class_images_to_dataset,
|
410 |
+
balance_datasets=self.balance_datasets,
|
411 |
+
with_prior_loss=self.with_prior_loss,
|
412 |
+
use_text_files_as_captions=self.use_text_files_as_captions,
|
413 |
+
aspect_mode=self.aspect_mode,
|
414 |
+
action_preference=self.action_preference,
|
415 |
+
model_variant=self.model_variant,
|
416 |
+
extra_module=self.extra_module,
|
417 |
+
mask_prompts=mask_prompts,
|
418 |
+
load_mask=load_mask,
|
419 |
+
)
|
420 |
+
|
421 |
+
#print(self.image_train_items)
|
422 |
+
if self.with_prior_loss and self.add_class_images_to_dataset == False:
|
423 |
+
self.image_train_items, self.class_train_items = shared_dataloader.get_all_images()
|
424 |
+
self.num_train_images = self.num_train_images + len(self.image_train_items)
|
425 |
+
self.num_reg_images = self.num_reg_images + len(self.class_train_items)
|
426 |
+
self._length = max(max(math.trunc(self.num_train_images * repeats), batch_size),math.trunc(self.num_reg_images * repeats), batch_size) - self.num_train_images % self.batch_size
|
427 |
+
self.num_train_images = self.num_train_images + self.num_reg_images
|
428 |
+
|
429 |
+
else:
|
430 |
+
self.image_train_items = shared_dataloader.get_all_images()
|
431 |
+
self.num_train_images = self.num_train_images + len(self.image_train_items)
|
432 |
+
self._length = max(math.trunc(self.num_train_images * repeats), batch_size) - self.num_train_images % self.batch_size
|
433 |
+
|
434 |
+
print()
|
435 |
+
print(f" {bcolors.WARNING} ** Validation Set: {set}, steps: {self._length / batch_size:.0f}, repeats: {repeats} {bcolors.ENDC}")
|
436 |
+
print()
|
437 |
+
|
438 |
+
|
439 |
+
def __len__(self):
|
440 |
+
return self._length
|
441 |
+
|
442 |
+
def __getitem__(self, i):
|
443 |
+
idx = i % self.num_train_images
|
444 |
+
#print(idx)
|
445 |
+
image_train_item = self.image_train_items[idx]
|
446 |
+
|
447 |
+
example = self.__get_image_for_trainer(image_train_item,debug_level=self.debug_level)
|
448 |
+
if self.with_prior_loss and self.add_class_images_to_dataset == False:
|
449 |
+
idx = i % self.num_reg_images
|
450 |
+
class_train_item = self.class_train_items[idx]
|
451 |
+
example_class = self.__get_image_for_trainer(class_train_item,debug_level=self.debug_level,class_img=True)
|
452 |
+
example= {**example, **example_class}
|
453 |
+
|
454 |
+
#print the tensor shape
|
455 |
+
#print(example['instance_images'].shape)
|
456 |
+
#print(example.keys())
|
457 |
+
return example
|
458 |
+
def normalize8(self,I):
|
459 |
+
mn = I.min()
|
460 |
+
mx = I.max()
|
461 |
+
|
462 |
+
mx -= mn
|
463 |
+
|
464 |
+
I = ((I - mn)/mx) * 255
|
465 |
+
return I.astype(np.uint8)
|
466 |
+
def __get_image_for_trainer(self,image_train_item,debug_level=0,class_img=False):
|
467 |
+
example = {}
|
468 |
+
save = debug_level > 2
|
469 |
+
|
470 |
+
if class_img==False:
|
471 |
+
image_train_tmp = image_train_item.hydrate(crop=False, save=0, crop_jitter=self.crop_jitter)
|
472 |
+
image_train_tmp_image = Image.fromarray(self.normalize8(image_train_tmp.image)).convert("RGB")
|
473 |
+
|
474 |
+
instance_prompt = image_train_tmp.caption
|
475 |
+
if self.shuffle_captions:
|
476 |
+
caption_parts = instance_prompt.split(",")
|
477 |
+
random.shuffle(caption_parts)
|
478 |
+
instance_prompt = ",".join(caption_parts)
|
479 |
+
|
480 |
+
example["instance_images"] = self.image_transforms(image_train_tmp_image)
|
481 |
+
if image_train_tmp.mask is not None:
|
482 |
+
image_train_tmp_mask = Image.fromarray(self.normalize8(image_train_tmp.mask)).convert("L")
|
483 |
+
example["mask"] = self.mask_transforms(image_train_tmp_mask)
|
484 |
+
if self.model_variant == 'depth2img':
|
485 |
+
image_train_tmp_depth = Image.fromarray(self.normalize8(image_train_tmp.extra)).convert("L")
|
486 |
+
example["instance_depth_images"] = self.depth_image_transforms(image_train_tmp_depth)
|
487 |
+
#print(instance_prompt)
|
488 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
489 |
+
instance_prompt,
|
490 |
+
padding="do_not_pad",
|
491 |
+
truncation=True,
|
492 |
+
max_length=self.tokenizer.model_max_length,
|
493 |
+
).input_ids
|
494 |
+
image_train_item.self_destruct()
|
495 |
+
return example
|
496 |
+
|
497 |
+
if class_img==True:
|
498 |
+
image_train_tmp = image_train_item.hydrate(crop=False, save=4, crop_jitter=self.crop_jitter)
|
499 |
+
image_train_tmp_image = Image.fromarray(self.normalize8(image_train_tmp.image)).convert("RGB")
|
500 |
+
if self.model_variant == 'depth2img':
|
501 |
+
image_train_tmp_depth = Image.fromarray(self.normalize8(image_train_tmp.extra)).convert("L")
|
502 |
+
example["class_depth_images"] = self.depth_image_transforms(image_train_tmp_depth)
|
503 |
+
example["class_images"] = self.image_transforms(image_train_tmp_image)
|
504 |
+
example["class_prompt_ids"] = self.tokenizer(
|
505 |
+
image_train_tmp.caption,
|
506 |
+
padding="do_not_pad",
|
507 |
+
truncation=True,
|
508 |
+
max_length=self.tokenizer.model_max_length,
|
509 |
+
).input_ids
|
510 |
+
image_train_item.self_destruct()
|
511 |
+
return example
|
512 |
+
|
513 |
+
_RANDOM_TRIM = 0.04
|
514 |
+
class ImageTrainItem():
|
515 |
+
"""
|
516 |
+
image: Image
|
517 |
+
mask: Image
|
518 |
+
extra: Image
|
519 |
+
identifier: caption,
|
520 |
+
target_aspect: (width, height),
|
521 |
+
pathname: path to image file
|
522 |
+
flip_p: probability of flipping image (0.0 to 1.0)
|
523 |
+
"""
|
524 |
+
def __init__(self, image: Image, mask: Image, extra: Image, caption: str, target_wh: list, pathname: str, flip_p=0.0, model_variant='base', load_mask=False):
|
525 |
+
self.caption = caption
|
526 |
+
self.target_wh = target_wh
|
527 |
+
self.pathname = pathname
|
528 |
+
self.mask_pathname = os.path.splitext(pathname)[0] + "-masklabel.png"
|
529 |
+
self.depth_pathname = os.path.splitext(pathname)[0] + "-depth.png"
|
530 |
+
self.flip_p = flip_p
|
531 |
+
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
532 |
+
self.cropped_img = None
|
533 |
+
self.model_variant = model_variant
|
534 |
+
self.load_mask=load_mask
|
535 |
+
self.is_dupe = []
|
536 |
+
self.variant_warning = False
|
537 |
+
|
538 |
+
self.image = image
|
539 |
+
self.mask = mask
|
540 |
+
self.extra = extra
|
541 |
+
|
542 |
+
def self_destruct(self):
|
543 |
+
self.image = None
|
544 |
+
self.mask = None
|
545 |
+
self.extra = None
|
546 |
+
self.cropped_img = None
|
547 |
+
self.is_dupe.append(1)
|
548 |
+
|
549 |
+
def load_image(self, pathname, crop, jitter_amount, flip):
|
550 |
+
if len(self.is_dupe) > 0:
|
551 |
+
self.flip = transforms.RandomHorizontalFlip(p=1.0 if flip else 0.0)
|
552 |
+
image = Image.open(pathname).convert('RGB')
|
553 |
+
|
554 |
+
width, height = image.size
|
555 |
+
if crop:
|
556 |
+
cropped_img = self.__autocrop(image)
|
557 |
+
image = cropped_img.resize((512, 512), resample=Image.Resampling.LANCZOS)
|
558 |
+
else:
|
559 |
+
width, height = image.size
|
560 |
+
|
561 |
+
if self.target_wh[0] == self.target_wh[1]:
|
562 |
+
if width > height:
|
563 |
+
left = random.randint(0, width - height)
|
564 |
+
image = image.crop((left, 0, height + left, height))
|
565 |
+
width = height
|
566 |
+
elif height > width:
|
567 |
+
top = random.randint(0, height - width)
|
568 |
+
image = image.crop((0, top, width, width + top))
|
569 |
+
height = width
|
570 |
+
elif width > self.target_wh[0]:
|
571 |
+
slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width - self.target_wh[0])
|
572 |
+
slicew_ratio = random.random()
|
573 |
+
left = int(slice * slicew_ratio)
|
574 |
+
right = width - int(slice * (1 - slicew_ratio))
|
575 |
+
sliceh_ratio = random.random()
|
576 |
+
top = int(slice * sliceh_ratio)
|
577 |
+
bottom = height - int(slice * (1 - sliceh_ratio))
|
578 |
+
|
579 |
+
image = image.crop((left, top, right, bottom))
|
580 |
+
else:
|
581 |
+
image_aspect = width / height
|
582 |
+
target_aspect = self.target_wh[0] / self.target_wh[1]
|
583 |
+
if image_aspect > target_aspect:
|
584 |
+
new_width = int(height * target_aspect)
|
585 |
+
jitter_amount = max(min(jitter_amount, int(abs(width - new_width) / 2)), 0)
|
586 |
+
left = jitter_amount
|
587 |
+
right = left + new_width
|
588 |
+
image = image.crop((left, 0, right, height))
|
589 |
+
else:
|
590 |
+
new_height = int(width / target_aspect)
|
591 |
+
jitter_amount = max(min(jitter_amount, int(abs(height - new_height) / 2)), 0)
|
592 |
+
top = jitter_amount
|
593 |
+
bottom = top + new_height
|
594 |
+
image = image.crop((0, top, width, bottom))
|
595 |
+
# LAZCOS resample
|
596 |
+
image = image.resize(self.target_wh, resample=Image.Resampling.LANCZOS)
|
597 |
+
# print the pixel count of the image
|
598 |
+
# print path to image file
|
599 |
+
# print(self.pathname)
|
600 |
+
# print(self.image.size[0] * self.image.size[1])
|
601 |
+
image = self.flip(image)
|
602 |
+
return image
|
603 |
+
|
604 |
+
def hydrate(self, crop=False, save=False, crop_jitter=20):
|
605 |
+
"""
|
606 |
+
crop: hard center crop to 512x512
|
607 |
+
save: save the cropped image to disk, for manual inspection of resize/crop
|
608 |
+
crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality
|
609 |
+
"""
|
610 |
+
|
611 |
+
if self.image is None:
|
612 |
+
chance = float(len(self.is_dupe)) / 10.0
|
613 |
+
|
614 |
+
flip_p = self.flip_p + chance if chance < 1.0 else 1.0
|
615 |
+
flip = random.uniform(0, 1) < flip_p
|
616 |
+
|
617 |
+
if len(self.is_dupe) > 0:
|
618 |
+
crop_jitter = crop_jitter + (len(self.is_dupe) * 10) if crop_jitter < 50 else 50
|
619 |
+
|
620 |
+
jitter_amount = random.randint(0, crop_jitter)
|
621 |
+
|
622 |
+
self.image = self.load_image(self.pathname, crop, jitter_amount, flip)
|
623 |
+
|
624 |
+
if self.model_variant == "inpainting" or self.load_mask:
|
625 |
+
if os.path.exists(self.mask_pathname) and self.load_mask:
|
626 |
+
self.mask = self.load_image(self.mask_pathname, crop, jitter_amount, flip)
|
627 |
+
else:
|
628 |
+
if self.variant_warning == False:
|
629 |
+
print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}")
|
630 |
+
self.variant_warning = True
|
631 |
+
self.mask = Image.new('RGB', self.image.size, color="white").convert("L")
|
632 |
+
|
633 |
+
if self.model_variant == "depth2img":
|
634 |
+
if os.path.exists(self.depth_pathname):
|
635 |
+
self.extra = self.load_image(self.depth_pathname, crop, jitter_amount, flip)
|
636 |
+
else:
|
637 |
+
if self.variant_warning == False:
|
638 |
+
print(f" {bcolors.FAIL} ** Warning: No depth found for an image, using an empty depth but make sure you're training the right model variant.{bcolors.ENDC}")
|
639 |
+
self.variant_warning = True
|
640 |
+
self.extra = Image.new('RGB', self.image.size, color="white").convert("L")
|
641 |
+
if type(self.image) is not np.ndarray:
|
642 |
+
if save:
|
643 |
+
base_name = os.path.basename(self.pathname)
|
644 |
+
if not os.path.exists("test/output"):
|
645 |
+
os.makedirs("test/output")
|
646 |
+
self.image.save(f"test/output/{base_name}")
|
647 |
+
|
648 |
+
self.image = np.array(self.image).astype(np.uint8)
|
649 |
+
|
650 |
+
self.image = (self.image / 127.5 - 1.0).astype(np.float32)
|
651 |
+
if self.mask is not None and type(self.mask) is not np.ndarray:
|
652 |
+
self.mask = np.array(self.mask).astype(np.uint8)
|
653 |
+
|
654 |
+
self.mask = (self.mask / 255.0).astype(np.float32)
|
655 |
+
if self.extra is not None and type(self.extra) is not np.ndarray:
|
656 |
+
self.extra = np.array(self.extra).astype(np.uint8)
|
657 |
+
|
658 |
+
self.extra = (self.extra / 255.0).astype(np.float32)
|
659 |
+
|
660 |
+
#print(self.image.shape)
|
661 |
+
|
662 |
+
return self
|
663 |
+
class CachedLatentsDataset(Dataset):
|
664 |
+
#stores paths and loads latents on the fly
|
665 |
+
def __init__(self, cache_paths=(),batch_size=None,tokenizer=None,text_encoder=None,dtype=None,model_variant='base',shuffle_per_epoch=False,args=None):
|
666 |
+
self.cache_paths = cache_paths
|
667 |
+
self.tokenizer = tokenizer
|
668 |
+
self.args = args
|
669 |
+
self.text_encoder = text_encoder
|
670 |
+
#get text encoder device
|
671 |
+
text_encoder_device = next(self.text_encoder.parameters()).device
|
672 |
+
self.empty_batch = [self.tokenizer('',padding="do_not_pad",truncation=True,max_length=self.tokenizer.model_max_length,).input_ids for i in range(batch_size)]
|
673 |
+
#handle text encoder for empty tokens
|
674 |
+
if self.args.train_text_encoder != True:
|
675 |
+
self.empty_tokens = tokenizer.pad({"input_ids": self.empty_batch},padding="max_length",max_length=tokenizer.model_max_length,return_tensors="pt",).to(text_encoder_device).input_ids
|
676 |
+
self.empty_tokens.to(text_encoder_device, dtype=dtype)
|
677 |
+
self.empty_tokens = self.text_encoder(self.empty_tokens)[0]
|
678 |
+
else:
|
679 |
+
self.empty_tokens = tokenizer.pad({"input_ids": self.empty_batch},padding="max_length",max_length=tokenizer.model_max_length,return_tensors="pt",).input_ids
|
680 |
+
self.empty_tokens.to(text_encoder_device, dtype=dtype)
|
681 |
+
|
682 |
+
self.conditional_dropout = args.conditional_dropout
|
683 |
+
self.conditional_indexes = []
|
684 |
+
self.model_variant = model_variant
|
685 |
+
self.shuffle_per_epoch = shuffle_per_epoch
|
686 |
+
def __len__(self):
|
687 |
+
return len(self.cache_paths)
|
688 |
+
def __getitem__(self, index):
|
689 |
+
if index == 0:
|
690 |
+
if self.shuffle_per_epoch == True:
|
691 |
+
self.cache_paths = tuple(random.sample(self.cache_paths, len(self.cache_paths)))
|
692 |
+
if len(self.cache_paths) > 1:
|
693 |
+
possible_indexes_extension = None
|
694 |
+
possible_indexes = list(range(0,len(self.cache_paths)))
|
695 |
+
#conditional dropout is a percentage of images to drop from the total cache_paths
|
696 |
+
if self.conditional_dropout != None:
|
697 |
+
if len(self.conditional_indexes) == 0:
|
698 |
+
self.conditional_indexes = random.sample(possible_indexes, k=int(math.ceil(len(possible_indexes)*self.conditional_dropout)))
|
699 |
+
else:
|
700 |
+
#pick indexes from the remaining possible indexes
|
701 |
+
possible_indexes_extension = [i for i in possible_indexes if i not in self.conditional_indexes]
|
702 |
+
#duplicate all values in possible_indexes_extension
|
703 |
+
possible_indexes_extension = possible_indexes_extension + possible_indexes_extension
|
704 |
+
possible_indexes_extension = possible_indexes_extension + self.conditional_indexes
|
705 |
+
self.conditional_indexes = random.sample(possible_indexes_extension, k=int(math.ceil(len(possible_indexes)*self.conditional_dropout)))
|
706 |
+
#check for duplicates in conditional_indexes values
|
707 |
+
if len(self.conditional_indexes) != len(set(self.conditional_indexes)):
|
708 |
+
#remove duplicates
|
709 |
+
self.conditional_indexes_non_dupe = list(set(self.conditional_indexes))
|
710 |
+
#add a random value from possible_indexes_extension for each duplicate
|
711 |
+
for i in range(len(self.conditional_indexes) - len(self.conditional_indexes_non_dupe)):
|
712 |
+
while True:
|
713 |
+
random_value = random.choice(possible_indexes_extension)
|
714 |
+
if random_value not in self.conditional_indexes_non_dupe:
|
715 |
+
self.conditional_indexes_non_dupe.append(random_value)
|
716 |
+
break
|
717 |
+
self.conditional_indexes = self.conditional_indexes_non_dupe
|
718 |
+
self.cache = torch.load(self.cache_paths[index])
|
719 |
+
self.latents = self.cache.latents_cache[0]
|
720 |
+
self.tokens = self.cache.tokens_cache[0]
|
721 |
+
self.extra_cache = None
|
722 |
+
self.mask_cache = None
|
723 |
+
if self.cache.mask_cache is not None:
|
724 |
+
self.mask_cache = self.cache.mask_cache[0]
|
725 |
+
self.mask_mean_cache = None
|
726 |
+
if self.cache.mask_mean_cache is not None:
|
727 |
+
self.mask_mean_cache = self.cache.mask_mean_cache[0]
|
728 |
+
if index in self.conditional_indexes:
|
729 |
+
self.text_encoder = self.empty_tokens
|
730 |
+
else:
|
731 |
+
self.text_encoder = self.cache.text_encoder_cache[0]
|
732 |
+
if self.model_variant != 'base':
|
733 |
+
self.extra_cache = self.cache.extra_cache[0]
|
734 |
+
del self.cache
|
735 |
+
return self.latents, self.text_encoder, self.mask_cache, self.mask_mean_cache, self.extra_cache, self.tokens
|
736 |
+
|
737 |
+
def add_pt_cache(self, cache_path):
|
738 |
+
if len(self.cache_paths) == 0:
|
739 |
+
self.cache_paths = (cache_path,)
|
740 |
+
else:
|
741 |
+
self.cache_paths += (cache_path,)
|
742 |
+
|
743 |
+
class LatentsDataset(Dataset):
|
744 |
+
def __init__(self, latents_cache=None, text_encoder_cache=None, mask_cache=None, mask_mean_cache=None, extra_cache=None,tokens_cache=None):
|
745 |
+
self.latents_cache = latents_cache
|
746 |
+
self.text_encoder_cache = text_encoder_cache
|
747 |
+
self.mask_cache = mask_cache
|
748 |
+
self.mask_mean_cache = mask_mean_cache
|
749 |
+
self.extra_cache = extra_cache
|
750 |
+
self.tokens_cache = tokens_cache
|
751 |
+
def add_latent(self, latent, text_encoder, cached_mask, cached_extra, tokens_cache):
|
752 |
+
self.latents_cache.append(latent)
|
753 |
+
self.text_encoder_cache.append(text_encoder)
|
754 |
+
self.mask_cache.append(cached_mask)
|
755 |
+
self.mask_mean_cache.append(None if cached_mask is None else cached_mask.mean())
|
756 |
+
self.extra_cache.append(cached_extra)
|
757 |
+
self.tokens_cache.append(tokens_cache)
|
758 |
+
def __len__(self):
|
759 |
+
return len(self.latents_cache)
|
760 |
+
def __getitem__(self, index):
|
761 |
+
return self.latents_cache[index], self.text_encoder_cache[index], self.mask_cache[index], self.mask_mean_cache[index], self.extra_cache[index], self.tokens_cache[index]
|
762 |
+
|
763 |
+
class DataLoaderMultiAspect():
|
764 |
+
"""
|
765 |
+
Data loader for multi-aspect-ratio training and bucketing
|
766 |
+
data_root: root folder of training data
|
767 |
+
batch_size: number of images per batch
|
768 |
+
flip_p: probability of flipping image horizontally (i.e. 0-0.5)
|
769 |
+
"""
|
770 |
+
def __init__(
|
771 |
+
self,
|
772 |
+
concept_list,
|
773 |
+
seed=555,
|
774 |
+
debug_level=0,
|
775 |
+
resolution=512,
|
776 |
+
batch_size=1,
|
777 |
+
flip_p=0.0,
|
778 |
+
use_image_names_as_captions=True,
|
779 |
+
add_class_images_to_dataset=False,
|
780 |
+
balance_datasets=False,
|
781 |
+
with_prior_loss=False,
|
782 |
+
use_text_files_as_captions=False,
|
783 |
+
aspect_mode='dynamic',
|
784 |
+
action_preference='add',
|
785 |
+
model_variant='base',
|
786 |
+
extra_module=None,
|
787 |
+
mask_prompts=None,
|
788 |
+
load_mask=False,
|
789 |
+
):
|
790 |
+
self.resolution = resolution
|
791 |
+
self.debug_level = debug_level
|
792 |
+
self.flip_p = flip_p
|
793 |
+
self.use_image_names_as_captions = use_image_names_as_captions
|
794 |
+
self.balance_datasets = balance_datasets
|
795 |
+
self.with_prior_loss = with_prior_loss
|
796 |
+
self.add_class_images_to_dataset = add_class_images_to_dataset
|
797 |
+
self.use_text_files_as_captions = use_text_files_as_captions
|
798 |
+
self.aspect_mode = aspect_mode
|
799 |
+
self.action_preference = action_preference
|
800 |
+
self.seed = seed
|
801 |
+
self.model_variant = model_variant
|
802 |
+
self.extra_module = extra_module
|
803 |
+
self.load_mask = load_mask
|
804 |
+
prepared_train_data = []
|
805 |
+
|
806 |
+
self.aspects = get_aspect_buckets(resolution)
|
807 |
+
#print(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
|
808 |
+
#process sub directories flag
|
809 |
+
|
810 |
+
print(f" {bcolors.WARNING} Preloading images...{bcolors.ENDC}")
|
811 |
+
|
812 |
+
if balance_datasets:
|
813 |
+
print(f" {bcolors.WARNING} Balancing datasets...{bcolors.ENDC}")
|
814 |
+
#get the concept with the least number of images in instance_data_dir
|
815 |
+
for concept in concept_list:
|
816 |
+
count = 0
|
817 |
+
if 'use_sub_dirs' in concept:
|
818 |
+
if concept['use_sub_dirs'] == 1:
|
819 |
+
tot = 0
|
820 |
+
for root, dirs, files in os.walk(concept['instance_data_dir']):
|
821 |
+
tot += len(files)
|
822 |
+
count = tot
|
823 |
+
else:
|
824 |
+
count = len(os.listdir(concept['instance_data_dir']))
|
825 |
+
else:
|
826 |
+
count = len(os.listdir(concept['instance_data_dir']))
|
827 |
+
print(f"{concept['instance_data_dir']} has count of {count}")
|
828 |
+
concept['count'] = count
|
829 |
+
|
830 |
+
min_concept = min(concept_list, key=lambda x: x['count'])
|
831 |
+
#get the number of images in the concept with the least number of images
|
832 |
+
min_concept_num_images = min_concept['count']
|
833 |
+
print(" Min concept: ",min_concept['instance_data_dir']," with ",min_concept_num_images," images")
|
834 |
+
|
835 |
+
balance_cocnept_list = []
|
836 |
+
for concept in concept_list:
|
837 |
+
#if concept has a key do not balance it
|
838 |
+
if 'do_not_balance' in concept:
|
839 |
+
if concept['do_not_balance'] == True:
|
840 |
+
balance_cocnept_list.append(-1)
|
841 |
+
else:
|
842 |
+
balance_cocnept_list.append(min_concept_num_images)
|
843 |
+
else:
|
844 |
+
balance_cocnept_list.append(min_concept_num_images)
|
845 |
+
for concept in concept_list:
|
846 |
+
if 'use_sub_dirs' in concept:
|
847 |
+
if concept['use_sub_dirs'] == True:
|
848 |
+
use_sub_dirs = True
|
849 |
+
else:
|
850 |
+
use_sub_dirs = False
|
851 |
+
else:
|
852 |
+
use_sub_dirs = False
|
853 |
+
self.image_paths = []
|
854 |
+
#self.class_image_paths = []
|
855 |
+
min_concept_num_images = None
|
856 |
+
if balance_datasets:
|
857 |
+
min_concept_num_images = balance_cocnept_list[concept_list.index(concept)]
|
858 |
+
data_root = concept['instance_data_dir']
|
859 |
+
data_root_class = concept['class_data_dir']
|
860 |
+
concept_prompt = concept['instance_prompt']
|
861 |
+
concept_class_prompt = concept['class_prompt']
|
862 |
+
if 'flip_p' in concept.keys():
|
863 |
+
flip_p = concept['flip_p']
|
864 |
+
if flip_p == '':
|
865 |
+
flip_p = 0.0
|
866 |
+
else:
|
867 |
+
flip_p = float(flip_p)
|
868 |
+
self.__recurse_data_root(self=self, recurse_root=data_root,use_sub_dirs=use_sub_dirs)
|
869 |
+
random.Random(self.seed).shuffle(self.image_paths)
|
870 |
+
if self.model_variant == 'depth2img':
|
871 |
+
print(f" {bcolors.WARNING} ** Loading Depth2Img Pipeline To Process Dataset{bcolors.ENDC}")
|
872 |
+
self.vae_scale_factor = self.extra_module.depth_images(self.image_paths)
|
873 |
+
prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_prompt,use_text_files_as_captions=self.use_text_files_as_captions)[0:min_concept_num_images]) # ImageTrainItem[]
|
874 |
+
if add_class_images_to_dataset:
|
875 |
+
self.image_paths = []
|
876 |
+
self.__recurse_data_root(self=self, recurse_root=data_root_class,use_sub_dirs=use_sub_dirs)
|
877 |
+
random.Random(self.seed).shuffle(self.image_paths)
|
878 |
+
use_image_names_as_captions = False
|
879 |
+
prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions)) # ImageTrainItem[]
|
880 |
+
|
881 |
+
self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference)
|
882 |
+
if self.with_prior_loss and add_class_images_to_dataset == False:
|
883 |
+
self.class_image_caption_pairs = []
|
884 |
+
for concept in concept_list:
|
885 |
+
self.class_images_path = []
|
886 |
+
data_root_class = concept['class_data_dir']
|
887 |
+
concept_class_prompt = concept['class_prompt']
|
888 |
+
self.__recurse_data_root(self=self, recurse_root=data_root_class,use_sub_dirs=use_sub_dirs,class_images=True)
|
889 |
+
random.Random(seed).shuffle(self.image_paths)
|
890 |
+
if self.model_variant == 'depth2img':
|
891 |
+
print(f" {bcolors.WARNING} ** Depth2Img To Process Class Dataset{bcolors.ENDC}")
|
892 |
+
self.vae_scale_factor = self.extra_module.depth_images(self.image_paths)
|
893 |
+
use_image_names_as_captions = False
|
894 |
+
self.class_image_caption_pairs.extend(self.__prescan_images(debug_level, self.class_images_path, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions))
|
895 |
+
self.class_image_caption_pairs = self.__bucketize_images(self.class_image_caption_pairs, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference)
|
896 |
+
if mask_prompts is not None:
|
897 |
+
print(f" {bcolors.WARNING} Checking and generating missing masks...{bcolors.ENDC}")
|
898 |
+
clip_seg = ClipSeg()
|
899 |
+
clip_seg.mask_images(self.image_paths, mask_prompts)
|
900 |
+
del clip_seg
|
901 |
+
if debug_level > 0: print(f" * DLMA Example: {self.image_caption_pairs[0]} images")
|
902 |
+
#print the length of image_caption_pairs
|
903 |
+
print(f" {bcolors.WARNING} Number of image-caption pairs: {len(self.image_caption_pairs)}{bcolors.ENDC}")
|
904 |
+
if len(self.image_caption_pairs) == 0:
|
905 |
+
raise Exception("All the buckets are empty. Please check your data or reduce the batch size.")
|
906 |
+
def get_all_images(self):
|
907 |
+
if self.with_prior_loss == False:
|
908 |
+
return self.image_caption_pairs
|
909 |
+
else:
|
910 |
+
return self.image_caption_pairs, self.class_image_caption_pairs
|
911 |
+
def __prescan_images(self,debug_level: int, image_paths: list, flip_p=0.0,use_image_names_as_captions=True,concept=None,use_text_files_as_captions=False):
|
912 |
+
"""
|
913 |
+
Create ImageTrainItem objects with metadata for hydration later
|
914 |
+
"""
|
915 |
+
decorated_image_train_items = []
|
916 |
+
|
917 |
+
for pathname in image_paths:
|
918 |
+
identifier = concept
|
919 |
+
if use_image_names_as_captions:
|
920 |
+
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
921 |
+
identifier = caption_from_filename
|
922 |
+
if use_text_files_as_captions:
|
923 |
+
txt_file_path = os.path.splitext(pathname)[0] + ".txt"
|
924 |
+
|
925 |
+
if os.path.exists(txt_file_path):
|
926 |
+
try:
|
927 |
+
with open(txt_file_path, 'r',encoding='utf-8',errors='ignore') as f:
|
928 |
+
identifier = f.readline().rstrip()
|
929 |
+
f.close()
|
930 |
+
if len(identifier) < 1:
|
931 |
+
raise ValueError(f" *** Could not find valid text in: {txt_file_path}")
|
932 |
+
|
933 |
+
except Exception as e:
|
934 |
+
print(f" {bcolors.FAIL} *** Error reading {txt_file_path} to get caption, falling back to filename{bcolors.ENDC}")
|
935 |
+
print(e)
|
936 |
+
identifier = caption_from_filename
|
937 |
+
pass
|
938 |
+
#print("identifier: ",identifier)
|
939 |
+
image = Image.open(pathname)
|
940 |
+
width, height = image.size
|
941 |
+
image_aspect = width / height
|
942 |
+
|
943 |
+
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
944 |
+
|
945 |
+
image_train_item = ImageTrainItem(image=None, mask=None, extra=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p,model_variant=self.model_variant, load_mask=self.load_mask)
|
946 |
+
|
947 |
+
decorated_image_train_items.append(image_train_item)
|
948 |
+
return decorated_image_train_items
|
949 |
+
|
950 |
+
@staticmethod
|
951 |
+
def __bucketize_images(prepared_train_data: list, batch_size=1, debug_level=0,aspect_mode='dynamic',action_preference='add'):
|
952 |
+
"""
|
953 |
+
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
|
954 |
+
"""
|
955 |
+
|
956 |
+
# TODO: this is not terribly efficient but at least linear time
|
957 |
+
buckets = {}
|
958 |
+
for image_caption_pair in prepared_train_data:
|
959 |
+
target_wh = image_caption_pair.target_wh
|
960 |
+
|
961 |
+
if (target_wh[0],target_wh[1]) not in buckets:
|
962 |
+
buckets[(target_wh[0],target_wh[1])] = []
|
963 |
+
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
|
964 |
+
print(f" ** Number of buckets: {len(buckets)}")
|
965 |
+
for bucket in buckets:
|
966 |
+
bucket_len = len(buckets[bucket])
|
967 |
+
#real_len = len(buckets[bucket])+1
|
968 |
+
#print(real_len)
|
969 |
+
truncate_amount = bucket_len % batch_size
|
970 |
+
add_amount = batch_size - bucket_len % batch_size
|
971 |
+
action = None
|
972 |
+
#print(f" ** Bucket {bucket} has {bucket_len} images")
|
973 |
+
if aspect_mode == 'dynamic':
|
974 |
+
if batch_size == bucket_len:
|
975 |
+
action = None
|
976 |
+
elif add_amount < truncate_amount and add_amount != 0 and add_amount != batch_size or truncate_amount == 0:
|
977 |
+
action = 'add'
|
978 |
+
#print(f'should add {add_amount}')
|
979 |
+
elif truncate_amount < add_amount and truncate_amount != 0 and truncate_amount != batch_size and batch_size < bucket_len:
|
980 |
+
#print(f'should truncate {truncate_amount}')
|
981 |
+
action = 'truncate'
|
982 |
+
#truncate the bucket
|
983 |
+
elif truncate_amount == add_amount:
|
984 |
+
if action_preference == 'add':
|
985 |
+
action = 'add'
|
986 |
+
elif action_preference == 'truncate':
|
987 |
+
action = 'truncate'
|
988 |
+
elif batch_size > bucket_len:
|
989 |
+
action = 'add'
|
990 |
+
|
991 |
+
elif aspect_mode == 'add':
|
992 |
+
action = 'add'
|
993 |
+
elif aspect_mode == 'truncate':
|
994 |
+
action = 'truncate'
|
995 |
+
if action == None:
|
996 |
+
action = None
|
997 |
+
#print('no need to add or truncate')
|
998 |
+
if action == None:
|
999 |
+
#print('test')
|
1000 |
+
current_bucket_size = bucket_len
|
1001 |
+
print(f" ** Bucket {bucket} found {bucket_len}, nice!")
|
1002 |
+
elif action == 'add':
|
1003 |
+
#copy the bucket
|
1004 |
+
shuffleBucket = random.sample(buckets[bucket], bucket_len)
|
1005 |
+
#add the images to the bucket
|
1006 |
+
current_bucket_size = bucket_len
|
1007 |
+
truncate_count = (bucket_len) % batch_size
|
1008 |
+
#how many images to add to the bucket to fill the batch
|
1009 |
+
addAmount = batch_size - truncate_count
|
1010 |
+
if addAmount != batch_size:
|
1011 |
+
added=0
|
1012 |
+
while added != addAmount:
|
1013 |
+
randomIndex = random.randint(0,len(shuffleBucket)-1)
|
1014 |
+
#print(str(randomIndex))
|
1015 |
+
buckets[bucket].append(shuffleBucket[randomIndex])
|
1016 |
+
added+=1
|
1017 |
+
print(f" ** Bucket {bucket} found {bucket_len} images, will {bcolors.OKCYAN}duplicate {added} images{bcolors.ENDC} due to batch size {bcolors.WARNING}{batch_size}{bcolors.ENDC}")
|
1018 |
+
else:
|
1019 |
+
print(f" ** Bucket {bucket} found {bucket_len}, {bcolors.OKGREEN}nice!{bcolors.ENDC}")
|
1020 |
+
elif action == 'truncate':
|
1021 |
+
truncate_count = (bucket_len) % batch_size
|
1022 |
+
current_bucket_size = bucket_len
|
1023 |
+
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
1024 |
+
print(f" ** Bucket {bucket} found {bucket_len} images, will {bcolors.FAIL}drop {truncate_count} images{bcolors.ENDC} due to batch size {bcolors.WARNING}{batch_size}{bcolors.ENDC}")
|
1025 |
+
|
1026 |
+
|
1027 |
+
# flatten the buckets
|
1028 |
+
image_caption_pairs = []
|
1029 |
+
for bucket in buckets:
|
1030 |
+
image_caption_pairs.extend(buckets[bucket])
|
1031 |
+
|
1032 |
+
return image_caption_pairs
|
1033 |
+
|
1034 |
+
@staticmethod
|
1035 |
+
def __recurse_data_root(self, recurse_root,use_sub_dirs=True,class_images=False):
|
1036 |
+
progress_bar = tqdm(os.listdir(recurse_root), desc=f" {bcolors.WARNING} ** Processing {recurse_root}{bcolors.ENDC}")
|
1037 |
+
for f in os.listdir(recurse_root):
|
1038 |
+
current = os.path.join(recurse_root, f)
|
1039 |
+
if os.path.isfile(current):
|
1040 |
+
ext = os.path.splitext(f)[1].lower()
|
1041 |
+
if '-depth' in f or '-masklabel' in f:
|
1042 |
+
progress_bar.update(1)
|
1043 |
+
continue
|
1044 |
+
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']:
|
1045 |
+
#try to open the file to make sure it's a valid image
|
1046 |
+
try:
|
1047 |
+
img = Image.open(current)
|
1048 |
+
except:
|
1049 |
+
print(f" ** Skipping {current} because it failed to open, please check the file")
|
1050 |
+
progress_bar.update(1)
|
1051 |
+
continue
|
1052 |
+
del img
|
1053 |
+
if class_images == False:
|
1054 |
+
self.image_paths.append(current)
|
1055 |
+
else:
|
1056 |
+
self.class_images_path.append(current)
|
1057 |
+
progress_bar.update(1)
|
1058 |
+
if use_sub_dirs:
|
1059 |
+
sub_dirs = []
|
1060 |
+
|
1061 |
+
for d in os.listdir(recurse_root):
|
1062 |
+
current = os.path.join(recurse_root, d)
|
1063 |
+
if os.path.isdir(current):
|
1064 |
+
sub_dirs.append(current)
|
1065 |
+
|
1066 |
+
for dir in sub_dirs:
|
1067 |
+
self.__recurse_data_root(self=self, recurse_root=dir)
|
1068 |
+
|
1069 |
+
class NormalDataset(Dataset):
|
1070 |
+
"""
|
1071 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
1072 |
+
It pre-processes the images and the tokenizes prompts.
|
1073 |
+
"""
|
1074 |
+
|
1075 |
+
def __init__(
|
1076 |
+
self,
|
1077 |
+
concepts_list,
|
1078 |
+
tokenizer,
|
1079 |
+
with_prior_preservation=True,
|
1080 |
+
size=512,
|
1081 |
+
center_crop=False,
|
1082 |
+
num_class_images=None,
|
1083 |
+
use_image_names_as_captions=False,
|
1084 |
+
shuffle_captions=False,
|
1085 |
+
repeats=1,
|
1086 |
+
use_text_files_as_captions=False,
|
1087 |
+
seed=555,
|
1088 |
+
model_variant='base',
|
1089 |
+
extra_module=None,
|
1090 |
+
mask_prompts=None,
|
1091 |
+
load_mask=None,
|
1092 |
+
):
|
1093 |
+
self.use_image_names_as_captions = use_image_names_as_captions
|
1094 |
+
self.shuffle_captions = shuffle_captions
|
1095 |
+
self.size = size
|
1096 |
+
self.center_crop = center_crop
|
1097 |
+
self.tokenizer = tokenizer
|
1098 |
+
self.with_prior_preservation = with_prior_preservation
|
1099 |
+
self.use_text_files_as_captions = use_text_files_as_captions
|
1100 |
+
self.image_paths = []
|
1101 |
+
self.class_images_path = []
|
1102 |
+
self.seed = seed
|
1103 |
+
self.model_variant = model_variant
|
1104 |
+
self.variant_warning = False
|
1105 |
+
self.vae_scale_factor = None
|
1106 |
+
self.load_mask = load_mask
|
1107 |
+
for concept in concepts_list:
|
1108 |
+
if 'use_sub_dirs' in concept:
|
1109 |
+
if concept['use_sub_dirs'] == True:
|
1110 |
+
use_sub_dirs = True
|
1111 |
+
else:
|
1112 |
+
use_sub_dirs = False
|
1113 |
+
else:
|
1114 |
+
use_sub_dirs = False
|
1115 |
+
|
1116 |
+
for i in range(repeats):
|
1117 |
+
self.__recurse_data_root(self, concept,use_sub_dirs=use_sub_dirs)
|
1118 |
+
|
1119 |
+
if with_prior_preservation:
|
1120 |
+
for i in range(repeats):
|
1121 |
+
self.__recurse_data_root(self, concept,use_sub_dirs=False,class_images=True)
|
1122 |
+
if mask_prompts is not None:
|
1123 |
+
print(f" {bcolors.WARNING} Checking and generating missing masks{bcolors.ENDC}")
|
1124 |
+
clip_seg = ClipSeg()
|
1125 |
+
clip_seg.mask_images(self.image_paths, mask_prompts)
|
1126 |
+
del clip_seg
|
1127 |
+
|
1128 |
+
random.Random(seed).shuffle(self.image_paths)
|
1129 |
+
self.num_instance_images = len(self.image_paths)
|
1130 |
+
self._length = self.num_instance_images
|
1131 |
+
self.num_class_images = len(self.class_images_path)
|
1132 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
1133 |
+
if self.model_variant == 'depth2img':
|
1134 |
+
print(f" {bcolors.WARNING} ** Loading Depth2Img Pipeline To Process Dataset{bcolors.ENDC}")
|
1135 |
+
self.vae_scale_factor = extra_module.depth_images(self.image_paths)
|
1136 |
+
if self.with_prior_preservation:
|
1137 |
+
print(f" {bcolors.WARNING} ** Loading Depth2Img Class Processing{bcolors.ENDC}")
|
1138 |
+
extra_module.depth_images(self.class_images_path)
|
1139 |
+
print(f" {bcolors.WARNING} ** Dataset length: {self._length}, {int(self.num_instance_images / repeats)} images using {repeats} repeats{bcolors.ENDC}")
|
1140 |
+
|
1141 |
+
self.image_transforms = transforms.Compose(
|
1142 |
+
[
|
1143 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
1144 |
+
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
1145 |
+
transforms.ToTensor(),
|
1146 |
+
transforms.Normalize([0.5], [0.5]),
|
1147 |
+
]
|
1148 |
+
|
1149 |
+
)
|
1150 |
+
self.mask_transforms = transforms.Compose(
|
1151 |
+
[
|
1152 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
1153 |
+
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
1154 |
+
transforms.ToTensor(),
|
1155 |
+
])
|
1156 |
+
|
1157 |
+
self.depth_image_transforms = transforms.Compose(
|
1158 |
+
[
|
1159 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
1160 |
+
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
1161 |
+
transforms.ToTensor(),
|
1162 |
+
]
|
1163 |
+
)
|
1164 |
+
|
1165 |
+
@staticmethod
|
1166 |
+
def __recurse_data_root(self, recurse_root,use_sub_dirs=True,class_images=False):
|
1167 |
+
#if recurse root is a dict
|
1168 |
+
if isinstance(recurse_root, dict):
|
1169 |
+
if class_images == True:
|
1170 |
+
#print(f" {bcolors.WARNING} ** Processing class images: {recurse_root['class_data_dir']}{bcolors.ENDC}")
|
1171 |
+
concept_token = recurse_root['class_prompt']
|
1172 |
+
data = recurse_root['class_data_dir']
|
1173 |
+
else:
|
1174 |
+
#print(f" {bcolors.WARNING} ** Processing instance images: {recurse_root['instance_data_dir']}{bcolors.ENDC}")
|
1175 |
+
concept_token = recurse_root['instance_prompt']
|
1176 |
+
data = recurse_root['instance_data_dir']
|
1177 |
+
|
1178 |
+
|
1179 |
+
else:
|
1180 |
+
concept_token = None
|
1181 |
+
#progress bar
|
1182 |
+
progress_bar = tqdm(os.listdir(data), desc=f" {bcolors.WARNING} ** Processing {data}{bcolors.ENDC}")
|
1183 |
+
for f in os.listdir(data):
|
1184 |
+
current = os.path.join(data, f)
|
1185 |
+
if os.path.isfile(current):
|
1186 |
+
if '-depth' in f or '-masklabel' in f:
|
1187 |
+
continue
|
1188 |
+
ext = os.path.splitext(f)[1].lower()
|
1189 |
+
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']:
|
1190 |
+
try:
|
1191 |
+
img = Image.open(current)
|
1192 |
+
except:
|
1193 |
+
print(f" ** Skipping {current} because it failed to open, please check the file")
|
1194 |
+
progress_bar.update(1)
|
1195 |
+
continue
|
1196 |
+
del img
|
1197 |
+
if class_images == False:
|
1198 |
+
self.image_paths.append([current,concept_token])
|
1199 |
+
else:
|
1200 |
+
self.class_images_path.append([current,concept_token])
|
1201 |
+
progress_bar.update(1)
|
1202 |
+
if use_sub_dirs:
|
1203 |
+
sub_dirs = []
|
1204 |
+
|
1205 |
+
for d in os.listdir(data):
|
1206 |
+
current = os.path.join(data, d)
|
1207 |
+
if os.path.isdir(current):
|
1208 |
+
sub_dirs.append(current)
|
1209 |
+
|
1210 |
+
for dir in sub_dirs:
|
1211 |
+
if class_images == False:
|
1212 |
+
self.__recurse_data_root(self=self, recurse_root={'instance_data_dir' : dir, 'instance_prompt' : concept_token})
|
1213 |
+
else:
|
1214 |
+
self.__recurse_data_root(self=self, recurse_root={'class_data_dir' : dir, 'class_prompt' : concept_token})
|
1215 |
+
|
1216 |
+
def __len__(self):
|
1217 |
+
return self._length
|
1218 |
+
|
1219 |
+
def __getitem__(self, index):
|
1220 |
+
example = {}
|
1221 |
+
instance_path, instance_prompt = self.image_paths[index % self.num_instance_images]
|
1222 |
+
og_prompt = instance_prompt
|
1223 |
+
instance_image = Image.open(instance_path)
|
1224 |
+
if self.model_variant == "inpainting" or self.load_mask:
|
1225 |
+
|
1226 |
+
mask_pathname = os.path.splitext(instance_path)[0] + "-masklabel.png"
|
1227 |
+
if os.path.exists(mask_pathname) and self.load_mask:
|
1228 |
+
mask = Image.open(mask_pathname).convert("L")
|
1229 |
+
else:
|
1230 |
+
if self.variant_warning == False:
|
1231 |
+
print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}")
|
1232 |
+
self.variant_warning = True
|
1233 |
+
size = instance_image.size
|
1234 |
+
mask = Image.new('RGB', size, color="white").convert("L")
|
1235 |
+
example["mask"] = self.mask_transforms(mask)
|
1236 |
+
if self.model_variant == "depth2img":
|
1237 |
+
depth_pathname = os.path.splitext(instance_path)[0] + "-depth.png"
|
1238 |
+
if os.path.exists(depth_pathname):
|
1239 |
+
depth_image = Image.open(depth_pathname).convert("L")
|
1240 |
+
else:
|
1241 |
+
if self.variant_warning == False:
|
1242 |
+
print(f" {bcolors.FAIL} ** Warning: No depth image found for an image, using an empty depth image but make sure you're training the right model variant.{bcolors.ENDC}")
|
1243 |
+
self.variant_warning = True
|
1244 |
+
size = instance_image.size
|
1245 |
+
depth_image = Image.new('RGB', size, color="white").convert("L")
|
1246 |
+
example["instance_depth_images"] = self.depth_image_transforms(depth_image)
|
1247 |
+
|
1248 |
+
if self.use_image_names_as_captions == True:
|
1249 |
+
instance_prompt = str(instance_path).split(os.sep)[-1].split('.')[0].split('_')[0]
|
1250 |
+
#else if there's a txt file with the same name as the image, read the caption from there
|
1251 |
+
if self.use_text_files_as_captions == True:
|
1252 |
+
#if there's a file with the same name as the image, but with a .txt extension, read the caption from there
|
1253 |
+
#get the last . in the file name
|
1254 |
+
last_dot = str(instance_path).rfind('.')
|
1255 |
+
#get the path up to the last dot
|
1256 |
+
txt_path = str(instance_path)[:last_dot] + '.txt'
|
1257 |
+
|
1258 |
+
#if txt_path exists, read the caption from there
|
1259 |
+
if os.path.exists(txt_path):
|
1260 |
+
with open(txt_path, encoding='utf-8') as f:
|
1261 |
+
instance_prompt = f.readline().rstrip()
|
1262 |
+
f.close()
|
1263 |
+
|
1264 |
+
if self.shuffle_captions:
|
1265 |
+
caption_parts = instance_prompt.split(",")
|
1266 |
+
random.shuffle(caption_parts)
|
1267 |
+
instance_prompt = ",".join(caption_parts)
|
1268 |
+
|
1269 |
+
#print('identifier: ' + instance_prompt)
|
1270 |
+
instance_image = instance_image.convert("RGB")
|
1271 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
1272 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
1273 |
+
instance_prompt,
|
1274 |
+
padding="do_not_pad",
|
1275 |
+
truncation=True,
|
1276 |
+
max_length=self.tokenizer.model_max_length,
|
1277 |
+
).input_ids
|
1278 |
+
if self.with_prior_preservation:
|
1279 |
+
class_path, class_prompt = self.class_images_path[index % self.num_class_images]
|
1280 |
+
class_image = Image.open(class_path)
|
1281 |
+
if not class_image.mode == "RGB":
|
1282 |
+
class_image = class_image.convert("RGB")
|
1283 |
+
|
1284 |
+
if self.model_variant == "inpainting":
|
1285 |
+
mask_pathname = os.path.splitext(class_path)[0] + "-masklabel.png"
|
1286 |
+
if os.path.exists(mask_pathname):
|
1287 |
+
mask = Image.open(mask_pathname).convert("L")
|
1288 |
+
else:
|
1289 |
+
if self.variant_warning == False:
|
1290 |
+
print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}")
|
1291 |
+
self.variant_warning = True
|
1292 |
+
size = instance_image.size
|
1293 |
+
mask = Image.new('RGB', size, color="white").convert("L")
|
1294 |
+
example["class_mask"] = self.mask_transforms(mask)
|
1295 |
+
if self.model_variant == "depth2img":
|
1296 |
+
depth_pathname = os.path.splitext(class_path)[0] + "-depth.png"
|
1297 |
+
if os.path.exists(depth_pathname):
|
1298 |
+
depth_image = Image.open(depth_pathname)
|
1299 |
+
else:
|
1300 |
+
if self.variant_warning == False:
|
1301 |
+
print(f" {bcolors.FAIL} ** Warning: No depth image found for an image, using an empty depth image but make sure you're training the right model variant.{bcolors.ENDC}")
|
1302 |
+
self.variant_warning = True
|
1303 |
+
size = instance_image.size
|
1304 |
+
depth_image = Image.new('RGB', size, color="white").convert("L")
|
1305 |
+
example["class_depth_images"] = self.depth_image_transforms(depth_image)
|
1306 |
+
example["class_images"] = self.image_transforms(class_image)
|
1307 |
+
example["class_prompt_ids"] = self.tokenizer(
|
1308 |
+
class_prompt,
|
1309 |
+
padding="do_not_pad",
|
1310 |
+
truncation=True,
|
1311 |
+
max_length=self.tokenizer.model_max_length,
|
1312 |
+
).input_ids
|
1313 |
+
|
1314 |
+
return example
|
1315 |
+
|
1316 |
+
|
1317 |
+
class PromptDataset(Dataset):
|
1318 |
+
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
1319 |
+
|
1320 |
+
def __init__(self, prompt, num_samples):
|
1321 |
+
self.prompt = prompt
|
1322 |
+
self.num_samples = num_samples
|
1323 |
+
|
1324 |
+
def __len__(self):
|
1325 |
+
return self.num_samples
|
1326 |
+
|
1327 |
+
def __getitem__(self, index):
|
1328 |
+
example = {}
|
1329 |
+
example["prompt"] = self.prompt
|
1330 |
+
example["index"] = index
|
1331 |
+
return example
|
StableTuner_RunPod_Fix/discriminator.py
ADDED
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import einops, einops.layers.torch
|
6 |
+
import diffusers
|
7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
+
from typing import Tuple, Optional
|
9 |
+
|
10 |
+
import inspect
|
11 |
+
import os
|
12 |
+
from functools import partial
|
13 |
+
from typing import Callable, List, Optional, Tuple, Union
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import Tensor, device
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
class ModelMixin(torch.nn.Module):
|
21 |
+
r"""
|
22 |
+
Base class for all models.
|
23 |
+
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
|
24 |
+
and saving models.
|
25 |
+
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
|
26 |
+
[`~models.ModelMixin.save_pretrained`].
|
27 |
+
"""
|
28 |
+
config_name = "new"
|
29 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
30 |
+
_supports_gradient_checkpointing = False
|
31 |
+
|
32 |
+
def __init__(self):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
@property
|
36 |
+
def is_gradient_checkpointing(self) -> bool:
|
37 |
+
"""
|
38 |
+
Whether gradient checkpointing is activated for this model or not.
|
39 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
40 |
+
activations".
|
41 |
+
"""
|
42 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
43 |
+
|
44 |
+
def enable_gradient_checkpointing(self):
|
45 |
+
"""
|
46 |
+
Activates gradient checkpointing for the current model.
|
47 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
48 |
+
activations".
|
49 |
+
"""
|
50 |
+
if not self._supports_gradient_checkpointing:
|
51 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
52 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
53 |
+
|
54 |
+
def disable_gradient_checkpointing(self):
|
55 |
+
"""
|
56 |
+
Deactivates gradient checkpointing for the current model.
|
57 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
58 |
+
activations".
|
59 |
+
"""
|
60 |
+
if self._supports_gradient_checkpointing:
|
61 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
62 |
+
|
63 |
+
def set_use_memory_efficient_attention_xformers(
|
64 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
65 |
+
) -> None:
|
66 |
+
# Recursively walk through all the children.
|
67 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
68 |
+
# gets the message
|
69 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
70 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
71 |
+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
72 |
+
|
73 |
+
for child in module.children():
|
74 |
+
fn_recursive_set_mem_eff(child)
|
75 |
+
|
76 |
+
for module in self.children():
|
77 |
+
if isinstance(module, torch.nn.Module):
|
78 |
+
fn_recursive_set_mem_eff(module)
|
79 |
+
|
80 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
81 |
+
r"""
|
82 |
+
Enable memory efficient attention as implemented in xformers.
|
83 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
84 |
+
time. Speed up at training time is not guaranteed.
|
85 |
+
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
86 |
+
is used.
|
87 |
+
Parameters:
|
88 |
+
attention_op (`Callable`, *optional*):
|
89 |
+
Override the default `None` operator for use as `op` argument to the
|
90 |
+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
91 |
+
function of xFormers.
|
92 |
+
Examples:
|
93 |
+
```py
|
94 |
+
>>> import torch
|
95 |
+
>>> from diffusers import UNet2DConditionModel
|
96 |
+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
97 |
+
>>> model = UNet2DConditionModel.from_pretrained(
|
98 |
+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
99 |
+
... )
|
100 |
+
>>> model = model.to("cuda")
|
101 |
+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
102 |
+
```
|
103 |
+
"""
|
104 |
+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
105 |
+
|
106 |
+
def disable_xformers_memory_efficient_attention(self):
|
107 |
+
r"""
|
108 |
+
Disable memory efficient attention as implemented in xformers.
|
109 |
+
"""
|
110 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
111 |
+
|
112 |
+
def save_pretrained(
|
113 |
+
self,
|
114 |
+
save_directory: Union[str, os.PathLike],
|
115 |
+
is_main_process: bool = True,
|
116 |
+
save_function: Callable = None,
|
117 |
+
safe_serialization: bool = False,
|
118 |
+
variant: Optional[str] = None,
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
122 |
+
`[`~models.ModelMixin.from_pretrained`]` class method.
|
123 |
+
Arguments:
|
124 |
+
save_directory (`str` or `os.PathLike`):
|
125 |
+
Directory to which to save. Will be created if it doesn't exist.
|
126 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
127 |
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
128 |
+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
129 |
+
the main process to avoid race conditions.
|
130 |
+
save_function (`Callable`):
|
131 |
+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
132 |
+
need to replace `torch.save` by another method. Can be configured with the environment variable
|
133 |
+
`DIFFUSERS_SAVE_MODE`.
|
134 |
+
safe_serialization (`bool`, *optional*, defaults to `False`):
|
135 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
136 |
+
variant (`str`, *optional*):
|
137 |
+
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
138 |
+
"""
|
139 |
+
if safe_serialization and not is_safetensors_available():
|
140 |
+
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
141 |
+
|
142 |
+
if os.path.isfile(save_directory):
|
143 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
144 |
+
return
|
145 |
+
|
146 |
+
os.makedirs(save_directory, exist_ok=True)
|
147 |
+
|
148 |
+
model_to_save = self
|
149 |
+
|
150 |
+
# Attach architecture to the config
|
151 |
+
# Save the config
|
152 |
+
if is_main_process:
|
153 |
+
model_to_save.save_config(save_directory)
|
154 |
+
|
155 |
+
# Save the model
|
156 |
+
state_dict = model_to_save.state_dict()
|
157 |
+
|
158 |
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
159 |
+
weights_name = _add_variant(weights_name, variant)
|
160 |
+
|
161 |
+
# Save the model
|
162 |
+
if safe_serialization:
|
163 |
+
safetensors.torch.save_file(
|
164 |
+
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
torch.save(state_dict, os.path.join(save_directory, weights_name))
|
168 |
+
|
169 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
170 |
+
|
171 |
+
@classmethod
|
172 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
173 |
+
r"""
|
174 |
+
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
175 |
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
176 |
+
the model, you should first set it back in training mode with `model.train()`.
|
177 |
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
178 |
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
179 |
+
task.
|
180 |
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
181 |
+
weights are discarded.
|
182 |
+
Parameters:
|
183 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
184 |
+
Can be either:
|
185 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
186 |
+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
187 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
188 |
+
`./my_model_directory/`.
|
189 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
190 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
191 |
+
standard cache should not be used.
|
192 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
193 |
+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
194 |
+
will be automatically derived from the model's weights.
|
195 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
196 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
197 |
+
cached versions if they exist.
|
198 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
199 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
200 |
+
file exists.
|
201 |
+
proxies (`Dict[str, str]`, *optional*):
|
202 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
203 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
204 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
205 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
206 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
207 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
208 |
+
use_auth_token (`str` or *bool*, *optional*):
|
209 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
210 |
+
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
211 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
212 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
213 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
214 |
+
identifier allowed by git.
|
215 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
216 |
+
Load the model weights from a Flax checkpoint save file.
|
217 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
218 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
219 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
220 |
+
mirror (`str`, *optional*):
|
221 |
+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
222 |
+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
223 |
+
Please refer to the mirror site for more information.
|
224 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
225 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
226 |
+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
227 |
+
same device.
|
228 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
229 |
+
more information about each option see [designing a device
|
230 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
231 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
232 |
+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
233 |
+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
234 |
+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
235 |
+
setting this argument to `True` will raise an error.
|
236 |
+
variant (`str`, *optional*):
|
237 |
+
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
238 |
+
ignored when using `from_flax`.
|
239 |
+
use_safetensors (`bool`, *optional* ):
|
240 |
+
If set to `True`, the pipeline will forcibly load the models from `safetensors` weights. If set to
|
241 |
+
`None` (the default). The pipeline will load using `safetensors` if safetensors weights are available
|
242 |
+
*and* if `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
|
243 |
+
<Tip>
|
244 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
245 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
246 |
+
</Tip>
|
247 |
+
<Tip>
|
248 |
+
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
249 |
+
this method in a firewalled environment.
|
250 |
+
</Tip>
|
251 |
+
"""
|
252 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
253 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
254 |
+
force_download = kwargs.pop("force_download", False)
|
255 |
+
from_flax = kwargs.pop("from_flax", False)
|
256 |
+
resume_download = kwargs.pop("resume_download", False)
|
257 |
+
proxies = kwargs.pop("proxies", None)
|
258 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
259 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
260 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
261 |
+
revision = kwargs.pop("revision", None)
|
262 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
263 |
+
subfolder = kwargs.pop("subfolder", None)
|
264 |
+
device_map = kwargs.pop("device_map", None)
|
265 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
266 |
+
variant = kwargs.pop("variant", None)
|
267 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
268 |
+
|
269 |
+
if use_safetensors and not is_safetensors_available():
|
270 |
+
raise ValueError(
|
271 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
272 |
+
)
|
273 |
+
|
274 |
+
allow_pickle = False
|
275 |
+
if use_safetensors is None:
|
276 |
+
use_safetensors = is_safetensors_available()
|
277 |
+
allow_pickle = True
|
278 |
+
|
279 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
280 |
+
low_cpu_mem_usage = False
|
281 |
+
logger.warning(
|
282 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
283 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
284 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
285 |
+
" install accelerate\n```\n."
|
286 |
+
)
|
287 |
+
|
288 |
+
if device_map is not None and not is_accelerate_available():
|
289 |
+
raise NotImplementedError(
|
290 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
291 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
292 |
+
)
|
293 |
+
|
294 |
+
# Check if we can handle device_map and dispatching the weights
|
295 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
296 |
+
raise NotImplementedError(
|
297 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
298 |
+
" `device_map=None`."
|
299 |
+
)
|
300 |
+
|
301 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
302 |
+
raise NotImplementedError(
|
303 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
304 |
+
" `low_cpu_mem_usage=False`."
|
305 |
+
)
|
306 |
+
|
307 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
308 |
+
raise ValueError(
|
309 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
310 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
311 |
+
)
|
312 |
+
|
313 |
+
# Load config if we don't provide a configuration
|
314 |
+
config_path = pretrained_model_name_or_path
|
315 |
+
|
316 |
+
user_agent = {
|
317 |
+
"diffusers": __version__,
|
318 |
+
"file_type": "model",
|
319 |
+
"framework": "pytorch",
|
320 |
+
}
|
321 |
+
|
322 |
+
# load config
|
323 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
324 |
+
config_path,
|
325 |
+
cache_dir=cache_dir,
|
326 |
+
return_unused_kwargs=True,
|
327 |
+
return_commit_hash=True,
|
328 |
+
force_download=force_download,
|
329 |
+
resume_download=resume_download,
|
330 |
+
proxies=proxies,
|
331 |
+
local_files_only=local_files_only,
|
332 |
+
use_auth_token=use_auth_token,
|
333 |
+
revision=revision,
|
334 |
+
subfolder=subfolder,
|
335 |
+
device_map=device_map,
|
336 |
+
user_agent=user_agent,
|
337 |
+
**kwargs,
|
338 |
+
)
|
339 |
+
|
340 |
+
# load model
|
341 |
+
model_file = None
|
342 |
+
if from_flax:
|
343 |
+
model_file = _get_model_file(
|
344 |
+
pretrained_model_name_or_path,
|
345 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
346 |
+
cache_dir=cache_dir,
|
347 |
+
force_download=force_download,
|
348 |
+
resume_download=resume_download,
|
349 |
+
proxies=proxies,
|
350 |
+
local_files_only=local_files_only,
|
351 |
+
use_auth_token=use_auth_token,
|
352 |
+
revision=revision,
|
353 |
+
subfolder=subfolder,
|
354 |
+
user_agent=user_agent,
|
355 |
+
commit_hash=commit_hash,
|
356 |
+
)
|
357 |
+
model = cls.from_config(config, **unused_kwargs)
|
358 |
+
|
359 |
+
# Convert the weights
|
360 |
+
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
361 |
+
|
362 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
363 |
+
else:
|
364 |
+
if use_safetensors:
|
365 |
+
try:
|
366 |
+
model_file = _get_model_file(
|
367 |
+
pretrained_model_name_or_path,
|
368 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
369 |
+
cache_dir=cache_dir,
|
370 |
+
force_download=force_download,
|
371 |
+
resume_download=resume_download,
|
372 |
+
proxies=proxies,
|
373 |
+
local_files_only=local_files_only,
|
374 |
+
use_auth_token=use_auth_token,
|
375 |
+
revision=revision,
|
376 |
+
subfolder=subfolder,
|
377 |
+
user_agent=user_agent,
|
378 |
+
commit_hash=commit_hash,
|
379 |
+
)
|
380 |
+
except IOError as e:
|
381 |
+
if not allow_pickle:
|
382 |
+
raise e
|
383 |
+
pass
|
384 |
+
if model_file is None:
|
385 |
+
model_file = _get_model_file(
|
386 |
+
pretrained_model_name_or_path,
|
387 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
388 |
+
cache_dir=cache_dir,
|
389 |
+
force_download=force_download,
|
390 |
+
resume_download=resume_download,
|
391 |
+
proxies=proxies,
|
392 |
+
local_files_only=local_files_only,
|
393 |
+
use_auth_token=use_auth_token,
|
394 |
+
revision=revision,
|
395 |
+
subfolder=subfolder,
|
396 |
+
user_agent=user_agent,
|
397 |
+
commit_hash=commit_hash,
|
398 |
+
)
|
399 |
+
|
400 |
+
if low_cpu_mem_usage:
|
401 |
+
# Instantiate model with empty weights
|
402 |
+
with accelerate.init_empty_weights():
|
403 |
+
model = cls.from_config(config, **unused_kwargs)
|
404 |
+
|
405 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
406 |
+
if device_map is None:
|
407 |
+
param_device = "cpu"
|
408 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
409 |
+
# move the params from meta device to cpu
|
410 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
411 |
+
if len(missing_keys) > 0:
|
412 |
+
raise ValueError(
|
413 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
414 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
415 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
416 |
+
" those weights or else make sure your checkpoint file is correct."
|
417 |
+
)
|
418 |
+
|
419 |
+
empty_state_dict = model.state_dict()
|
420 |
+
for param_name, param in state_dict.items():
|
421 |
+
accepts_dtype = "dtype" in set(
|
422 |
+
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
423 |
+
)
|
424 |
+
|
425 |
+
if empty_state_dict[param_name].shape != param.shape:
|
426 |
+
raise ValueError(
|
427 |
+
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
428 |
+
)
|
429 |
+
|
430 |
+
if accepts_dtype:
|
431 |
+
set_module_tensor_to_device(
|
432 |
+
model, param_name, param_device, value=param, dtype=torch_dtype
|
433 |
+
)
|
434 |
+
else:
|
435 |
+
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
436 |
+
else: # else let accelerate handle loading and dispatching.
|
437 |
+
# Load weights and dispatch according to the device_map
|
438 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
439 |
+
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
|
440 |
+
|
441 |
+
loading_info = {
|
442 |
+
"missing_keys": [],
|
443 |
+
"unexpected_keys": [],
|
444 |
+
"mismatched_keys": [],
|
445 |
+
"error_msgs": [],
|
446 |
+
}
|
447 |
+
else:
|
448 |
+
model = cls.from_config(config, **unused_kwargs)
|
449 |
+
|
450 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
451 |
+
|
452 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
453 |
+
model,
|
454 |
+
state_dict,
|
455 |
+
model_file,
|
456 |
+
pretrained_model_name_or_path,
|
457 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
458 |
+
)
|
459 |
+
|
460 |
+
loading_info = {
|
461 |
+
"missing_keys": missing_keys,
|
462 |
+
"unexpected_keys": unexpected_keys,
|
463 |
+
"mismatched_keys": mismatched_keys,
|
464 |
+
"error_msgs": error_msgs,
|
465 |
+
}
|
466 |
+
|
467 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
468 |
+
raise ValueError(
|
469 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
470 |
+
)
|
471 |
+
elif torch_dtype is not None:
|
472 |
+
model = model.to(torch_dtype)
|
473 |
+
|
474 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
475 |
+
|
476 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
477 |
+
model.eval()
|
478 |
+
if output_loading_info:
|
479 |
+
return model, loading_info
|
480 |
+
|
481 |
+
return model
|
482 |
+
|
483 |
+
@classmethod
|
484 |
+
def _load_pretrained_model(
|
485 |
+
cls,
|
486 |
+
model,
|
487 |
+
state_dict,
|
488 |
+
resolved_archive_file,
|
489 |
+
pretrained_model_name_or_path,
|
490 |
+
ignore_mismatched_sizes=False,
|
491 |
+
):
|
492 |
+
# Retrieve missing & unexpected_keys
|
493 |
+
model_state_dict = model.state_dict()
|
494 |
+
loaded_keys = list(state_dict.keys())
|
495 |
+
|
496 |
+
expected_keys = list(model_state_dict.keys())
|
497 |
+
|
498 |
+
original_loaded_keys = loaded_keys
|
499 |
+
|
500 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
501 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
502 |
+
|
503 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
504 |
+
model_to_load = model
|
505 |
+
|
506 |
+
def _find_mismatched_keys(
|
507 |
+
state_dict,
|
508 |
+
model_state_dict,
|
509 |
+
loaded_keys,
|
510 |
+
ignore_mismatched_sizes,
|
511 |
+
):
|
512 |
+
mismatched_keys = []
|
513 |
+
if ignore_mismatched_sizes:
|
514 |
+
for checkpoint_key in loaded_keys:
|
515 |
+
model_key = checkpoint_key
|
516 |
+
|
517 |
+
if (
|
518 |
+
model_key in model_state_dict
|
519 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
520 |
+
):
|
521 |
+
mismatched_keys.append(
|
522 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
523 |
+
)
|
524 |
+
del state_dict[checkpoint_key]
|
525 |
+
return mismatched_keys
|
526 |
+
|
527 |
+
if state_dict is not None:
|
528 |
+
# Whole checkpoint
|
529 |
+
mismatched_keys = _find_mismatched_keys(
|
530 |
+
state_dict,
|
531 |
+
model_state_dict,
|
532 |
+
original_loaded_keys,
|
533 |
+
ignore_mismatched_sizes,
|
534 |
+
)
|
535 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
536 |
+
|
537 |
+
if len(error_msgs) > 0:
|
538 |
+
error_msg = "\n\t".join(error_msgs)
|
539 |
+
if "size mismatch" in error_msg:
|
540 |
+
error_msg += (
|
541 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
542 |
+
)
|
543 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
544 |
+
|
545 |
+
if len(unexpected_keys) > 0:
|
546 |
+
logger.warning(
|
547 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
548 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
549 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
550 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
551 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
552 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
553 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
554 |
+
" BertForSequenceClassification model)."
|
555 |
+
)
|
556 |
+
else:
|
557 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
558 |
+
if len(missing_keys) > 0:
|
559 |
+
logger.warning(
|
560 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
561 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
562 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
563 |
+
)
|
564 |
+
elif len(mismatched_keys) == 0:
|
565 |
+
logger.info(
|
566 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
567 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
568 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
569 |
+
" without further training."
|
570 |
+
)
|
571 |
+
if len(mismatched_keys) > 0:
|
572 |
+
mismatched_warning = "\n".join(
|
573 |
+
[
|
574 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
575 |
+
for key, shape1, shape2 in mismatched_keys
|
576 |
+
]
|
577 |
+
)
|
578 |
+
logger.warning(
|
579 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
580 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
581 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
582 |
+
" able to use it for predictions and inference."
|
583 |
+
)
|
584 |
+
|
585 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
586 |
+
|
587 |
+
@property
|
588 |
+
def device(self) -> device:
|
589 |
+
"""
|
590 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
591 |
+
device).
|
592 |
+
"""
|
593 |
+
return get_parameter_device(self)
|
594 |
+
|
595 |
+
@property
|
596 |
+
def dtype(self) -> torch.dtype:
|
597 |
+
"""
|
598 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
599 |
+
"""
|
600 |
+
return get_parameter_dtype(self)
|
601 |
+
|
602 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
603 |
+
"""
|
604 |
+
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
605 |
+
Args:
|
606 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
607 |
+
Whether or not to return only the number of trainable parameters
|
608 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
609 |
+
Whether or not to return only the number of non-embeddings parameters
|
610 |
+
Returns:
|
611 |
+
`int`: The number of parameters.
|
612 |
+
"""
|
613 |
+
|
614 |
+
if exclude_embeddings:
|
615 |
+
embedding_param_names = [
|
616 |
+
f"{name}.weight"
|
617 |
+
for name, module_type in self.named_modules()
|
618 |
+
if isinstance(module_type, torch.nn.Embedding)
|
619 |
+
]
|
620 |
+
non_embedding_parameters = [
|
621 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
622 |
+
]
|
623 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
624 |
+
else:
|
625 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
626 |
+
|
627 |
+
def Downsample(dim, dim_out):
|
628 |
+
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
629 |
+
|
630 |
+
class Residual(nn.Sequential):
|
631 |
+
def forward(self, input):
|
632 |
+
x = input
|
633 |
+
for module in self:
|
634 |
+
x = module(x)
|
635 |
+
return x + input
|
636 |
+
|
637 |
+
def ConvLayer(dim, dim_out, *, kernel_size=3, groups=32):
|
638 |
+
return nn.Sequential(
|
639 |
+
nn.GroupNorm(groups, dim),
|
640 |
+
nn.SiLU(),
|
641 |
+
nn.Conv2d(dim, dim_out, kernel_size=kernel_size, padding=kernel_size//2),
|
642 |
+
)
|
643 |
+
|
644 |
+
def ResnetBlock(dim, *, kernel_size=3, groups=32):
|
645 |
+
return Residual(
|
646 |
+
ConvLayer(dim, dim, kernel_size=kernel_size, groups=groups),
|
647 |
+
ConvLayer(dim, dim, kernel_size=kernel_size, groups=groups),
|
648 |
+
)
|
649 |
+
|
650 |
+
class SelfAttention(nn.Module):
|
651 |
+
def __init__(self, dim, out_dim, *, heads=8, key_dim=32, value_dim=32):
|
652 |
+
super().__init__()
|
653 |
+
self.dim = dim
|
654 |
+
self.out_dim = dim
|
655 |
+
self.heads = heads
|
656 |
+
self.key_dim = key_dim
|
657 |
+
|
658 |
+
self.to_k = nn.Linear(dim, key_dim)
|
659 |
+
self.to_v = nn.Linear(dim, value_dim)
|
660 |
+
self.to_q = nn.Linear(dim, key_dim * heads)
|
661 |
+
self.to_out = nn.Linear(value_dim * heads, out_dim)
|
662 |
+
|
663 |
+
def forward(self, x):
|
664 |
+
shape = x.shape
|
665 |
+
x = einops.rearrange(x, 'b c ... -> b (...) c')
|
666 |
+
|
667 |
+
k = self.to_k(x)
|
668 |
+
v = self.to_v(x)
|
669 |
+
q = self.to_q(x)
|
670 |
+
q = einops.rearrange(q, 'b n (h c) -> b (n h) c', h=self.heads)
|
671 |
+
if hasattr(nn.functional, "scaled_dot_product_attention"):
|
672 |
+
result = F.scaled_dot_product_attention(q, k, v)
|
673 |
+
else:
|
674 |
+
attention_scores = torch.bmm(q, k.transpose(-2, -1))
|
675 |
+
attention_probs = torch.softmax(attention_scores.float() / math.sqrt(self.key_dim), dim=-1).type(attention_scores.dtype)
|
676 |
+
result = torch.bmm(attention_probs, v)
|
677 |
+
result = einops.rearrange(result, 'b (n h) c -> b n (h c)', h=self.heads)
|
678 |
+
out = self.to_out(result)
|
679 |
+
|
680 |
+
out = einops.rearrange(out, 'b n c -> b c n')
|
681 |
+
out = torch.reshape(out, (shape[0], self.out_dim, *shape[2:]))
|
682 |
+
return out
|
683 |
+
|
684 |
+
def SelfAttentionBlock(dim, attention_dim, *, heads=8, groups=32):
|
685 |
+
if not attention_dim:
|
686 |
+
attention_dim = dim // heads
|
687 |
+
return Residual(
|
688 |
+
nn.GroupNorm(groups, dim),
|
689 |
+
SelfAttention(dim, dim, heads=heads, key_dim=attention_dim, value_dim=attention_dim),
|
690 |
+
)
|
691 |
+
|
692 |
+
class Discriminator2D(ModelMixin, ConfigMixin):
|
693 |
+
@register_to_config
|
694 |
+
def __init__(
|
695 |
+
self,
|
696 |
+
in_channels: int = 8,
|
697 |
+
out_channels: int = 1,
|
698 |
+
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024, 1024),
|
699 |
+
block_repeats: Tuple[int] = (2, 2, 2, 2, 2),
|
700 |
+
downsample_blocks: Tuple[int] = (0, 1, 2),
|
701 |
+
attention_blocks: Tuple[int] = (1, 2, 3, 4),
|
702 |
+
mlp_hidden_channels: Tuple[int] = (2048, 2048, 2048),
|
703 |
+
mlp_uses_norm: bool = True,
|
704 |
+
attention_dim: Optional[int] = None,
|
705 |
+
attention_heads: int = 8,
|
706 |
+
groups: int = 32,
|
707 |
+
embedding_dim: int = 768,
|
708 |
+
):
|
709 |
+
super().__init__()
|
710 |
+
|
711 |
+
self.blocks = nn.ModuleList([])
|
712 |
+
|
713 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], 7, padding=3)
|
714 |
+
|
715 |
+
for i in range(0, len(block_out_channels) - 1):
|
716 |
+
block_in = block_out_channels[i]
|
717 |
+
block_out = block_out_channels[i + 1]
|
718 |
+
block = nn.Sequential()
|
719 |
+
for j in range(0, block_repeats[i]):
|
720 |
+
if i in attention_blocks:
|
721 |
+
block.append(SelfAttentionBlock(block_in, attention_dim, heads=attention_heads, groups=groups))
|
722 |
+
block.append(ResnetBlock(block_in, groups=groups))
|
723 |
+
if i in downsample_blocks:
|
724 |
+
block.append(Downsample(block_in, block_out))
|
725 |
+
elif block_in != block_out:
|
726 |
+
block.append(nn.Conv2d(block_in, block_out, 1))
|
727 |
+
self.blocks.append(block)
|
728 |
+
|
729 |
+
# A simple MLP to make the final decision based on statistics from
|
730 |
+
# the output of every block
|
731 |
+
self.to_out = nn.Sequential()
|
732 |
+
d_channels = 2 * sum(block_out_channels[1:]) + embedding_dim
|
733 |
+
for c in mlp_hidden_channels:
|
734 |
+
self.to_out.append(nn.Linear(d_channels, c))
|
735 |
+
if mlp_uses_norm:
|
736 |
+
self.to_out.append(nn.GroupNorm(groups, c))
|
737 |
+
self.to_out.append(nn.SiLU())
|
738 |
+
d_channels = c
|
739 |
+
self.to_out.append(nn.Linear(d_channels, out_channels))
|
740 |
+
|
741 |
+
self.gradient_checkpointing = False
|
742 |
+
|
743 |
+
def enable_gradient_checkpointing(self):
|
744 |
+
self.gradient_checkpointing = True
|
745 |
+
|
746 |
+
def disable_gradient_checkpointing(self):
|
747 |
+
self.gradient_checkpointing = False
|
748 |
+
|
749 |
+
def forward(self, x, encoder_hidden_states):
|
750 |
+
x = self.conv_in(x)
|
751 |
+
if self.config.embedding_dim != 0:
|
752 |
+
d = einops.reduce(encoder_hidden_states, 'b n c -> b c', 'mean')
|
753 |
+
else:
|
754 |
+
d = torch.zeros([x.shape[0], 0], device=x.device, dtype=x.dtype)
|
755 |
+
for block in self.blocks:
|
756 |
+
if self.gradient_checkpointing:
|
757 |
+
x = torch.utils.checkpoint.checkpoint(block, x)
|
758 |
+
else:
|
759 |
+
x = block(x)
|
760 |
+
x_mean = einops.reduce(x, 'b c ... -> b c', 'mean')
|
761 |
+
x_max = einops.reduce(x, 'b c ... -> b c', 'max')
|
762 |
+
d = torch.cat([d, x_mean, x_max], dim=-1)
|
763 |
+
return self.to_out(d)
|
764 |
+
|
StableTuner_RunPod_Fix/lion_pytorch.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Optional, Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.optim.optimizer import Optimizer
|
5 |
+
|
6 |
+
# functions
|
7 |
+
|
8 |
+
def exists(val):
|
9 |
+
return val is not None
|
10 |
+
|
11 |
+
# update functions
|
12 |
+
|
13 |
+
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
|
14 |
+
# stepweight decay
|
15 |
+
|
16 |
+
p.data.mul_(1 - lr * wd)
|
17 |
+
|
18 |
+
# weight update
|
19 |
+
|
20 |
+
update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1 - beta1).sign_()
|
21 |
+
p.add_(update, alpha = -lr)
|
22 |
+
|
23 |
+
# decay the momentum running average coefficient
|
24 |
+
|
25 |
+
exp_avg.mul_(beta2).add_(grad, alpha = 1 - beta2)
|
26 |
+
|
27 |
+
# class
|
28 |
+
|
29 |
+
class Lion(Optimizer):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
params,
|
33 |
+
lr: float = 1e-4,
|
34 |
+
betas: Tuple[float, float] = (0.9, 0.99),
|
35 |
+
weight_decay: float = 0.0,
|
36 |
+
use_triton: bool = False
|
37 |
+
):
|
38 |
+
assert lr > 0.
|
39 |
+
assert all([0. <= beta <= 1. for beta in betas])
|
40 |
+
|
41 |
+
defaults = dict(
|
42 |
+
lr = lr,
|
43 |
+
betas = betas,
|
44 |
+
weight_decay = weight_decay
|
45 |
+
)
|
46 |
+
|
47 |
+
super().__init__(params, defaults)
|
48 |
+
|
49 |
+
self.update_fn = update_fn
|
50 |
+
|
51 |
+
if use_triton:
|
52 |
+
from lion_pytorch.triton import update_fn as triton_update_fn
|
53 |
+
self.update_fn = triton_update_fn
|
54 |
+
|
55 |
+
@torch.no_grad()
|
56 |
+
def step(
|
57 |
+
self,
|
58 |
+
closure: Optional[Callable] = None
|
59 |
+
):
|
60 |
+
|
61 |
+
loss = None
|
62 |
+
if exists(closure):
|
63 |
+
with torch.enable_grad():
|
64 |
+
loss = closure()
|
65 |
+
|
66 |
+
for group in self.param_groups:
|
67 |
+
for p in filter(lambda p: exists(p.grad), group['params']):
|
68 |
+
|
69 |
+
grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p]
|
70 |
+
|
71 |
+
# init state - exponential moving average of gradient values
|
72 |
+
|
73 |
+
if len(state) == 0:
|
74 |
+
state['exp_avg'] = torch.zeros_like(p)
|
75 |
+
|
76 |
+
exp_avg = state['exp_avg']
|
77 |
+
|
78 |
+
self.update_fn(
|
79 |
+
p,
|
80 |
+
grad,
|
81 |
+
exp_avg,
|
82 |
+
lr,
|
83 |
+
wd,
|
84 |
+
beta1,
|
85 |
+
beta2
|
86 |
+
)
|
87 |
+
|
88 |
+
return loss
|
StableTuner_RunPod_Fix/lora_utils.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRA network module
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from trainer_util import *
|
11 |
+
|
12 |
+
|
13 |
+
class LoRAModule(torch.nn.Module):
|
14 |
+
"""
|
15 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
19 |
+
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
20 |
+
super().__init__()
|
21 |
+
self.lora_name = lora_name
|
22 |
+
self.lora_dim = lora_dim
|
23 |
+
|
24 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
25 |
+
in_dim = org_module.in_channels
|
26 |
+
out_dim = org_module.out_channels
|
27 |
+
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
28 |
+
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
29 |
+
else:
|
30 |
+
in_dim = org_module.in_features
|
31 |
+
out_dim = org_module.out_features
|
32 |
+
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
33 |
+
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
34 |
+
|
35 |
+
if type(alpha) == torch.Tensor:
|
36 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
37 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
38 |
+
self.scale = alpha / self.lora_dim
|
39 |
+
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
40 |
+
|
41 |
+
# same as microsoft's
|
42 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
43 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
44 |
+
|
45 |
+
self.multiplier = multiplier
|
46 |
+
self.org_module = org_module # remove in applying
|
47 |
+
|
48 |
+
def apply_to(self):
|
49 |
+
self.org_forward = self.org_module.forward
|
50 |
+
self.org_module.forward = self.forward
|
51 |
+
del self.org_module
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
55 |
+
|
56 |
+
|
57 |
+
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
58 |
+
if network_dim is None:
|
59 |
+
network_dim = 4 # default
|
60 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
61 |
+
return network
|
62 |
+
|
63 |
+
|
64 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
65 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
66 |
+
from safetensors.torch import load_file, safe_open
|
67 |
+
weights_sd = load_file(file)
|
68 |
+
else:
|
69 |
+
weights_sd = torch.load(file, map_location='cpu')
|
70 |
+
|
71 |
+
# get dim (rank)
|
72 |
+
network_alpha = None
|
73 |
+
network_dim = None
|
74 |
+
for key, value in weights_sd.items():
|
75 |
+
if network_alpha is None and 'alpha' in key:
|
76 |
+
network_alpha = value
|
77 |
+
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
78 |
+
network_dim = value.size()[0]
|
79 |
+
|
80 |
+
if network_alpha is None:
|
81 |
+
network_alpha = network_dim
|
82 |
+
|
83 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
84 |
+
network.weights_sd = weights_sd
|
85 |
+
return network
|
86 |
+
|
87 |
+
|
88 |
+
class LoRANetwork(torch.nn.Module):
|
89 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
90 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
91 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
92 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
93 |
+
|
94 |
+
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
95 |
+
super().__init__()
|
96 |
+
self.multiplier = multiplier
|
97 |
+
self.lora_dim = lora_dim
|
98 |
+
self.alpha = alpha
|
99 |
+
|
100 |
+
# create module instances
|
101 |
+
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
102 |
+
loras = []
|
103 |
+
for name, module in root_module.named_modules():
|
104 |
+
if module.__class__.__name__ in target_replace_modules:
|
105 |
+
for child_name, child_module in module.named_modules():
|
106 |
+
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
107 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
108 |
+
lora_name = lora_name.replace('.', '_')
|
109 |
+
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
110 |
+
loras.append(lora)
|
111 |
+
return loras
|
112 |
+
|
113 |
+
self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
|
114 |
+
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
115 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
116 |
+
|
117 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
|
118 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
119 |
+
|
120 |
+
self.weights_sd = None
|
121 |
+
|
122 |
+
# assertion
|
123 |
+
names = set()
|
124 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
125 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
126 |
+
names.add(lora.lora_name)
|
127 |
+
|
128 |
+
def load_weights(self, file):
|
129 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
130 |
+
from safetensors.torch import load_file, safe_open
|
131 |
+
self.weights_sd = load_file(file)
|
132 |
+
else:
|
133 |
+
self.weights_sd = torch.load(file, map_location='cpu')
|
134 |
+
|
135 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
136 |
+
if self.weights_sd:
|
137 |
+
weights_has_text_encoder = weights_has_unet = False
|
138 |
+
for key in self.weights_sd.keys():
|
139 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
140 |
+
weights_has_text_encoder = True
|
141 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
142 |
+
weights_has_unet = True
|
143 |
+
|
144 |
+
if apply_text_encoder is None:
|
145 |
+
apply_text_encoder = weights_has_text_encoder
|
146 |
+
else:
|
147 |
+
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
148 |
+
|
149 |
+
if apply_unet is None:
|
150 |
+
apply_unet = weights_has_unet
|
151 |
+
else:
|
152 |
+
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
153 |
+
else:
|
154 |
+
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
155 |
+
|
156 |
+
if apply_text_encoder:
|
157 |
+
print("enable LoRA for text encoder")
|
158 |
+
else:
|
159 |
+
self.text_encoder_loras = []
|
160 |
+
|
161 |
+
if apply_unet:
|
162 |
+
print("enable LoRA for U-Net")
|
163 |
+
else:
|
164 |
+
self.unet_loras = []
|
165 |
+
|
166 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
167 |
+
lora.apply_to()
|
168 |
+
self.add_module(lora.lora_name, lora)
|
169 |
+
|
170 |
+
if self.weights_sd:
|
171 |
+
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
172 |
+
info = self.load_state_dict(self.weights_sd, False)
|
173 |
+
print(f"weights are loaded: {info}")
|
174 |
+
|
175 |
+
def enable_gradient_checkpointing(self):
|
176 |
+
# not supported
|
177 |
+
pass
|
178 |
+
|
179 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
180 |
+
def enumerate_params(loras):
|
181 |
+
params = []
|
182 |
+
for lora in loras:
|
183 |
+
params.extend(lora.parameters())
|
184 |
+
return params
|
185 |
+
|
186 |
+
self.requires_grad_(True)
|
187 |
+
all_params = []
|
188 |
+
|
189 |
+
if self.text_encoder_loras:
|
190 |
+
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
191 |
+
if text_encoder_lr is not None:
|
192 |
+
param_data['lr'] = text_encoder_lr
|
193 |
+
all_params.append(param_data)
|
194 |
+
|
195 |
+
if self.unet_loras:
|
196 |
+
param_data = {'params': enumerate_params(self.unet_loras)}
|
197 |
+
if unet_lr is not None:
|
198 |
+
param_data['lr'] = unet_lr
|
199 |
+
all_params.append(param_data)
|
200 |
+
|
201 |
+
return all_params
|
202 |
+
|
203 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
204 |
+
self.requires_grad_(True)
|
205 |
+
|
206 |
+
def on_epoch_start(self, text_encoder, unet):
|
207 |
+
self.train()
|
208 |
+
|
209 |
+
def get_trainable_params(self):
|
210 |
+
return self.parameters()
|
211 |
+
|
212 |
+
def save_weights(self, file, dtype, metadata):
|
213 |
+
if metadata is not None and len(metadata) == 0:
|
214 |
+
metadata = None
|
215 |
+
|
216 |
+
state_dict = self.state_dict()
|
217 |
+
|
218 |
+
if dtype is not None:
|
219 |
+
for key in list(state_dict.keys()):
|
220 |
+
v = state_dict[key]
|
221 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
222 |
+
state_dict[key] = v
|
223 |
+
|
224 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
225 |
+
from safetensors.torch import save_file
|
226 |
+
|
227 |
+
# Precalculate model hashes to save time on indexing
|
228 |
+
if metadata is None:
|
229 |
+
metadata = {}
|
230 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
231 |
+
metadata["sshs_model_hash"] = model_hash
|
232 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
233 |
+
|
234 |
+
save_file(state_dict, file, metadata)
|
235 |
+
else:
|
236 |
+
torch.save(state_dict, file)
|
StableTuner_RunPod_Fix/model_util.py
ADDED
@@ -0,0 +1,1543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v1: split from train_db_fixed.py.
|
2 |
+
# v2: support safetensors
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
8 |
+
from diffusers import (
|
9 |
+
AutoencoderKL,
|
10 |
+
DDIMScheduler,
|
11 |
+
StableDiffusionPipeline,
|
12 |
+
UNet2DConditionModel,
|
13 |
+
)
|
14 |
+
from safetensors.torch import load_file, save_file
|
15 |
+
|
16 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
17 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
18 |
+
BETA_START = 0.00085
|
19 |
+
BETA_END = 0.0120
|
20 |
+
|
21 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
22 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
23 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
24 |
+
UNET_PARAMS_IMAGE_SIZE = 32 # unused
|
25 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
26 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
27 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
28 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
29 |
+
UNET_PARAMS_NUM_HEADS = 8
|
30 |
+
|
31 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
32 |
+
VAE_PARAMS_RESOLUTION = 256
|
33 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
34 |
+
VAE_PARAMS_OUT_CH = 3
|
35 |
+
VAE_PARAMS_CH = 128
|
36 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
37 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
38 |
+
|
39 |
+
# V2
|
40 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
41 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
42 |
+
|
43 |
+
# Diffusersの設定を読み込むための参照モデル
|
44 |
+
DIFFUSERS_REF_MODEL_ID_V1 = 'runwayml/stable-diffusion-v1-5'
|
45 |
+
DIFFUSERS_REF_MODEL_ID_V2 = 'stabilityai/stable-diffusion-2-1'
|
46 |
+
|
47 |
+
|
48 |
+
# region StableDiffusion->Diffusersの変換コード
|
49 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
50 |
+
|
51 |
+
|
52 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
53 |
+
"""
|
54 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
55 |
+
"""
|
56 |
+
if n_shave_prefix_segments >= 0:
|
57 |
+
return '.'.join(path.split('.')[n_shave_prefix_segments:])
|
58 |
+
else:
|
59 |
+
return '.'.join(path.split('.')[:n_shave_prefix_segments])
|
60 |
+
|
61 |
+
|
62 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
63 |
+
"""
|
64 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
65 |
+
"""
|
66 |
+
mapping = []
|
67 |
+
for old_item in old_list:
|
68 |
+
new_item = old_item.replace('in_layers.0', 'norm1')
|
69 |
+
new_item = new_item.replace('in_layers.2', 'conv1')
|
70 |
+
|
71 |
+
new_item = new_item.replace('out_layers.0', 'norm2')
|
72 |
+
new_item = new_item.replace('out_layers.3', 'conv2')
|
73 |
+
|
74 |
+
new_item = new_item.replace('emb_layers.1', 'time_emb_proj')
|
75 |
+
new_item = new_item.replace('skip_connection', 'conv_shortcut')
|
76 |
+
|
77 |
+
new_item = shave_segments(
|
78 |
+
new_item, n_shave_prefix_segments=n_shave_prefix_segments
|
79 |
+
)
|
80 |
+
|
81 |
+
mapping.append({'old': old_item, 'new': new_item})
|
82 |
+
|
83 |
+
return mapping
|
84 |
+
|
85 |
+
|
86 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
87 |
+
"""
|
88 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
89 |
+
"""
|
90 |
+
mapping = []
|
91 |
+
for old_item in old_list:
|
92 |
+
new_item = old_item
|
93 |
+
|
94 |
+
new_item = new_item.replace('nin_shortcut', 'conv_shortcut')
|
95 |
+
new_item = shave_segments(
|
96 |
+
new_item, n_shave_prefix_segments=n_shave_prefix_segments
|
97 |
+
)
|
98 |
+
|
99 |
+
mapping.append({'old': old_item, 'new': new_item})
|
100 |
+
|
101 |
+
return mapping
|
102 |
+
|
103 |
+
|
104 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
105 |
+
"""
|
106 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
107 |
+
"""
|
108 |
+
mapping = []
|
109 |
+
for old_item in old_list:
|
110 |
+
new_item = old_item
|
111 |
+
|
112 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
113 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
114 |
+
|
115 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
116 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
117 |
+
|
118 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
119 |
+
|
120 |
+
mapping.append({'old': old_item, 'new': new_item})
|
121 |
+
|
122 |
+
return mapping
|
123 |
+
|
124 |
+
|
125 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
126 |
+
"""
|
127 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
128 |
+
"""
|
129 |
+
mapping = []
|
130 |
+
for old_item in old_list:
|
131 |
+
new_item = old_item
|
132 |
+
|
133 |
+
new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
134 |
+
new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
135 |
+
|
136 |
+
new_item = new_item.replace('q.weight', 'query.weight')
|
137 |
+
new_item = new_item.replace('q.bias', 'query.bias')
|
138 |
+
|
139 |
+
new_item = new_item.replace('k.weight', 'key.weight')
|
140 |
+
new_item = new_item.replace('k.bias', 'key.bias')
|
141 |
+
|
142 |
+
new_item = new_item.replace('v.weight', 'value.weight')
|
143 |
+
new_item = new_item.replace('v.bias', 'value.bias')
|
144 |
+
|
145 |
+
new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
146 |
+
new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
147 |
+
|
148 |
+
new_item = shave_segments(
|
149 |
+
new_item, n_shave_prefix_segments=n_shave_prefix_segments
|
150 |
+
)
|
151 |
+
|
152 |
+
mapping.append({'old': old_item, 'new': new_item})
|
153 |
+
|
154 |
+
return mapping
|
155 |
+
|
156 |
+
|
157 |
+
def assign_to_checkpoint(
|
158 |
+
paths,
|
159 |
+
checkpoint,
|
160 |
+
old_checkpoint,
|
161 |
+
attention_paths_to_split=None,
|
162 |
+
additional_replacements=None,
|
163 |
+
config=None,
|
164 |
+
):
|
165 |
+
"""
|
166 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
167 |
+
to them. It splits attention layers, and takes into account additional replacements
|
168 |
+
that may arise.
|
169 |
+
Assigns the weights to the new checkpoint.
|
170 |
+
"""
|
171 |
+
assert isinstance(
|
172 |
+
paths, list
|
173 |
+
), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
174 |
+
|
175 |
+
# Splits the attention layers into three variables.
|
176 |
+
if attention_paths_to_split is not None:
|
177 |
+
for path, path_map in attention_paths_to_split.items():
|
178 |
+
old_tensor = old_checkpoint[path]
|
179 |
+
channels = old_tensor.shape[0] // 3
|
180 |
+
|
181 |
+
target_shape = (
|
182 |
+
(-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
183 |
+
)
|
184 |
+
|
185 |
+
num_heads = old_tensor.shape[0] // config['num_head_channels'] // 3
|
186 |
+
|
187 |
+
old_tensor = old_tensor.reshape(
|
188 |
+
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
|
189 |
+
)
|
190 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
191 |
+
|
192 |
+
checkpoint[path_map['query']] = query.reshape(target_shape)
|
193 |
+
checkpoint[path_map['key']] = key.reshape(target_shape)
|
194 |
+
checkpoint[path_map['value']] = value.reshape(target_shape)
|
195 |
+
|
196 |
+
for path in paths:
|
197 |
+
new_path = path['new']
|
198 |
+
|
199 |
+
# These have already been assigned
|
200 |
+
if (
|
201 |
+
attention_paths_to_split is not None
|
202 |
+
and new_path in attention_paths_to_split
|
203 |
+
):
|
204 |
+
continue
|
205 |
+
|
206 |
+
# Global renaming happens here
|
207 |
+
new_path = new_path.replace('middle_block.0', 'mid_block.resnets.0')
|
208 |
+
new_path = new_path.replace('middle_block.1', 'mid_block.attentions.0')
|
209 |
+
new_path = new_path.replace('middle_block.2', 'mid_block.resnets.1')
|
210 |
+
|
211 |
+
if additional_replacements is not None:
|
212 |
+
for replacement in additional_replacements:
|
213 |
+
new_path = new_path.replace(
|
214 |
+
replacement['old'], replacement['new']
|
215 |
+
)
|
216 |
+
|
217 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
218 |
+
if 'proj_attn.weight' in new_path:
|
219 |
+
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
|
220 |
+
else:
|
221 |
+
checkpoint[new_path] = old_checkpoint[path['old']]
|
222 |
+
|
223 |
+
|
224 |
+
def conv_attn_to_linear(checkpoint):
|
225 |
+
keys = list(checkpoint.keys())
|
226 |
+
attn_keys = ['query.weight', 'key.weight', 'value.weight']
|
227 |
+
for key in keys:
|
228 |
+
if '.'.join(key.split('.')[-2:]) in attn_keys:
|
229 |
+
if checkpoint[key].ndim > 2:
|
230 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
231 |
+
elif 'proj_attn.weight' in key:
|
232 |
+
if checkpoint[key].ndim > 2:
|
233 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
234 |
+
|
235 |
+
|
236 |
+
def linear_transformer_to_conv(checkpoint):
|
237 |
+
keys = list(checkpoint.keys())
|
238 |
+
tf_keys = ['proj_in.weight', 'proj_out.weight']
|
239 |
+
for key in keys:
|
240 |
+
if '.'.join(key.split('.')[-2:]) in tf_keys:
|
241 |
+
if checkpoint[key].ndim == 2:
|
242 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
243 |
+
|
244 |
+
|
245 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
246 |
+
"""
|
247 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
248 |
+
"""
|
249 |
+
|
250 |
+
# extract state_dict for UNet
|
251 |
+
unet_state_dict = {}
|
252 |
+
unet_key = 'model.diffusion_model.'
|
253 |
+
keys = list(checkpoint.keys())
|
254 |
+
for key in keys:
|
255 |
+
if key.startswith(unet_key):
|
256 |
+
unet_state_dict[key.replace(unet_key, '')] = checkpoint.pop(key)
|
257 |
+
|
258 |
+
new_checkpoint = {}
|
259 |
+
|
260 |
+
new_checkpoint['time_embedding.linear_1.weight'] = unet_state_dict[
|
261 |
+
'time_embed.0.weight'
|
262 |
+
]
|
263 |
+
new_checkpoint['time_embedding.linear_1.bias'] = unet_state_dict[
|
264 |
+
'time_embed.0.bias'
|
265 |
+
]
|
266 |
+
new_checkpoint['time_embedding.linear_2.weight'] = unet_state_dict[
|
267 |
+
'time_embed.2.weight'
|
268 |
+
]
|
269 |
+
new_checkpoint['time_embedding.linear_2.bias'] = unet_state_dict[
|
270 |
+
'time_embed.2.bias'
|
271 |
+
]
|
272 |
+
|
273 |
+
new_checkpoint['conv_in.weight'] = unet_state_dict[
|
274 |
+
'input_blocks.0.0.weight'
|
275 |
+
]
|
276 |
+
new_checkpoint['conv_in.bias'] = unet_state_dict['input_blocks.0.0.bias']
|
277 |
+
|
278 |
+
new_checkpoint['conv_norm_out.weight'] = unet_state_dict['out.0.weight']
|
279 |
+
new_checkpoint['conv_norm_out.bias'] = unet_state_dict['out.0.bias']
|
280 |
+
new_checkpoint['conv_out.weight'] = unet_state_dict['out.2.weight']
|
281 |
+
new_checkpoint['conv_out.bias'] = unet_state_dict['out.2.bias']
|
282 |
+
|
283 |
+
# Retrieves the keys for the input blocks only
|
284 |
+
num_input_blocks = len(
|
285 |
+
{
|
286 |
+
'.'.join(layer.split('.')[:2])
|
287 |
+
for layer in unet_state_dict
|
288 |
+
if 'input_blocks' in layer
|
289 |
+
}
|
290 |
+
)
|
291 |
+
input_blocks = {
|
292 |
+
layer_id: [
|
293 |
+
key
|
294 |
+
for key in unet_state_dict
|
295 |
+
if f'input_blocks.{layer_id}.' in key
|
296 |
+
]
|
297 |
+
for layer_id in range(num_input_blocks)
|
298 |
+
}
|
299 |
+
|
300 |
+
# Retrieves the keys for the middle blocks only
|
301 |
+
num_middle_blocks = len(
|
302 |
+
{
|
303 |
+
'.'.join(layer.split('.')[:2])
|
304 |
+
for layer in unet_state_dict
|
305 |
+
if 'middle_block' in layer
|
306 |
+
}
|
307 |
+
)
|
308 |
+
middle_blocks = {
|
309 |
+
layer_id: [
|
310 |
+
key
|
311 |
+
for key in unet_state_dict
|
312 |
+
if f'middle_block.{layer_id}.' in key
|
313 |
+
]
|
314 |
+
for layer_id in range(num_middle_blocks)
|
315 |
+
}
|
316 |
+
|
317 |
+
# Retrieves the keys for the output blocks only
|
318 |
+
num_output_blocks = len(
|
319 |
+
{
|
320 |
+
'.'.join(layer.split('.')[:2])
|
321 |
+
for layer in unet_state_dict
|
322 |
+
if 'output_blocks' in layer
|
323 |
+
}
|
324 |
+
)
|
325 |
+
output_blocks = {
|
326 |
+
layer_id: [
|
327 |
+
key
|
328 |
+
for key in unet_state_dict
|
329 |
+
if f'output_blocks.{layer_id}.' in key
|
330 |
+
]
|
331 |
+
for layer_id in range(num_output_blocks)
|
332 |
+
}
|
333 |
+
|
334 |
+
for i in range(1, num_input_blocks):
|
335 |
+
block_id = (i - 1) // (config['layers_per_block'] + 1)
|
336 |
+
layer_in_block_id = (i - 1) % (config['layers_per_block'] + 1)
|
337 |
+
|
338 |
+
resnets = [
|
339 |
+
key
|
340 |
+
for key in input_blocks[i]
|
341 |
+
if f'input_blocks.{i}.0' in key
|
342 |
+
and f'input_blocks.{i}.0.op' not in key
|
343 |
+
]
|
344 |
+
attentions = [
|
345 |
+
key for key in input_blocks[i] if f'input_blocks.{i}.1' in key
|
346 |
+
]
|
347 |
+
|
348 |
+
if f'input_blocks.{i}.0.op.weight' in unet_state_dict:
|
349 |
+
new_checkpoint[
|
350 |
+
f'down_blocks.{block_id}.downsamplers.0.conv.weight'
|
351 |
+
] = unet_state_dict.pop(f'input_blocks.{i}.0.op.weight')
|
352 |
+
new_checkpoint[
|
353 |
+
f'down_blocks.{block_id}.downsamplers.0.conv.bias'
|
354 |
+
] = unet_state_dict.pop(f'input_blocks.{i}.0.op.bias')
|
355 |
+
|
356 |
+
paths = renew_resnet_paths(resnets)
|
357 |
+
meta_path = {
|
358 |
+
'old': f'input_blocks.{i}.0',
|
359 |
+
'new': f'down_blocks.{block_id}.resnets.{layer_in_block_id}',
|
360 |
+
}
|
361 |
+
assign_to_checkpoint(
|
362 |
+
paths,
|
363 |
+
new_checkpoint,
|
364 |
+
unet_state_dict,
|
365 |
+
additional_replacements=[meta_path],
|
366 |
+
config=config,
|
367 |
+
)
|
368 |
+
|
369 |
+
if len(attentions):
|
370 |
+
paths = renew_attention_paths(attentions)
|
371 |
+
meta_path = {
|
372 |
+
'old': f'input_blocks.{i}.1',
|
373 |
+
'new': f'down_blocks.{block_id}.attentions.{layer_in_block_id}',
|
374 |
+
}
|
375 |
+
assign_to_checkpoint(
|
376 |
+
paths,
|
377 |
+
new_checkpoint,
|
378 |
+
unet_state_dict,
|
379 |
+
additional_replacements=[meta_path],
|
380 |
+
config=config,
|
381 |
+
)
|
382 |
+
|
383 |
+
resnet_0 = middle_blocks[0]
|
384 |
+
attentions = middle_blocks[1]
|
385 |
+
resnet_1 = middle_blocks[2]
|
386 |
+
|
387 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
388 |
+
assign_to_checkpoint(
|
389 |
+
resnet_0_paths, new_checkpoint, unet_state_dict, config=config
|
390 |
+
)
|
391 |
+
|
392 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
393 |
+
assign_to_checkpoint(
|
394 |
+
resnet_1_paths, new_checkpoint, unet_state_dict, config=config
|
395 |
+
)
|
396 |
+
|
397 |
+
attentions_paths = renew_attention_paths(attentions)
|
398 |
+
meta_path = {'old': 'middle_block.1', 'new': 'mid_block.attentions.0'}
|
399 |
+
assign_to_checkpoint(
|
400 |
+
attentions_paths,
|
401 |
+
new_checkpoint,
|
402 |
+
unet_state_dict,
|
403 |
+
additional_replacements=[meta_path],
|
404 |
+
config=config,
|
405 |
+
)
|
406 |
+
|
407 |
+
for i in range(num_output_blocks):
|
408 |
+
block_id = i // (config['layers_per_block'] + 1)
|
409 |
+
layer_in_block_id = i % (config['layers_per_block'] + 1)
|
410 |
+
output_block_layers = [
|
411 |
+
shave_segments(name, 2) for name in output_blocks[i]
|
412 |
+
]
|
413 |
+
output_block_list = {}
|
414 |
+
|
415 |
+
for layer in output_block_layers:
|
416 |
+
layer_id, layer_name = layer.split('.')[0], shave_segments(
|
417 |
+
layer, 1
|
418 |
+
)
|
419 |
+
if layer_id in output_block_list:
|
420 |
+
output_block_list[layer_id].append(layer_name)
|
421 |
+
else:
|
422 |
+
output_block_list[layer_id] = [layer_name]
|
423 |
+
|
424 |
+
if len(output_block_list) > 1:
|
425 |
+
resnets = [
|
426 |
+
key
|
427 |
+
for key in output_blocks[i]
|
428 |
+
if f'output_blocks.{i}.0' in key
|
429 |
+
]
|
430 |
+
attentions = [
|
431 |
+
key
|
432 |
+
for key in output_blocks[i]
|
433 |
+
if f'output_blocks.{i}.1' in key
|
434 |
+
]
|
435 |
+
|
436 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
437 |
+
paths = renew_resnet_paths(resnets)
|
438 |
+
|
439 |
+
meta_path = {
|
440 |
+
'old': f'output_blocks.{i}.0',
|
441 |
+
'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}',
|
442 |
+
}
|
443 |
+
assign_to_checkpoint(
|
444 |
+
paths,
|
445 |
+
new_checkpoint,
|
446 |
+
unet_state_dict,
|
447 |
+
additional_replacements=[meta_path],
|
448 |
+
config=config,
|
449 |
+
)
|
450 |
+
|
451 |
+
# オリジナル:
|
452 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
453 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
454 |
+
|
455 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
456 |
+
for l in output_block_list.values():
|
457 |
+
l.sort()
|
458 |
+
|
459 |
+
if ['conv.bias', 'conv.weight'] in output_block_list.values():
|
460 |
+
index = list(output_block_list.values()).index(
|
461 |
+
['conv.bias', 'conv.weight']
|
462 |
+
)
|
463 |
+
new_checkpoint[
|
464 |
+
f'up_blocks.{block_id}.upsamplers.0.conv.bias'
|
465 |
+
] = unet_state_dict[f'output_blocks.{i}.{index}.conv.bias']
|
466 |
+
new_checkpoint[
|
467 |
+
f'up_blocks.{block_id}.upsamplers.0.conv.weight'
|
468 |
+
] = unet_state_dict[f'output_blocks.{i}.{index}.conv.weight']
|
469 |
+
|
470 |
+
# Clear attentions as they have been attributed above.
|
471 |
+
if len(attentions) == 2:
|
472 |
+
attentions = []
|
473 |
+
|
474 |
+
if len(attentions):
|
475 |
+
paths = renew_attention_paths(attentions)
|
476 |
+
meta_path = {
|
477 |
+
'old': f'output_blocks.{i}.1',
|
478 |
+
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}',
|
479 |
+
}
|
480 |
+
assign_to_checkpoint(
|
481 |
+
paths,
|
482 |
+
new_checkpoint,
|
483 |
+
unet_state_dict,
|
484 |
+
additional_replacements=[meta_path],
|
485 |
+
config=config,
|
486 |
+
)
|
487 |
+
else:
|
488 |
+
resnet_0_paths = renew_resnet_paths(
|
489 |
+
output_block_layers, n_shave_prefix_segments=1
|
490 |
+
)
|
491 |
+
for path in resnet_0_paths:
|
492 |
+
old_path = '.'.join(['output_blocks', str(i), path['old']])
|
493 |
+
new_path = '.'.join(
|
494 |
+
[
|
495 |
+
'up_blocks',
|
496 |
+
str(block_id),
|
497 |
+
'resnets',
|
498 |
+
str(layer_in_block_id),
|
499 |
+
path['new'],
|
500 |
+
]
|
501 |
+
)
|
502 |
+
|
503 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
504 |
+
|
505 |
+
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
506 |
+
if v2:
|
507 |
+
linear_transformer_to_conv(new_checkpoint)
|
508 |
+
|
509 |
+
return new_checkpoint
|
510 |
+
|
511 |
+
|
512 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
513 |
+
# extract state dict for VAE
|
514 |
+
vae_state_dict = {}
|
515 |
+
vae_key = 'first_stage_model.'
|
516 |
+
keys = list(checkpoint.keys())
|
517 |
+
for key in keys:
|
518 |
+
if key.startswith(vae_key):
|
519 |
+
vae_state_dict[key.replace(vae_key, '')] = checkpoint.get(key)
|
520 |
+
# if len(vae_state_dict) == 0:
|
521 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
522 |
+
# vae_state_dict = checkpoint
|
523 |
+
|
524 |
+
new_checkpoint = {}
|
525 |
+
|
526 |
+
new_checkpoint['encoder.conv_in.weight'] = vae_state_dict[
|
527 |
+
'encoder.conv_in.weight'
|
528 |
+
]
|
529 |
+
new_checkpoint['encoder.conv_in.bias'] = vae_state_dict[
|
530 |
+
'encoder.conv_in.bias'
|
531 |
+
]
|
532 |
+
new_checkpoint['encoder.conv_out.weight'] = vae_state_dict[
|
533 |
+
'encoder.conv_out.weight'
|
534 |
+
]
|
535 |
+
new_checkpoint['encoder.conv_out.bias'] = vae_state_dict[
|
536 |
+
'encoder.conv_out.bias'
|
537 |
+
]
|
538 |
+
new_checkpoint['encoder.conv_norm_out.weight'] = vae_state_dict[
|
539 |
+
'encoder.norm_out.weight'
|
540 |
+
]
|
541 |
+
new_checkpoint['encoder.conv_norm_out.bias'] = vae_state_dict[
|
542 |
+
'encoder.norm_out.bias'
|
543 |
+
]
|
544 |
+
|
545 |
+
new_checkpoint['decoder.conv_in.weight'] = vae_state_dict[
|
546 |
+
'decoder.conv_in.weight'
|
547 |
+
]
|
548 |
+
new_checkpoint['decoder.conv_in.bias'] = vae_state_dict[
|
549 |
+
'decoder.conv_in.bias'
|
550 |
+
]
|
551 |
+
new_checkpoint['decoder.conv_out.weight'] = vae_state_dict[
|
552 |
+
'decoder.conv_out.weight'
|
553 |
+
]
|
554 |
+
new_checkpoint['decoder.conv_out.bias'] = vae_state_dict[
|
555 |
+
'decoder.conv_out.bias'
|
556 |
+
]
|
557 |
+
new_checkpoint['decoder.conv_norm_out.weight'] = vae_state_dict[
|
558 |
+
'decoder.norm_out.weight'
|
559 |
+
]
|
560 |
+
new_checkpoint['decoder.conv_norm_out.bias'] = vae_state_dict[
|
561 |
+
'decoder.norm_out.bias'
|
562 |
+
]
|
563 |
+
|
564 |
+
new_checkpoint['quant_conv.weight'] = vae_state_dict['quant_conv.weight']
|
565 |
+
new_checkpoint['quant_conv.bias'] = vae_state_dict['quant_conv.bias']
|
566 |
+
new_checkpoint['post_quant_conv.weight'] = vae_state_dict[
|
567 |
+
'post_quant_conv.weight'
|
568 |
+
]
|
569 |
+
new_checkpoint['post_quant_conv.bias'] = vae_state_dict[
|
570 |
+
'post_quant_conv.bias'
|
571 |
+
]
|
572 |
+
|
573 |
+
# Retrieves the keys for the encoder down blocks only
|
574 |
+
num_down_blocks = len(
|
575 |
+
{
|
576 |
+
'.'.join(layer.split('.')[:3])
|
577 |
+
for layer in vae_state_dict
|
578 |
+
if 'encoder.down' in layer
|
579 |
+
}
|
580 |
+
)
|
581 |
+
down_blocks = {
|
582 |
+
layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key]
|
583 |
+
for layer_id in range(num_down_blocks)
|
584 |
+
}
|
585 |
+
|
586 |
+
# Retrieves the keys for the decoder up blocks only
|
587 |
+
num_up_blocks = len(
|
588 |
+
{
|
589 |
+
'.'.join(layer.split('.')[:3])
|
590 |
+
for layer in vae_state_dict
|
591 |
+
if 'decoder.up' in layer
|
592 |
+
}
|
593 |
+
)
|
594 |
+
up_blocks = {
|
595 |
+
layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key]
|
596 |
+
for layer_id in range(num_up_blocks)
|
597 |
+
}
|
598 |
+
|
599 |
+
for i in range(num_down_blocks):
|
600 |
+
resnets = [
|
601 |
+
key
|
602 |
+
for key in down_blocks[i]
|
603 |
+
if f'down.{i}' in key and f'down.{i}.downsample' not in key
|
604 |
+
]
|
605 |
+
|
606 |
+
if f'encoder.down.{i}.downsample.conv.weight' in vae_state_dict:
|
607 |
+
new_checkpoint[
|
608 |
+
f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'
|
609 |
+
] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.weight')
|
610 |
+
new_checkpoint[
|
611 |
+
f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'
|
612 |
+
] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.bias')
|
613 |
+
|
614 |
+
paths = renew_vae_resnet_paths(resnets)
|
615 |
+
meta_path = {
|
616 |
+
'old': f'down.{i}.block',
|
617 |
+
'new': f'down_blocks.{i}.resnets',
|
618 |
+
}
|
619 |
+
assign_to_checkpoint(
|
620 |
+
paths,
|
621 |
+
new_checkpoint,
|
622 |
+
vae_state_dict,
|
623 |
+
additional_replacements=[meta_path],
|
624 |
+
config=config,
|
625 |
+
)
|
626 |
+
|
627 |
+
mid_resnets = [key for key in vae_state_dict if 'encoder.mid.block' in key]
|
628 |
+
num_mid_res_blocks = 2
|
629 |
+
for i in range(1, num_mid_res_blocks + 1):
|
630 |
+
resnets = [
|
631 |
+
key for key in mid_resnets if f'encoder.mid.block_{i}' in key
|
632 |
+
]
|
633 |
+
|
634 |
+
paths = renew_vae_resnet_paths(resnets)
|
635 |
+
meta_path = {
|
636 |
+
'old': f'mid.block_{i}',
|
637 |
+
'new': f'mid_block.resnets.{i - 1}',
|
638 |
+
}
|
639 |
+
assign_to_checkpoint(
|
640 |
+
paths,
|
641 |
+
new_checkpoint,
|
642 |
+
vae_state_dict,
|
643 |
+
additional_replacements=[meta_path],
|
644 |
+
config=config,
|
645 |
+
)
|
646 |
+
|
647 |
+
mid_attentions = [
|
648 |
+
key for key in vae_state_dict if 'encoder.mid.attn' in key
|
649 |
+
]
|
650 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
651 |
+
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
|
652 |
+
assign_to_checkpoint(
|
653 |
+
paths,
|
654 |
+
new_checkpoint,
|
655 |
+
vae_state_dict,
|
656 |
+
additional_replacements=[meta_path],
|
657 |
+
config=config,
|
658 |
+
)
|
659 |
+
conv_attn_to_linear(new_checkpoint)
|
660 |
+
|
661 |
+
for i in range(num_up_blocks):
|
662 |
+
block_id = num_up_blocks - 1 - i
|
663 |
+
resnets = [
|
664 |
+
key
|
665 |
+
for key in up_blocks[block_id]
|
666 |
+
if f'up.{block_id}' in key and f'up.{block_id}.upsample' not in key
|
667 |
+
]
|
668 |
+
|
669 |
+
if f'decoder.up.{block_id}.upsample.conv.weight' in vae_state_dict:
|
670 |
+
new_checkpoint[
|
671 |
+
f'decoder.up_blocks.{i}.upsamplers.0.conv.weight'
|
672 |
+
] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.weight']
|
673 |
+
new_checkpoint[
|
674 |
+
f'decoder.up_blocks.{i}.upsamplers.0.conv.bias'
|
675 |
+
] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.bias']
|
676 |
+
|
677 |
+
paths = renew_vae_resnet_paths(resnets)
|
678 |
+
meta_path = {
|
679 |
+
'old': f'up.{block_id}.block',
|
680 |
+
'new': f'up_blocks.{i}.resnets',
|
681 |
+
}
|
682 |
+
assign_to_checkpoint(
|
683 |
+
paths,
|
684 |
+
new_checkpoint,
|
685 |
+
vae_state_dict,
|
686 |
+
additional_replacements=[meta_path],
|
687 |
+
config=config,
|
688 |
+
)
|
689 |
+
|
690 |
+
mid_resnets = [key for key in vae_state_dict if 'decoder.mid.block' in key]
|
691 |
+
num_mid_res_blocks = 2
|
692 |
+
for i in range(1, num_mid_res_blocks + 1):
|
693 |
+
resnets = [
|
694 |
+
key for key in mid_resnets if f'decoder.mid.block_{i}' in key
|
695 |
+
]
|
696 |
+
|
697 |
+
paths = renew_vae_resnet_paths(resnets)
|
698 |
+
meta_path = {
|
699 |
+
'old': f'mid.block_{i}',
|
700 |
+
'new': f'mid_block.resnets.{i - 1}',
|
701 |
+
}
|
702 |
+
assign_to_checkpoint(
|
703 |
+
paths,
|
704 |
+
new_checkpoint,
|
705 |
+
vae_state_dict,
|
706 |
+
additional_replacements=[meta_path],
|
707 |
+
config=config,
|
708 |
+
)
|
709 |
+
|
710 |
+
mid_attentions = [
|
711 |
+
key for key in vae_state_dict if 'decoder.mid.attn' in key
|
712 |
+
]
|
713 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
714 |
+
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
|
715 |
+
assign_to_checkpoint(
|
716 |
+
paths,
|
717 |
+
new_checkpoint,
|
718 |
+
vae_state_dict,
|
719 |
+
additional_replacements=[meta_path],
|
720 |
+
config=config,
|
721 |
+
)
|
722 |
+
conv_attn_to_linear(new_checkpoint)
|
723 |
+
return new_checkpoint
|
724 |
+
|
725 |
+
|
726 |
+
def create_unet_diffusers_config(v2):
|
727 |
+
"""
|
728 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
729 |
+
"""
|
730 |
+
# unet_params = original_config.model.params.unet_config.params
|
731 |
+
|
732 |
+
block_out_channels = [
|
733 |
+
UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT
|
734 |
+
]
|
735 |
+
|
736 |
+
down_block_types = []
|
737 |
+
resolution = 1
|
738 |
+
for i in range(len(block_out_channels)):
|
739 |
+
block_type = (
|
740 |
+
'CrossAttnDownBlock2D'
|
741 |
+
if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
|
742 |
+
else 'DownBlock2D'
|
743 |
+
)
|
744 |
+
down_block_types.append(block_type)
|
745 |
+
if i != len(block_out_channels) - 1:
|
746 |
+
resolution *= 2
|
747 |
+
|
748 |
+
up_block_types = []
|
749 |
+
for i in range(len(block_out_channels)):
|
750 |
+
block_type = (
|
751 |
+
'CrossAttnUpBlock2D'
|
752 |
+
if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
|
753 |
+
else 'UpBlock2D'
|
754 |
+
)
|
755 |
+
up_block_types.append(block_type)
|
756 |
+
resolution //= 2
|
757 |
+
|
758 |
+
config = dict(
|
759 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
760 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
761 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
762 |
+
down_block_types=tuple(down_block_types),
|
763 |
+
up_block_types=tuple(up_block_types),
|
764 |
+
block_out_channels=tuple(block_out_channels),
|
765 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
766 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM
|
767 |
+
if not v2
|
768 |
+
else V2_UNET_PARAMS_CONTEXT_DIM,
|
769 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS
|
770 |
+
if not v2
|
771 |
+
else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
772 |
+
)
|
773 |
+
|
774 |
+
return config
|
775 |
+
|
776 |
+
|
777 |
+
def create_vae_diffusers_config():
|
778 |
+
"""
|
779 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
780 |
+
"""
|
781 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
782 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
783 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
784 |
+
down_block_types = ['DownEncoderBlock2D'] * len(block_out_channels)
|
785 |
+
up_block_types = ['UpDecoderBlock2D'] * len(block_out_channels)
|
786 |
+
|
787 |
+
config = dict(
|
788 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
789 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
790 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
791 |
+
down_block_types=tuple(down_block_types),
|
792 |
+
up_block_types=tuple(up_block_types),
|
793 |
+
block_out_channels=tuple(block_out_channels),
|
794 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
795 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
796 |
+
)
|
797 |
+
return config
|
798 |
+
|
799 |
+
|
800 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
801 |
+
keys = list(checkpoint.keys())
|
802 |
+
text_model_dict = {}
|
803 |
+
for key in keys:
|
804 |
+
if key.startswith('cond_stage_model.transformer'):
|
805 |
+
text_model_dict[
|
806 |
+
key[len('cond_stage_model.transformer.') :]
|
807 |
+
] = checkpoint[key]
|
808 |
+
return text_model_dict
|
809 |
+
|
810 |
+
|
811 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
812 |
+
# 嫌になるくらい違うぞ!
|
813 |
+
def convert_key(key):
|
814 |
+
if not key.startswith('cond_stage_model'):
|
815 |
+
return None
|
816 |
+
|
817 |
+
# common conversion
|
818 |
+
key = key.replace(
|
819 |
+
'cond_stage_model.model.transformer.', 'text_model.encoder.'
|
820 |
+
)
|
821 |
+
key = key.replace('cond_stage_model.model.', 'text_model.')
|
822 |
+
|
823 |
+
if 'resblocks' in key:
|
824 |
+
# resblocks conversion
|
825 |
+
key = key.replace('.resblocks.', '.layers.')
|
826 |
+
if '.ln_' in key:
|
827 |
+
key = key.replace('.ln_', '.layer_norm')
|
828 |
+
elif '.mlp.' in key:
|
829 |
+
key = key.replace('.c_fc.', '.fc1.')
|
830 |
+
key = key.replace('.c_proj.', '.fc2.')
|
831 |
+
elif '.attn.out_proj' in key:
|
832 |
+
key = key.replace('.attn.out_proj.', '.self_attn.out_proj.')
|
833 |
+
elif '.attn.in_proj' in key:
|
834 |
+
key = None # 特殊なので後で処理する
|
835 |
+
else:
|
836 |
+
raise ValueError(f'unexpected key in SD: {key}')
|
837 |
+
elif '.positional_embedding' in key:
|
838 |
+
key = key.replace(
|
839 |
+
'.positional_embedding',
|
840 |
+
'.embeddings.position_embedding.weight',
|
841 |
+
)
|
842 |
+
elif '.text_projection' in key:
|
843 |
+
key = None # 使われない???
|
844 |
+
elif '.logit_scale' in key:
|
845 |
+
key = None # 使われない???
|
846 |
+
elif '.token_embedding' in key:
|
847 |
+
key = key.replace(
|
848 |
+
'.token_embedding.weight', '.embeddings.token_embedding.weight'
|
849 |
+
)
|
850 |
+
elif '.ln_final' in key:
|
851 |
+
key = key.replace('.ln_final', '.final_layer_norm')
|
852 |
+
return key
|
853 |
+
|
854 |
+
keys = list(checkpoint.keys())
|
855 |
+
new_sd = {}
|
856 |
+
for key in keys:
|
857 |
+
# remove resblocks 23
|
858 |
+
if '.resblocks.23.' in key:
|
859 |
+
continue
|
860 |
+
new_key = convert_key(key)
|
861 |
+
if new_key is None:
|
862 |
+
continue
|
863 |
+
new_sd[new_key] = checkpoint[key]
|
864 |
+
|
865 |
+
# attnの変換
|
866 |
+
for key in keys:
|
867 |
+
if '.resblocks.23.' in key:
|
868 |
+
continue
|
869 |
+
if '.resblocks' in key and '.attn.in_proj_' in key:
|
870 |
+
# 三つに分割
|
871 |
+
values = torch.chunk(checkpoint[key], 3)
|
872 |
+
|
873 |
+
key_suffix = '.weight' if 'weight' in key else '.bias'
|
874 |
+
key_pfx = key.replace(
|
875 |
+
'cond_stage_model.model.transformer.resblocks.',
|
876 |
+
'text_model.encoder.layers.',
|
877 |
+
)
|
878 |
+
key_pfx = key_pfx.replace('_weight', '')
|
879 |
+
key_pfx = key_pfx.replace('_bias', '')
|
880 |
+
key_pfx = key_pfx.replace('.attn.in_proj', '.self_attn.')
|
881 |
+
new_sd[key_pfx + 'q_proj' + key_suffix] = values[0]
|
882 |
+
new_sd[key_pfx + 'k_proj' + key_suffix] = values[1]
|
883 |
+
new_sd[key_pfx + 'v_proj' + key_suffix] = values[2]
|
884 |
+
|
885 |
+
# position_idsの追加
|
886 |
+
new_sd['text_model.embeddings.position_ids'] = torch.Tensor(
|
887 |
+
[list(range(max_length))]
|
888 |
+
).to(torch.int64)
|
889 |
+
return new_sd
|
890 |
+
|
891 |
+
|
892 |
+
# endregion
|
893 |
+
|
894 |
+
|
895 |
+
# region Diffusers->StableDiffusion の変換コード
|
896 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
897 |
+
|
898 |
+
|
899 |
+
def conv_transformer_to_linear(checkpoint):
|
900 |
+
keys = list(checkpoint.keys())
|
901 |
+
tf_keys = ['proj_in.weight', 'proj_out.weight']
|
902 |
+
for key in keys:
|
903 |
+
if '.'.join(key.split('.')[-2:]) in tf_keys:
|
904 |
+
if checkpoint[key].ndim > 2:
|
905 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
906 |
+
|
907 |
+
|
908 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
909 |
+
unet_conversion_map = [
|
910 |
+
# (stable-diffusion, HF Diffusers)
|
911 |
+
('time_embed.0.weight', 'time_embedding.linear_1.weight'),
|
912 |
+
('time_embed.0.bias', 'time_embedding.linear_1.bias'),
|
913 |
+
('time_embed.2.weight', 'time_embedding.linear_2.weight'),
|
914 |
+
('time_embed.2.bias', 'time_embedding.linear_2.bias'),
|
915 |
+
('input_blocks.0.0.weight', 'conv_in.weight'),
|
916 |
+
('input_blocks.0.0.bias', 'conv_in.bias'),
|
917 |
+
('out.0.weight', 'conv_norm_out.weight'),
|
918 |
+
('out.0.bias', 'conv_norm_out.bias'),
|
919 |
+
('out.2.weight', 'conv_out.weight'),
|
920 |
+
('out.2.bias', 'conv_out.bias'),
|
921 |
+
]
|
922 |
+
|
923 |
+
unet_conversion_map_resnet = [
|
924 |
+
# (stable-diffusion, HF Diffusers)
|
925 |
+
('in_layers.0', 'norm1'),
|
926 |
+
('in_layers.2', 'conv1'),
|
927 |
+
('out_layers.0', 'norm2'),
|
928 |
+
('out_layers.3', 'conv2'),
|
929 |
+
('emb_layers.1', 'time_emb_proj'),
|
930 |
+
('skip_connection', 'conv_shortcut'),
|
931 |
+
]
|
932 |
+
|
933 |
+
unet_conversion_map_layer = []
|
934 |
+
for i in range(4):
|
935 |
+
# loop over downblocks/upblocks
|
936 |
+
|
937 |
+
for j in range(2):
|
938 |
+
# loop over resnets/attentions for downblocks
|
939 |
+
hf_down_res_prefix = f'down_blocks.{i}.resnets.{j}.'
|
940 |
+
sd_down_res_prefix = f'input_blocks.{3*i + j + 1}.0.'
|
941 |
+
unet_conversion_map_layer.append(
|
942 |
+
(sd_down_res_prefix, hf_down_res_prefix)
|
943 |
+
)
|
944 |
+
|
945 |
+
if i < 3:
|
946 |
+
# no attention layers in down_blocks.3
|
947 |
+
hf_down_atn_prefix = f'down_blocks.{i}.attentions.{j}.'
|
948 |
+
sd_down_atn_prefix = f'input_blocks.{3*i + j + 1}.1.'
|
949 |
+
unet_conversion_map_layer.append(
|
950 |
+
(sd_down_atn_prefix, hf_down_atn_prefix)
|
951 |
+
)
|
952 |
+
|
953 |
+
for j in range(3):
|
954 |
+
# loop over resnets/attentions for upblocks
|
955 |
+
hf_up_res_prefix = f'up_blocks.{i}.resnets.{j}.'
|
956 |
+
sd_up_res_prefix = f'output_blocks.{3*i + j}.0.'
|
957 |
+
unet_conversion_map_layer.append(
|
958 |
+
(sd_up_res_prefix, hf_up_res_prefix)
|
959 |
+
)
|
960 |
+
|
961 |
+
if i > 0:
|
962 |
+
# no attention layers in up_blocks.0
|
963 |
+
hf_up_atn_prefix = f'up_blocks.{i}.attentions.{j}.'
|
964 |
+
sd_up_atn_prefix = f'output_blocks.{3*i + j}.1.'
|
965 |
+
unet_conversion_map_layer.append(
|
966 |
+
(sd_up_atn_prefix, hf_up_atn_prefix)
|
967 |
+
)
|
968 |
+
|
969 |
+
if i < 3:
|
970 |
+
# no downsample in down_blocks.3
|
971 |
+
hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.conv.'
|
972 |
+
sd_downsample_prefix = f'input_blocks.{3*(i+1)}.0.op.'
|
973 |
+
unet_conversion_map_layer.append(
|
974 |
+
(sd_downsample_prefix, hf_downsample_prefix)
|
975 |
+
)
|
976 |
+
|
977 |
+
# no upsample in up_blocks.3
|
978 |
+
hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.'
|
979 |
+
sd_upsample_prefix = (
|
980 |
+
f'output_blocks.{3*i + 2}.{1 if i == 0 else 2}.'
|
981 |
+
)
|
982 |
+
unet_conversion_map_layer.append(
|
983 |
+
(sd_upsample_prefix, hf_upsample_prefix)
|
984 |
+
)
|
985 |
+
|
986 |
+
hf_mid_atn_prefix = 'mid_block.attentions.0.'
|
987 |
+
sd_mid_atn_prefix = 'middle_block.1.'
|
988 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
989 |
+
|
990 |
+
for j in range(2):
|
991 |
+
hf_mid_res_prefix = f'mid_block.resnets.{j}.'
|
992 |
+
sd_mid_res_prefix = f'middle_block.{2*j}.'
|
993 |
+
unet_conversion_map_layer.append(
|
994 |
+
(sd_mid_res_prefix, hf_mid_res_prefix)
|
995 |
+
)
|
996 |
+
|
997 |
+
# buyer beware: this is a *brittle* function,
|
998 |
+
# and correct output requires that all of these pieces interact in
|
999 |
+
# the exact order in which I have arranged them.
|
1000 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
1001 |
+
for sd_name, hf_name in unet_conversion_map:
|
1002 |
+
mapping[hf_name] = sd_name
|
1003 |
+
for k, v in mapping.items():
|
1004 |
+
if 'resnets' in k:
|
1005 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
1006 |
+
v = v.replace(hf_part, sd_part)
|
1007 |
+
mapping[k] = v
|
1008 |
+
for k, v in mapping.items():
|
1009 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
1010 |
+
v = v.replace(hf_part, sd_part)
|
1011 |
+
mapping[k] = v
|
1012 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
1013 |
+
|
1014 |
+
if v2:
|
1015 |
+
conv_transformer_to_linear(new_state_dict)
|
1016 |
+
|
1017 |
+
return new_state_dict
|
1018 |
+
|
1019 |
+
|
1020 |
+
# ================#
|
1021 |
+
# VAE Conversion #
|
1022 |
+
# ================#
|
1023 |
+
|
1024 |
+
|
1025 |
+
def reshape_weight_for_sd(w):
|
1026 |
+
# convert HF linear weights to SD conv2d weights
|
1027 |
+
return w.reshape(*w.shape, 1, 1)
|
1028 |
+
|
1029 |
+
|
1030 |
+
def convert_vae_state_dict(vae_state_dict):
|
1031 |
+
vae_conversion_map = [
|
1032 |
+
# (stable-diffusion, HF Diffusers)
|
1033 |
+
('nin_shortcut', 'conv_shortcut'),
|
1034 |
+
('norm_out', 'conv_norm_out'),
|
1035 |
+
('mid.attn_1.', 'mid_block.attentions.0.'),
|
1036 |
+
]
|
1037 |
+
|
1038 |
+
for i in range(4):
|
1039 |
+
# down_blocks have two resnets
|
1040 |
+
for j in range(2):
|
1041 |
+
hf_down_prefix = f'encoder.down_blocks.{i}.resnets.{j}.'
|
1042 |
+
sd_down_prefix = f'encoder.down.{i}.block.{j}.'
|
1043 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
1044 |
+
|
1045 |
+
if i < 3:
|
1046 |
+
hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.'
|
1047 |
+
sd_downsample_prefix = f'down.{i}.downsample.'
|
1048 |
+
vae_conversion_map.append(
|
1049 |
+
(sd_downsample_prefix, hf_downsample_prefix)
|
1050 |
+
)
|
1051 |
+
|
1052 |
+
hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.'
|
1053 |
+
sd_upsample_prefix = f'up.{3-i}.upsample.'
|
1054 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
1055 |
+
|
1056 |
+
# up_blocks have three resnets
|
1057 |
+
# also, up blocks in hf are numbered in reverse from sd
|
1058 |
+
for j in range(3):
|
1059 |
+
hf_up_prefix = f'decoder.up_blocks.{i}.resnets.{j}.'
|
1060 |
+
sd_up_prefix = f'decoder.up.{3-i}.block.{j}.'
|
1061 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
1062 |
+
|
1063 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
1064 |
+
for i in range(2):
|
1065 |
+
hf_mid_res_prefix = f'mid_block.resnets.{i}.'
|
1066 |
+
sd_mid_res_prefix = f'mid.block_{i+1}.'
|
1067 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
1068 |
+
|
1069 |
+
vae_conversion_map_attn = [
|
1070 |
+
# (stable-diffusion, HF Diffusers)
|
1071 |
+
('norm.', 'group_norm.'),
|
1072 |
+
('q.', 'query.'),
|
1073 |
+
('k.', 'key.'),
|
1074 |
+
('v.', 'value.'),
|
1075 |
+
('proj_out.', 'proj_attn.'),
|
1076 |
+
]
|
1077 |
+
|
1078 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
1079 |
+
for k, v in mapping.items():
|
1080 |
+
for sd_part, hf_part in vae_conversion_map:
|
1081 |
+
v = v.replace(hf_part, sd_part)
|
1082 |
+
mapping[k] = v
|
1083 |
+
for k, v in mapping.items():
|
1084 |
+
if 'attentions' in k:
|
1085 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
1086 |
+
v = v.replace(hf_part, sd_part)
|
1087 |
+
mapping[k] = v
|
1088 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
1089 |
+
weights_to_convert = ['q', 'k', 'v', 'proj_out']
|
1090 |
+
for k, v in new_state_dict.items():
|
1091 |
+
for weight_name in weights_to_convert:
|
1092 |
+
if f'mid.attn_1.{weight_name}.weight' in k:
|
1093 |
+
# print(f"Reshaping {k} for SD format")
|
1094 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
1095 |
+
|
1096 |
+
return new_state_dict
|
1097 |
+
|
1098 |
+
|
1099 |
+
# endregion
|
1100 |
+
|
1101 |
+
# region 自作のモデル読み書きなど
|
1102 |
+
|
1103 |
+
|
1104 |
+
def is_safetensors(path):
|
1105 |
+
return os.path.splitext(path)[1].lower() == '.safetensors'
|
1106 |
+
|
1107 |
+
|
1108 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
1109 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
1110 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
1111 |
+
(
|
1112 |
+
'cond_stage_model.transformer.embeddings.',
|
1113 |
+
'cond_stage_model.transformer.text_model.embeddings.',
|
1114 |
+
),
|
1115 |
+
(
|
1116 |
+
'cond_stage_model.transformer.encoder.',
|
1117 |
+
'cond_stage_model.transformer.text_model.encoder.',
|
1118 |
+
),
|
1119 |
+
(
|
1120 |
+
'cond_stage_model.transformer.final_layer_norm.',
|
1121 |
+
'cond_stage_model.transformer.text_model.final_layer_norm.',
|
1122 |
+
),
|
1123 |
+
]
|
1124 |
+
|
1125 |
+
if is_safetensors(ckpt_path):
|
1126 |
+
checkpoint = None
|
1127 |
+
state_dict = load_file(ckpt_path, 'cpu')
|
1128 |
+
else:
|
1129 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
1130 |
+
if 'state_dict' in checkpoint:
|
1131 |
+
state_dict = checkpoint['state_dict']
|
1132 |
+
else:
|
1133 |
+
state_dict = checkpoint
|
1134 |
+
checkpoint = None
|
1135 |
+
|
1136 |
+
key_reps = []
|
1137 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
1138 |
+
for key in state_dict.keys():
|
1139 |
+
if key.startswith(rep_from):
|
1140 |
+
new_key = rep_to + key[len(rep_from) :]
|
1141 |
+
key_reps.append((key, new_key))
|
1142 |
+
|
1143 |
+
for key, new_key in key_reps:
|
1144 |
+
state_dict[new_key] = state_dict[key]
|
1145 |
+
del state_dict[key]
|
1146 |
+
|
1147 |
+
return checkpoint, state_dict
|
1148 |
+
|
1149 |
+
|
1150 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
1151 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
1152 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1153 |
+
if dtype is not None:
|
1154 |
+
for k, v in state_dict.items():
|
1155 |
+
if type(v) is torch.Tensor:
|
1156 |
+
state_dict[k] = v.to(dtype)
|
1157 |
+
|
1158 |
+
# Convert the UNet2DConditionModel model.
|
1159 |
+
unet_config = create_unet_diffusers_config(v2)
|
1160 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
1161 |
+
v2, state_dict, unet_config
|
1162 |
+
)
|
1163 |
+
|
1164 |
+
unet = UNet2DConditionModel(**unet_config)
|
1165 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
1166 |
+
print('loading u-net:', info)
|
1167 |
+
|
1168 |
+
# Convert the VAE model.
|
1169 |
+
vae_config = create_vae_diffusers_config()
|
1170 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
1171 |
+
state_dict, vae_config
|
1172 |
+
)
|
1173 |
+
|
1174 |
+
vae = AutoencoderKL(**vae_config)
|
1175 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
1176 |
+
print('loadint vae:', info)
|
1177 |
+
|
1178 |
+
# convert text_model
|
1179 |
+
if v2:
|
1180 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(
|
1181 |
+
state_dict, 77
|
1182 |
+
)
|
1183 |
+
cfg = CLIPTextConfig(
|
1184 |
+
vocab_size=49408,
|
1185 |
+
hidden_size=1024,
|
1186 |
+
intermediate_size=4096,
|
1187 |
+
num_hidden_layers=23,
|
1188 |
+
num_attention_heads=16,
|
1189 |
+
max_position_embeddings=77,
|
1190 |
+
hidden_act='gelu',
|
1191 |
+
layer_norm_eps=1e-05,
|
1192 |
+
dropout=0.0,
|
1193 |
+
attention_dropout=0.0,
|
1194 |
+
initializer_range=0.02,
|
1195 |
+
initializer_factor=1.0,
|
1196 |
+
pad_token_id=1,
|
1197 |
+
bos_token_id=0,
|
1198 |
+
eos_token_id=2,
|
1199 |
+
model_type='clip_text_model',
|
1200 |
+
projection_dim=512,
|
1201 |
+
torch_dtype='float32',
|
1202 |
+
transformers_version='4.25.0.dev0',
|
1203 |
+
)
|
1204 |
+
text_model = CLIPTextModel._from_config(cfg)
|
1205 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
1206 |
+
else:
|
1207 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(
|
1208 |
+
state_dict
|
1209 |
+
)
|
1210 |
+
text_model = CLIPTextModel.from_pretrained(
|
1211 |
+
'openai/clip-vit-large-patch14'
|
1212 |
+
)
|
1213 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
1214 |
+
print('loading text encoder:', info)
|
1215 |
+
|
1216 |
+
return text_model, vae, unet
|
1217 |
+
|
1218 |
+
|
1219 |
+
def convert_text_encoder_state_dict_to_sd_v2(
|
1220 |
+
checkpoint, make_dummy_weights=False
|
1221 |
+
):
|
1222 |
+
def convert_key(key):
|
1223 |
+
# position_idsの除去
|
1224 |
+
if '.position_ids' in key:
|
1225 |
+
return None
|
1226 |
+
|
1227 |
+
# common
|
1228 |
+
key = key.replace('text_model.encoder.', 'transformer.')
|
1229 |
+
key = key.replace('text_model.', '')
|
1230 |
+
if 'layers' in key:
|
1231 |
+
# resblocks conversion
|
1232 |
+
key = key.replace('.layers.', '.resblocks.')
|
1233 |
+
if '.layer_norm' in key:
|
1234 |
+
key = key.replace('.layer_norm', '.ln_')
|
1235 |
+
elif '.mlp.' in key:
|
1236 |
+
key = key.replace('.fc1.', '.c_fc.')
|
1237 |
+
key = key.replace('.fc2.', '.c_proj.')
|
1238 |
+
elif '.self_attn.out_proj' in key:
|
1239 |
+
key = key.replace('.self_attn.out_proj.', '.attn.out_proj.')
|
1240 |
+
elif '.self_attn.' in key:
|
1241 |
+
key = None # 特殊なので後で処理する
|
1242 |
+
else:
|
1243 |
+
raise ValueError(f'unexpected key in DiffUsers model: {key}')
|
1244 |
+
elif '.position_embedding' in key:
|
1245 |
+
key = key.replace(
|
1246 |
+
'embeddings.position_embedding.weight', 'positional_embedding'
|
1247 |
+
)
|
1248 |
+
elif '.token_embedding' in key:
|
1249 |
+
key = key.replace(
|
1250 |
+
'embeddings.token_embedding.weight', 'token_embedding.weight'
|
1251 |
+
)
|
1252 |
+
elif 'final_layer_norm' in key:
|
1253 |
+
key = key.replace('final_layer_norm', 'ln_final')
|
1254 |
+
return key
|
1255 |
+
|
1256 |
+
keys = list(checkpoint.keys())
|
1257 |
+
new_sd = {}
|
1258 |
+
for key in keys:
|
1259 |
+
new_key = convert_key(key)
|
1260 |
+
if new_key is None:
|
1261 |
+
continue
|
1262 |
+
new_sd[new_key] = checkpoint[key]
|
1263 |
+
|
1264 |
+
# attnの変換
|
1265 |
+
for key in keys:
|
1266 |
+
if 'layers' in key and 'q_proj' in key:
|
1267 |
+
# 三つを結合
|
1268 |
+
key_q = key
|
1269 |
+
key_k = key.replace('q_proj', 'k_proj')
|
1270 |
+
key_v = key.replace('q_proj', 'v_proj')
|
1271 |
+
|
1272 |
+
value_q = checkpoint[key_q]
|
1273 |
+
value_k = checkpoint[key_k]
|
1274 |
+
value_v = checkpoint[key_v]
|
1275 |
+
value = torch.cat([value_q, value_k, value_v])
|
1276 |
+
|
1277 |
+
new_key = key.replace(
|
1278 |
+
'text_model.encoder.layers.', 'transformer.resblocks.'
|
1279 |
+
)
|
1280 |
+
new_key = new_key.replace('.self_attn.q_proj.', '.attn.in_proj_')
|
1281 |
+
new_sd[new_key] = value
|
1282 |
+
|
1283 |
+
# 最後の層などを捏造するか
|
1284 |
+
if make_dummy_weights:
|
1285 |
+
print(
|
1286 |
+
'make dummy weights for resblock.23, text_projection and logit scale.'
|
1287 |
+
)
|
1288 |
+
keys = list(new_sd.keys())
|
1289 |
+
for key in keys:
|
1290 |
+
if key.startswith('transformer.resblocks.22.'):
|
1291 |
+
new_sd[key.replace('.22.', '.23.')] = new_sd[
|
1292 |
+
key
|
1293 |
+
].clone() # copyしないとsafetensorsの保存で落ちる
|
1294 |
+
|
1295 |
+
# Diffusersに含まれない重みを作っておく
|
1296 |
+
new_sd['text_projection'] = torch.ones(
|
1297 |
+
(1024, 1024),
|
1298 |
+
dtype=new_sd[keys[0]].dtype,
|
1299 |
+
device=new_sd[keys[0]].device,
|
1300 |
+
)
|
1301 |
+
new_sd['logit_scale'] = torch.tensor(1)
|
1302 |
+
|
1303 |
+
return new_sd
|
1304 |
+
|
1305 |
+
|
1306 |
+
def save_stable_diffusion_checkpoint(
|
1307 |
+
v2,
|
1308 |
+
output_file,
|
1309 |
+
text_encoder,
|
1310 |
+
unet,
|
1311 |
+
ckpt_path,
|
1312 |
+
epochs,
|
1313 |
+
steps,
|
1314 |
+
save_dtype=None,
|
1315 |
+
vae=None,
|
1316 |
+
):
|
1317 |
+
if ckpt_path is not None:
|
1318 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1319 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(
|
1320 |
+
ckpt_path
|
1321 |
+
)
|
1322 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
1323 |
+
checkpoint = {}
|
1324 |
+
strict = False
|
1325 |
+
else:
|
1326 |
+
strict = True
|
1327 |
+
if 'state_dict' in state_dict:
|
1328 |
+
del state_dict['state_dict']
|
1329 |
+
else:
|
1330 |
+
# 新しく作る
|
1331 |
+
assert (
|
1332 |
+
vae is not None
|
1333 |
+
), 'VAE is required to save a checkpoint without a given checkpoint'
|
1334 |
+
checkpoint = {}
|
1335 |
+
state_dict = {}
|
1336 |
+
strict = False
|
1337 |
+
|
1338 |
+
def update_sd(prefix, sd):
|
1339 |
+
for k, v in sd.items():
|
1340 |
+
key = prefix + k
|
1341 |
+
assert (
|
1342 |
+
not strict or key in state_dict
|
1343 |
+
), f'Illegal key in save SD: {key}'
|
1344 |
+
if save_dtype is not None:
|
1345 |
+
v = v.detach().clone().to('cpu').to(save_dtype)
|
1346 |
+
state_dict[key] = v
|
1347 |
+
|
1348 |
+
# Convert the UNet model
|
1349 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1350 |
+
update_sd('model.diffusion_model.', unet_state_dict)
|
1351 |
+
|
1352 |
+
# Convert the text encoder model
|
1353 |
+
if v2:
|
1354 |
+
make_dummy = (
|
1355 |
+
ckpt_path is None
|
1356 |
+
) # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
1357 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(
|
1358 |
+
text_encoder.state_dict(), make_dummy
|
1359 |
+
)
|
1360 |
+
update_sd('cond_stage_model.model.', text_enc_dict)
|
1361 |
+
else:
|
1362 |
+
text_enc_dict = text_encoder.state_dict()
|
1363 |
+
update_sd('cond_stage_model.transformer.', text_enc_dict)
|
1364 |
+
|
1365 |
+
# Convert the VAE
|
1366 |
+
if vae is not None:
|
1367 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1368 |
+
update_sd('first_stage_model.', vae_dict)
|
1369 |
+
|
1370 |
+
# Put together new checkpoint
|
1371 |
+
key_count = len(state_dict.keys())
|
1372 |
+
new_ckpt = {'state_dict': state_dict}
|
1373 |
+
|
1374 |
+
if 'epoch' in checkpoint:
|
1375 |
+
epochs += checkpoint['epoch']
|
1376 |
+
if 'global_step' in checkpoint:
|
1377 |
+
steps += checkpoint['global_step']
|
1378 |
+
|
1379 |
+
new_ckpt['epoch'] = epochs
|
1380 |
+
new_ckpt['global_step'] = steps
|
1381 |
+
|
1382 |
+
if is_safetensors(output_file):
|
1383 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1384 |
+
save_file(state_dict, output_file)
|
1385 |
+
else:
|
1386 |
+
torch.save(new_ckpt, output_file)
|
1387 |
+
|
1388 |
+
return key_count
|
1389 |
+
|
1390 |
+
|
1391 |
+
def save_diffusers_checkpoint(
|
1392 |
+
v2,
|
1393 |
+
output_dir,
|
1394 |
+
text_encoder,
|
1395 |
+
unet,
|
1396 |
+
pretrained_model_name_or_path,
|
1397 |
+
vae=None,
|
1398 |
+
use_safetensors=False,
|
1399 |
+
):
|
1400 |
+
if pretrained_model_name_or_path is None:
|
1401 |
+
# load default settings for v1/v2
|
1402 |
+
if v2:
|
1403 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1404 |
+
else:
|
1405 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1406 |
+
|
1407 |
+
scheduler = DDIMScheduler.from_pretrained(
|
1408 |
+
pretrained_model_name_or_path, subfolder='scheduler'
|
1409 |
+
)
|
1410 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
1411 |
+
pretrained_model_name_or_path, subfolder='tokenizer'
|
1412 |
+
)
|
1413 |
+
if vae is None:
|
1414 |
+
vae = AutoencoderKL.from_pretrained(
|
1415 |
+
pretrained_model_name_or_path, subfolder='vae'
|
1416 |
+
)
|
1417 |
+
|
1418 |
+
pipeline = StableDiffusionPipeline(
|
1419 |
+
unet=unet,
|
1420 |
+
text_encoder=text_encoder,
|
1421 |
+
vae=vae,
|
1422 |
+
scheduler=scheduler,
|
1423 |
+
tokenizer=tokenizer,
|
1424 |
+
safety_checker=None,
|
1425 |
+
feature_extractor=None,
|
1426 |
+
requires_safety_checker=None,
|
1427 |
+
)
|
1428 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1429 |
+
|
1430 |
+
|
1431 |
+
VAE_PREFIX = 'first_stage_model.'
|
1432 |
+
|
1433 |
+
|
1434 |
+
def load_vae(vae_id, dtype):
|
1435 |
+
print(f'load VAE: {vae_id}')
|
1436 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1437 |
+
# Diffusers local/remote
|
1438 |
+
try:
|
1439 |
+
vae = AutoencoderKL.from_pretrained(
|
1440 |
+
vae_id, subfolder=None, torch_dtype=dtype
|
1441 |
+
)
|
1442 |
+
except EnvironmentError as e:
|
1443 |
+
print(f'exception occurs in loading vae: {e}')
|
1444 |
+
print("retry with subfolder='vae'")
|
1445 |
+
vae = AutoencoderKL.from_pretrained(
|
1446 |
+
vae_id, subfolder='vae', torch_dtype=dtype
|
1447 |
+
)
|
1448 |
+
return vae
|
1449 |
+
|
1450 |
+
# local
|
1451 |
+
vae_config = create_vae_diffusers_config()
|
1452 |
+
|
1453 |
+
if vae_id.endswith('.bin'):
|
1454 |
+
# SD 1.5 VAE on Huggingface
|
1455 |
+
vae_sd = torch.load(vae_id, map_location='cpu')
|
1456 |
+
converted_vae_checkpoint = vae_sd
|
1457 |
+
else:
|
1458 |
+
# StableDiffusion
|
1459 |
+
vae_model = torch.load(vae_id, map_location='cpu')
|
1460 |
+
vae_sd = vae_model['state_dict']
|
1461 |
+
|
1462 |
+
# vae only or full model
|
1463 |
+
full_model = False
|
1464 |
+
for vae_key in vae_sd:
|
1465 |
+
if vae_key.startswith(VAE_PREFIX):
|
1466 |
+
full_model = True
|
1467 |
+
break
|
1468 |
+
if not full_model:
|
1469 |
+
sd = {}
|
1470 |
+
for key, value in vae_sd.items():
|
1471 |
+
sd[VAE_PREFIX + key] = value
|
1472 |
+
vae_sd = sd
|
1473 |
+
del sd
|
1474 |
+
|
1475 |
+
# Convert the VAE model.
|
1476 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
1477 |
+
vae_sd, vae_config
|
1478 |
+
)
|
1479 |
+
|
1480 |
+
vae = AutoencoderKL(**vae_config)
|
1481 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1482 |
+
return vae
|
1483 |
+
|
1484 |
+
|
1485 |
+
def get_epoch_ckpt_name(use_safetensors, epoch):
|
1486 |
+
return f'epoch-{epoch:06d}' + (
|
1487 |
+
'.safetensors' if use_safetensors else '.ckpt'
|
1488 |
+
)
|
1489 |
+
|
1490 |
+
|
1491 |
+
def get_last_ckpt_name(use_safetensors):
|
1492 |
+
return f'last' + ('.safetensors' if use_safetensors else '.ckpt')
|
1493 |
+
|
1494 |
+
|
1495 |
+
# endregion
|
1496 |
+
|
1497 |
+
|
1498 |
+
def make_bucket_resolutions(
|
1499 |
+
max_reso, min_size=256, max_size=1024, divisible=64
|
1500 |
+
):
|
1501 |
+
max_width, max_height = max_reso
|
1502 |
+
max_area = (max_width // divisible) * (max_height // divisible)
|
1503 |
+
|
1504 |
+
resos = set()
|
1505 |
+
|
1506 |
+
size = int(math.sqrt(max_area)) * divisible
|
1507 |
+
resos.add((size, size))
|
1508 |
+
|
1509 |
+
size = min_size
|
1510 |
+
while size <= max_size:
|
1511 |
+
width = size
|
1512 |
+
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
1513 |
+
resos.add((width, height))
|
1514 |
+
resos.add((height, width))
|
1515 |
+
|
1516 |
+
# # make additional resos
|
1517 |
+
# if width >= height and width - divisible >= min_size:
|
1518 |
+
# resos.add((width - divisible, height))
|
1519 |
+
# resos.add((height, width - divisible))
|
1520 |
+
# if height >= width and height - divisible >= min_size:
|
1521 |
+
# resos.add((width, height - divisible))
|
1522 |
+
# resos.add((height - divisible, width))
|
1523 |
+
|
1524 |
+
size += divisible
|
1525 |
+
|
1526 |
+
resos = list(resos)
|
1527 |
+
resos.sort()
|
1528 |
+
|
1529 |
+
aspect_ratios = [w / h for w, h in resos]
|
1530 |
+
return resos, aspect_ratios
|
1531 |
+
|
1532 |
+
|
1533 |
+
if __name__ == '__main__':
|
1534 |
+
resos, aspect_ratios = make_bucket_resolutions((512, 768))
|
1535 |
+
print(len(resos))
|
1536 |
+
print(resos)
|
1537 |
+
print(aspect_ratios)
|
1538 |
+
|
1539 |
+
ars = set()
|
1540 |
+
for ar in aspect_ratios:
|
1541 |
+
if ar in ars:
|
1542 |
+
print('error! duplicate ar:', ar)
|
1543 |
+
ars.add(ar)
|
StableTuner_RunPod_Fix/trainer.py
ADDED
@@ -0,0 +1,1750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright 2022 HuggingFace, ShivamShrirao
|
3 |
+
|
4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
you may not use this file except in compliance with the License.
|
6 |
+
You may obtain a copy of the License at
|
7 |
+
|
8 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
Unless required by applicable law or agreed to in writing, software
|
11 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
See the License for the specific language governing permissions and
|
14 |
+
limitations under the License.
|
15 |
+
"""
|
16 |
+
import keyboard
|
17 |
+
import gradio as gr
|
18 |
+
import argparse
|
19 |
+
import random
|
20 |
+
import hashlib
|
21 |
+
import itertools
|
22 |
+
import json
|
23 |
+
import math
|
24 |
+
import os
|
25 |
+
import copy
|
26 |
+
from contextlib import nullcontext
|
27 |
+
from pathlib import Path
|
28 |
+
import shutil
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F
|
31 |
+
import torch.utils.checkpoint
|
32 |
+
import numpy as np
|
33 |
+
from accelerate import Accelerator
|
34 |
+
from accelerate.logging import get_logger
|
35 |
+
from accelerate.utils import set_seed
|
36 |
+
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel,DiffusionPipeline, DPMSolverMultistepScheduler,EulerDiscreteScheduler
|
37 |
+
from diffusers.optimization import get_scheduler
|
38 |
+
from torchvision.transforms import functional
|
39 |
+
from tqdm.auto import tqdm
|
40 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
41 |
+
from typing import Dict, List, Generator, Tuple
|
42 |
+
from PIL import Image, ImageFile
|
43 |
+
from diffusers.utils.import_utils import is_xformers_available
|
44 |
+
from trainer_util import *
|
45 |
+
from dataloaders_util import *
|
46 |
+
from discriminator import Discriminator2D
|
47 |
+
from lion_pytorch import Lion
|
48 |
+
logger = get_logger(__name__)
|
49 |
+
def parse_args():
|
50 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
51 |
+
parser.add_argument(
|
52 |
+
"--revision",
|
53 |
+
type=str,
|
54 |
+
default=None,
|
55 |
+
required=False,
|
56 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
57 |
+
)
|
58 |
+
|
59 |
+
parser.add_argument(
|
60 |
+
"--attention",
|
61 |
+
type=str,
|
62 |
+
choices=["xformers", "flash_attention"],
|
63 |
+
default="xformers",
|
64 |
+
help="Type of attention to use."
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--model_variant",
|
68 |
+
type=str,
|
69 |
+
default='base',
|
70 |
+
required=False,
|
71 |
+
help="Train Base/Inpaint/Depth2Img",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--aspect_mode",
|
75 |
+
type=str,
|
76 |
+
default='dynamic',
|
77 |
+
required=False,
|
78 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--aspect_mode_action_preference",
|
82 |
+
type=str,
|
83 |
+
default='add',
|
84 |
+
required=False,
|
85 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
86 |
+
)
|
87 |
+
parser.add_argument('--use_lion',default=False,action="store_true", help='Use the new LION optimizer')
|
88 |
+
parser.add_argument('--use_ema',default=False,action="store_true", help='Use EMA for finetuning')
|
89 |
+
parser.add_argument('--clip_penultimate',default=False,action="store_true", help='Use penultimate CLIP layer for text embedding')
|
90 |
+
parser.add_argument("--conditional_dropout", type=float, default=None,required=False, help="Conditional dropout probability")
|
91 |
+
parser.add_argument('--disable_cudnn_benchmark', default=False, action="store_true")
|
92 |
+
parser.add_argument('--use_text_files_as_captions', default=False, action="store_true")
|
93 |
+
|
94 |
+
parser.add_argument(
|
95 |
+
"--sample_from_batch",
|
96 |
+
type=int,
|
97 |
+
default=0,
|
98 |
+
help=("Number of prompts to sample from the batch for inference"),
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--flatten_sample_folder",
|
102 |
+
default=True,
|
103 |
+
action="store_true",
|
104 |
+
help="Will save samples in one folder instead of per-epoch",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--stop_text_encoder_training",
|
108 |
+
type=int,
|
109 |
+
default=999999999999999,
|
110 |
+
help=("The epoch at which the text_encoder is no longer trained"),
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--use_bucketing",
|
114 |
+
default=False,
|
115 |
+
action="store_true",
|
116 |
+
help="Will save and generate samples before training",
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"--regenerate_latent_cache",
|
120 |
+
default=False,
|
121 |
+
action="store_true",
|
122 |
+
help="Will save and generate samples before training",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--sample_on_training_start",
|
126 |
+
default=False,
|
127 |
+
action="store_true",
|
128 |
+
help="Will save and generate samples before training",
|
129 |
+
)
|
130 |
+
|
131 |
+
parser.add_argument(
|
132 |
+
"--add_class_images_to_dataset",
|
133 |
+
default=False,
|
134 |
+
action="store_true",
|
135 |
+
help="will generate and add class images to the dataset without using prior reservation in training",
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--auto_balance_concept_datasets",
|
139 |
+
default=False,
|
140 |
+
action="store_true",
|
141 |
+
help="will balance the number of images in each concept dataset to match the minimum number of images in any concept dataset",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--sample_aspect_ratios",
|
145 |
+
default=False,
|
146 |
+
action="store_true",
|
147 |
+
help="sample different aspect ratios for each image",
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--dataset_repeats",
|
151 |
+
type=int,
|
152 |
+
default=1,
|
153 |
+
help="repeat the dataset this many times",
|
154 |
+
)
|
155 |
+
parser.add_argument(
|
156 |
+
"--save_every_n_epoch",
|
157 |
+
type=int,
|
158 |
+
default=1,
|
159 |
+
help="save on epoch finished",
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--pretrained_model_name_or_path",
|
163 |
+
type=str,
|
164 |
+
default=None,
|
165 |
+
required=True,
|
166 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
167 |
+
)
|
168 |
+
parser.add_argument(
|
169 |
+
"--pretrained_vae_name_or_path",
|
170 |
+
type=str,
|
171 |
+
default=None,
|
172 |
+
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--tokenizer_name",
|
176 |
+
type=str,
|
177 |
+
default=None,
|
178 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--instance_data_dir",
|
182 |
+
type=str,
|
183 |
+
default=None,
|
184 |
+
help="A folder containing the training data of instance images.",
|
185 |
+
)
|
186 |
+
parser.add_argument(
|
187 |
+
"--class_data_dir",
|
188 |
+
type=str,
|
189 |
+
default=None,
|
190 |
+
help="A folder containing the training data of class images.",
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--instance_prompt",
|
194 |
+
type=str,
|
195 |
+
default=None,
|
196 |
+
help="The prompt with identifier specifying the instance",
|
197 |
+
)
|
198 |
+
parser.add_argument(
|
199 |
+
"--class_prompt",
|
200 |
+
type=str,
|
201 |
+
default=None,
|
202 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
203 |
+
)
|
204 |
+
parser.add_argument(
|
205 |
+
"--save_sample_prompt",
|
206 |
+
type=str,
|
207 |
+
default=None,
|
208 |
+
help="The prompt used to generate sample outputs to save.",
|
209 |
+
)
|
210 |
+
parser.add_argument(
|
211 |
+
"--n_save_sample",
|
212 |
+
type=int,
|
213 |
+
default=4,
|
214 |
+
help="The number of samples to save.",
|
215 |
+
)
|
216 |
+
parser.add_argument(
|
217 |
+
"--sample_height",
|
218 |
+
type=int,
|
219 |
+
default=512,
|
220 |
+
help="The number of samples to save.",
|
221 |
+
)
|
222 |
+
parser.add_argument(
|
223 |
+
"--sample_width",
|
224 |
+
type=int,
|
225 |
+
default=512,
|
226 |
+
help="The number of samples to save.",
|
227 |
+
)
|
228 |
+
parser.add_argument(
|
229 |
+
"--save_guidance_scale",
|
230 |
+
type=float,
|
231 |
+
default=7.5,
|
232 |
+
help="CFG for save sample.",
|
233 |
+
)
|
234 |
+
parser.add_argument(
|
235 |
+
"--save_infer_steps",
|
236 |
+
type=int,
|
237 |
+
default=30,
|
238 |
+
help="The number of inference steps for save sample.",
|
239 |
+
)
|
240 |
+
parser.add_argument(
|
241 |
+
"--with_prior_preservation",
|
242 |
+
default=False,
|
243 |
+
action="store_true",
|
244 |
+
help="Flag to add prior preservation loss.",
|
245 |
+
)
|
246 |
+
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
|
247 |
+
parser.add_argument(
|
248 |
+
"--with_offset_noise",
|
249 |
+
default=False,
|
250 |
+
action="store_true",
|
251 |
+
help="Flag to offset noise applied to latents.",
|
252 |
+
)
|
253 |
+
|
254 |
+
parser.add_argument("--offset_noise_weight", type=float, default=0.1, help="The weight of offset noise applied during training.")
|
255 |
+
parser.add_argument(
|
256 |
+
"--num_class_images",
|
257 |
+
type=int,
|
258 |
+
default=100,
|
259 |
+
help=(
|
260 |
+
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
261 |
+
" sampled with class_prompt."
|
262 |
+
),
|
263 |
+
)
|
264 |
+
parser.add_argument(
|
265 |
+
"--output_dir",
|
266 |
+
type=str,
|
267 |
+
default="text-inversion-model",
|
268 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
269 |
+
)
|
270 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
271 |
+
parser.add_argument(
|
272 |
+
"--resolution",
|
273 |
+
type=int,
|
274 |
+
default=512,
|
275 |
+
help=(
|
276 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
277 |
+
" resolution"
|
278 |
+
),
|
279 |
+
)
|
280 |
+
parser.add_argument(
|
281 |
+
"--center_crop", default=False, action="store_true", help="Whether to center crop images before resizing to resolution"
|
282 |
+
)
|
283 |
+
parser.add_argument("--train_text_encoder", default=False, action="store_true", help="Whether to train the text encoder")
|
284 |
+
parser.add_argument(
|
285 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
286 |
+
)
|
287 |
+
parser.add_argument(
|
288 |
+
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
289 |
+
)
|
290 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
291 |
+
parser.add_argument(
|
292 |
+
"--max_train_steps",
|
293 |
+
type=int,
|
294 |
+
default=None,
|
295 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
296 |
+
)
|
297 |
+
parser.add_argument(
|
298 |
+
"--gradient_accumulation_steps",
|
299 |
+
type=int,
|
300 |
+
default=1,
|
301 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
302 |
+
)
|
303 |
+
parser.add_argument(
|
304 |
+
"--gradient_checkpointing",
|
305 |
+
default=False,
|
306 |
+
action="store_true",
|
307 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
308 |
+
)
|
309 |
+
parser.add_argument(
|
310 |
+
"--learning_rate",
|
311 |
+
type=float,
|
312 |
+
default=5e-6,
|
313 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--scale_lr",
|
317 |
+
action="store_true",
|
318 |
+
default=False,
|
319 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
320 |
+
)
|
321 |
+
parser.add_argument(
|
322 |
+
"--lr_scheduler",
|
323 |
+
type=str,
|
324 |
+
default="constant",
|
325 |
+
help=(
|
326 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
327 |
+
' "constant", "constant_with_warmup"]'
|
328 |
+
),
|
329 |
+
)
|
330 |
+
parser.add_argument(
|
331 |
+
"--lr_warmup_steps", type=float, default=500, help="Number of steps for the warmup in the lr scheduler."
|
332 |
+
)
|
333 |
+
parser.add_argument(
|
334 |
+
"--use_8bit_adam", default=False, action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
335 |
+
)
|
336 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
337 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
338 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
339 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
340 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
341 |
+
parser.add_argument("--push_to_hub", default=False, action="store_true", help="Whether or not to push the model to the Hub.")
|
342 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
343 |
+
parser.add_argument(
|
344 |
+
"--hub_model_id",
|
345 |
+
type=str,
|
346 |
+
default=None,
|
347 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
348 |
+
)
|
349 |
+
parser.add_argument(
|
350 |
+
"--logging_dir",
|
351 |
+
type=str,
|
352 |
+
default="logs",
|
353 |
+
help=(
|
354 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
355 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
356 |
+
),
|
357 |
+
)
|
358 |
+
parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
|
359 |
+
parser.add_argument("--sample_step_interval", type=int, default=100000000000000, help="Sample images every N steps.")
|
360 |
+
parser.add_argument(
|
361 |
+
"--mixed_precision",
|
362 |
+
type=str,
|
363 |
+
default="no",
|
364 |
+
choices=["no", "fp16", "bf16","tf32"],
|
365 |
+
help=(
|
366 |
+
"Whether to use mixed precision. Choose"
|
367 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
368 |
+
"and an Nvidia Ampere GPU."
|
369 |
+
),
|
370 |
+
)
|
371 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
372 |
+
parser.add_argument(
|
373 |
+
"--concepts_list",
|
374 |
+
type=str,
|
375 |
+
default=None,
|
376 |
+
help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
|
377 |
+
)
|
378 |
+
parser.add_argument("--save_sample_controlled_seed", type=int, action='append', help="Set a seed for an extra sample image to be constantly saved.")
|
379 |
+
parser.add_argument("--detect_full_drive", default=True, action="store_true", help="Delete checkpoints when the drive is full.")
|
380 |
+
parser.add_argument("--send_telegram_updates", default=False, action="store_true", help="Send Telegram updates.")
|
381 |
+
parser.add_argument("--telegram_chat_id", type=str, default="0", help="Telegram chat ID.")
|
382 |
+
parser.add_argument("--telegram_token", type=str, default="0", help="Telegram token.")
|
383 |
+
parser.add_argument("--use_deepspeed_adam", default=False, action="store_true", help="Use experimental DeepSpeed Adam 8.")
|
384 |
+
parser.add_argument('--append_sample_controlled_seed_action', action='append')
|
385 |
+
parser.add_argument('--add_sample_prompt', type=str, action='append')
|
386 |
+
parser.add_argument('--use_image_names_as_captions', default=False, action="store_true")
|
387 |
+
parser.add_argument('--shuffle_captions', default=False, action="store_true")
|
388 |
+
parser.add_argument("--masked_training", default=False, required=False, action='store_true', help="Whether to mask parts of the image during training")
|
389 |
+
parser.add_argument("--normalize_masked_area_loss", default=False, required=False, action='store_true', help="Normalize the loss, to make it independent of the size of the masked area")
|
390 |
+
parser.add_argument("--unmasked_probability", type=float, default=1, required=False, help="Probability of training a step without a mask")
|
391 |
+
parser.add_argument("--max_denoising_strength", type=float, default=1, required=False, help="Max denoising steps to train on")
|
392 |
+
parser.add_argument('--add_mask_prompt', type=str, default=None, action="append", dest="mask_prompts", help="Prompt for automatic mask creation")
|
393 |
+
parser.add_argument('--with_gan', default=False, action="store_true", help="Use GAN (experimental)")
|
394 |
+
parser.add_argument("--gan_weight", type=float, default=0.2, required=False, help="Strength of effect GAN has on training")
|
395 |
+
parser.add_argument("--gan_warmup", type=float, default=0, required=False, help="Slowly increases GAN weight from zero over this many steps, useful when initializing a GAN discriminator from scratch")
|
396 |
+
parser.add_argument('--discriminator_config', default="configs/discriminator_large.json", help="Location of config file to use when initializing a new GAN discriminator")
|
397 |
+
parser.add_argument('--sample_from_ema', default=True, action="store_true", help="Generate sample images using the EMA model")
|
398 |
+
parser.add_argument('--run_name', type=str, default=None, help="Adds a custom identifier to the sample and checkpoint directories")
|
399 |
+
args = parser.parse_args()
|
400 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
401 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
402 |
+
args.local_rank = env_local_rank
|
403 |
+
|
404 |
+
return args
|
405 |
+
|
406 |
+
def main():
|
407 |
+
print(f" {bcolors.OKBLUE}Booting Up StableTuner{bcolors.ENDC}")
|
408 |
+
print(f" {bcolors.OKBLUE}Please wait a moment as we load up some stuff...{bcolors.ENDC}")
|
409 |
+
#torch.cuda.set_per_process_memory_fraction(0.5)
|
410 |
+
args = parse_args()
|
411 |
+
#temp arg
|
412 |
+
args.batch_tokens = None
|
413 |
+
if args.disable_cudnn_benchmark:
|
414 |
+
torch.backends.cudnn.benchmark = False
|
415 |
+
else:
|
416 |
+
torch.backends.cudnn.benchmark = True
|
417 |
+
if args.send_telegram_updates:
|
418 |
+
send_telegram_message(f"Booting up StableTuner!\n", args.telegram_chat_id, args.telegram_token)
|
419 |
+
logging_dir = Path(args.output_dir, "logs", args.logging_dir)
|
420 |
+
if args.run_name:
|
421 |
+
main_sample_dir = os.path.join(args.output_dir, f"samples_{args.run_name}")
|
422 |
+
else:
|
423 |
+
main_sample_dir = os.path.join(args.output_dir, "samples")
|
424 |
+
if os.path.exists(main_sample_dir):
|
425 |
+
shutil.rmtree(main_sample_dir)
|
426 |
+
os.makedirs(main_sample_dir)
|
427 |
+
#create logging directory
|
428 |
+
if not logging_dir.exists():
|
429 |
+
logging_dir.mkdir(parents=True)
|
430 |
+
#create output directory
|
431 |
+
if not Path(args.output_dir).exists():
|
432 |
+
Path(args.output_dir).mkdir(parents=True)
|
433 |
+
|
434 |
+
|
435 |
+
accelerator = Accelerator(
|
436 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
437 |
+
mixed_precision=args.mixed_precision if args.mixed_precision != 'tf32' else 'no',
|
438 |
+
log_with="tensorboard",
|
439 |
+
logging_dir=logging_dir,
|
440 |
+
)
|
441 |
+
|
442 |
+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
443 |
+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
444 |
+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
445 |
+
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
446 |
+
raise ValueError(
|
447 |
+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
448 |
+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
449 |
+
)
|
450 |
+
|
451 |
+
if args.seed is not None:
|
452 |
+
set_seed(args.seed)
|
453 |
+
|
454 |
+
if args.concepts_list is None:
|
455 |
+
args.concepts_list = [
|
456 |
+
{
|
457 |
+
"instance_prompt": args.instance_prompt,
|
458 |
+
"class_prompt": args.class_prompt,
|
459 |
+
"instance_data_dir": args.instance_data_dir,
|
460 |
+
"class_data_dir": args.class_data_dir
|
461 |
+
}
|
462 |
+
]
|
463 |
+
else:
|
464 |
+
with open(args.concepts_list, "r") as f:
|
465 |
+
args.concepts_list = json.load(f)
|
466 |
+
|
467 |
+
if args.with_prior_preservation or args.add_class_images_to_dataset:
|
468 |
+
pipeline = None
|
469 |
+
for concept in args.concepts_list:
|
470 |
+
class_images_dir = Path(concept["class_data_dir"])
|
471 |
+
class_images_dir.mkdir(parents=True, exist_ok=True)
|
472 |
+
cur_class_images = len(list(class_images_dir.iterdir()))
|
473 |
+
|
474 |
+
if cur_class_images < args.num_class_images:
|
475 |
+
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
476 |
+
if pipeline is None:
|
477 |
+
|
478 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
479 |
+
args.pretrained_model_name_or_path,
|
480 |
+
safety_checker=None,
|
481 |
+
vae=AutoencoderKL.from_pretrained(args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,subfolder=None if args.pretrained_vae_name_or_path else "vae" ,safe_serialization=True),
|
482 |
+
torch_dtype=torch_dtype,
|
483 |
+
requires_safety_checker=False,
|
484 |
+
)
|
485 |
+
pipeline.set_progress_bar_config(disable=True)
|
486 |
+
pipeline.to(accelerator.device)
|
487 |
+
|
488 |
+
#if args.use_bucketing == False:
|
489 |
+
num_new_images = args.num_class_images - cur_class_images
|
490 |
+
logger.info(f"Number of class images to sample: {num_new_images}.")
|
491 |
+
|
492 |
+
sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
|
493 |
+
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
494 |
+
sample_dataloader = accelerator.prepare(sample_dataloader)
|
495 |
+
#else:
|
496 |
+
#create class images that match up to the concept target buckets
|
497 |
+
# instance_images_dir = Path(concept["instance_data_dir"])
|
498 |
+
# cur_instance_images = len(list(instance_images_dir.iterdir()))
|
499 |
+
#target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
500 |
+
# num_new_images = cur_instance_images - cur_class_images
|
501 |
+
|
502 |
+
|
503 |
+
|
504 |
+
with torch.autocast("cuda"):
|
505 |
+
for example in tqdm(
|
506 |
+
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
507 |
+
):
|
508 |
+
with torch.autocast("cuda"):
|
509 |
+
images = pipeline(example["prompt"],height=args.resolution,width=args.resolution).images
|
510 |
+
for i, image in enumerate(images):
|
511 |
+
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
512 |
+
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
513 |
+
image.save(image_filename)
|
514 |
+
|
515 |
+
del pipeline
|
516 |
+
if torch.cuda.is_available():
|
517 |
+
torch.cuda.empty_cache()
|
518 |
+
torch.cuda.ipc_collect()
|
519 |
+
# Load the tokenizer
|
520 |
+
if args.tokenizer_name:
|
521 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name )
|
522 |
+
elif args.pretrained_model_name_or_path:
|
523 |
+
#print(os.getcwd())
|
524 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer" )
|
525 |
+
|
526 |
+
# Load models and create wrapper for stable diffusion
|
527 |
+
#text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder" )
|
528 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
529 |
+
args.pretrained_model_name_or_path,
|
530 |
+
subfolder="text_encoder",
|
531 |
+
revision=args.revision,
|
532 |
+
)
|
533 |
+
vae = AutoencoderKL.from_pretrained(
|
534 |
+
args.pretrained_model_name_or_path,
|
535 |
+
subfolder="vae",
|
536 |
+
revision=args.revision,
|
537 |
+
)
|
538 |
+
unet = UNet2DConditionModel.from_pretrained(
|
539 |
+
args.pretrained_model_name_or_path,
|
540 |
+
subfolder="unet",
|
541 |
+
revision=args.revision,
|
542 |
+
torch_dtype=torch.float32
|
543 |
+
)
|
544 |
+
|
545 |
+
if args.with_gan:
|
546 |
+
if os.path.isdir(os.path.join(args.pretrained_model_name_or_path, "discriminator")):
|
547 |
+
discriminator = Discriminator2D.from_pretrained(
|
548 |
+
args.pretrained_model_name_or_path,
|
549 |
+
subfolder="discriminator",
|
550 |
+
revision=args.revision,
|
551 |
+
)
|
552 |
+
else:
|
553 |
+
print(f" {bcolors.WARNING}Discriminator network (GAN) not found. Initializing a new network. It may take a very large number of steps to train.{bcolors.ENDC}")
|
554 |
+
if not args.gan_warmup:
|
555 |
+
print(f" {bcolors.WARNING}Consider using --gan_warmup to stabilize the model while the discriminator is being trained.{bcolors.ENDC}")
|
556 |
+
with open(args.discriminator_config, "r") as f:
|
557 |
+
discriminator_config = json.load(f)
|
558 |
+
discriminator = Discriminator2D.from_config(discriminator_config)
|
559 |
+
|
560 |
+
|
561 |
+
if is_xformers_available() and args.attention=='xformers':
|
562 |
+
try:
|
563 |
+
vae.enable_xformers_memory_efficient_attention()
|
564 |
+
unet.enable_xformers_memory_efficient_attention()
|
565 |
+
if args.with_gan:
|
566 |
+
discriminator.enable_xformers_memory_efficient_attention()
|
567 |
+
except Exception as e:
|
568 |
+
logger.warning(
|
569 |
+
"Could not enable memory efficient attention. Make sure xformers is installed"
|
570 |
+
f" correctly and a GPU is available: {e}"
|
571 |
+
)
|
572 |
+
elif args.attention=='flash_attention':
|
573 |
+
replace_unet_cross_attn_to_flash_attention()
|
574 |
+
|
575 |
+
if args.use_ema == True:
|
576 |
+
if os.path.isdir(os.path.join(args.pretrained_model_name_or_path, "unet_ema")):
|
577 |
+
ema_unet = UNet2DConditionModel.from_pretrained(
|
578 |
+
args.pretrained_model_name_or_path,
|
579 |
+
subfolder="unet_ema",
|
580 |
+
revision=args.revision,
|
581 |
+
torch_dtype=torch.float32
|
582 |
+
)
|
583 |
+
else:
|
584 |
+
ema_unet = copy.deepcopy(unet)
|
585 |
+
ema_unet.config["step"] = 0
|
586 |
+
for param in ema_unet.parameters():
|
587 |
+
param.requires_grad = False
|
588 |
+
|
589 |
+
if args.model_variant == "depth2img":
|
590 |
+
d2i = Depth2Img(unet,text_encoder,args.mixed_precision,args.pretrained_model_name_or_path,accelerator)
|
591 |
+
vae.requires_grad_(False)
|
592 |
+
vae.enable_slicing()
|
593 |
+
if not args.train_text_encoder:
|
594 |
+
text_encoder.requires_grad_(False)
|
595 |
+
|
596 |
+
if args.gradient_checkpointing:
|
597 |
+
unet.enable_gradient_checkpointing()
|
598 |
+
if args.train_text_encoder:
|
599 |
+
text_encoder.gradient_checkpointing_enable()
|
600 |
+
if args.with_gan:
|
601 |
+
discriminator.enable_gradient_checkpointing()
|
602 |
+
|
603 |
+
if args.scale_lr:
|
604 |
+
args.learning_rate = (
|
605 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
606 |
+
)
|
607 |
+
|
608 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
609 |
+
if args.use_8bit_adam and args.use_deepspeed_adam==False and args.use_lion==False:
|
610 |
+
try:
|
611 |
+
import bitsandbytes as bnb
|
612 |
+
except ImportError:
|
613 |
+
raise ImportError(
|
614 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
615 |
+
)
|
616 |
+
optimizer_class = bnb.optim.AdamW8bit
|
617 |
+
print("Using 8-bit Adam")
|
618 |
+
elif args.use_8bit_adam and args.use_deepspeed_adam==True:
|
619 |
+
try:
|
620 |
+
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
621 |
+
except ImportError:
|
622 |
+
raise ImportError(
|
623 |
+
"To use 8-bit DeepSpeed Adam, try updating your cuda and deepspeed integrations."
|
624 |
+
)
|
625 |
+
optimizer_class = DeepSpeedCPUAdam
|
626 |
+
elif args.use_lion == True:
|
627 |
+
print("Using LION optimizer")
|
628 |
+
optimizer_class = Lion
|
629 |
+
elif args.use_deepspeed_adam==False and args.use_lion==False and args.use_8bit_adam==False:
|
630 |
+
optimizer_class = torch.optim.AdamW
|
631 |
+
params_to_optimize = (
|
632 |
+
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
|
633 |
+
)
|
634 |
+
if args.use_lion == False:
|
635 |
+
optimizer = optimizer_class(
|
636 |
+
params_to_optimize,
|
637 |
+
lr=args.learning_rate,
|
638 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
639 |
+
weight_decay=args.adam_weight_decay,
|
640 |
+
eps=args.adam_epsilon,
|
641 |
+
)
|
642 |
+
if args.with_gan:
|
643 |
+
optimizer_discriminator = optimizer_class(
|
644 |
+
discriminator.parameters(),
|
645 |
+
lr=args.learning_rate,
|
646 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
647 |
+
weight_decay=args.adam_weight_decay,
|
648 |
+
eps=args.adam_epsilon,
|
649 |
+
)
|
650 |
+
else:
|
651 |
+
optimizer = optimizer_class(
|
652 |
+
params_to_optimize,
|
653 |
+
lr=args.learning_rate,
|
654 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
655 |
+
weight_decay=args.adam_weight_decay,
|
656 |
+
#eps=args.adam_epsilon,
|
657 |
+
)
|
658 |
+
if args.with_gan:
|
659 |
+
optimizer_discriminator = optimizer_class(
|
660 |
+
discriminator.parameters(),
|
661 |
+
lr=args.learning_rate,
|
662 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
663 |
+
weight_decay=args.adam_weight_decay,
|
664 |
+
#eps=args.adam_epsilon,
|
665 |
+
)
|
666 |
+
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
|
667 |
+
|
668 |
+
if args.use_bucketing:
|
669 |
+
train_dataset = AutoBucketing(
|
670 |
+
concepts_list=args.concepts_list,
|
671 |
+
use_image_names_as_captions=args.use_image_names_as_captions,
|
672 |
+
shuffle_captions=args.shuffle_captions,
|
673 |
+
batch_size=args.train_batch_size,
|
674 |
+
tokenizer=tokenizer,
|
675 |
+
add_class_images_to_dataset=args.add_class_images_to_dataset,
|
676 |
+
balance_datasets=args.auto_balance_concept_datasets,
|
677 |
+
resolution=args.resolution,
|
678 |
+
with_prior_loss=False,#args.with_prior_preservation,
|
679 |
+
repeats=args.dataset_repeats,
|
680 |
+
use_text_files_as_captions=args.use_text_files_as_captions,
|
681 |
+
aspect_mode=args.aspect_mode,
|
682 |
+
action_preference=args.aspect_mode_action_preference,
|
683 |
+
seed=args.seed,
|
684 |
+
model_variant=args.model_variant,
|
685 |
+
extra_module=None if args.model_variant != "depth2img" else d2i,
|
686 |
+
mask_prompts=args.mask_prompts,
|
687 |
+
load_mask=args.masked_training,
|
688 |
+
)
|
689 |
+
else:
|
690 |
+
train_dataset = NormalDataset(
|
691 |
+
concepts_list=args.concepts_list,
|
692 |
+
tokenizer=tokenizer,
|
693 |
+
with_prior_preservation=args.with_prior_preservation,
|
694 |
+
size=args.resolution,
|
695 |
+
center_crop=args.center_crop,
|
696 |
+
num_class_images=args.num_class_images,
|
697 |
+
use_image_names_as_captions=args.use_image_names_as_captions,
|
698 |
+
shuffle_captions=args.shuffle_captions,
|
699 |
+
repeats=args.dataset_repeats,
|
700 |
+
use_text_files_as_captions=args.use_text_files_as_captions,
|
701 |
+
seed = args.seed,
|
702 |
+
model_variant=args.model_variant,
|
703 |
+
extra_module=None if args.model_variant != "depth2img" else d2i,
|
704 |
+
mask_prompts=args.mask_prompts,
|
705 |
+
load_mask=args.masked_training,
|
706 |
+
)
|
707 |
+
def collate_fn(examples):
|
708 |
+
#print(examples)
|
709 |
+
#print('test')
|
710 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
711 |
+
tokens = input_ids
|
712 |
+
pixel_values = [example["instance_images"] for example in examples]
|
713 |
+
mask = None
|
714 |
+
if "mask" in examples[0]:
|
715 |
+
mask = [example["mask"] for example in examples]
|
716 |
+
if args.model_variant == 'depth2img':
|
717 |
+
depth = [example["instance_depth_images"] for example in examples]
|
718 |
+
|
719 |
+
#print('test')
|
720 |
+
# Concat class and instance examples for prior preservation.
|
721 |
+
# We do this to avoid doing two forward passes.
|
722 |
+
if args.with_prior_preservation:
|
723 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
724 |
+
pixel_values += [example["class_images"] for example in examples]
|
725 |
+
if "mask" in examples[0]:
|
726 |
+
mask += [example["class_mask"] for example in examples]
|
727 |
+
if args.model_variant == 'depth2img':
|
728 |
+
depth = [example["class_depth_images"] for example in examples]
|
729 |
+
mask_values = None
|
730 |
+
if mask is not None:
|
731 |
+
mask_values = torch.stack(mask)
|
732 |
+
mask_values = mask_values.to(memory_format=torch.contiguous_format).float()
|
733 |
+
if args.model_variant == 'depth2img':
|
734 |
+
depth_values = torch.stack(depth)
|
735 |
+
depth_values = depth_values.to(memory_format=torch.contiguous_format).float()
|
736 |
+
### no need to do it now when it's loaded by the multiAspectsDataset
|
737 |
+
#if args.with_prior_preservation:
|
738 |
+
# input_ids += [example["class_prompt_ids"] for example in examples]
|
739 |
+
# pixel_values += [example["class_images"] for example in examples]
|
740 |
+
|
741 |
+
#print(pixel_values)
|
742 |
+
#unpack the pixel_values from tensor to list
|
743 |
+
|
744 |
+
|
745 |
+
pixel_values = torch.stack(pixel_values)
|
746 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
747 |
+
input_ids = tokenizer.pad(
|
748 |
+
{"input_ids": input_ids},
|
749 |
+
padding="max_length",
|
750 |
+
max_length=tokenizer.model_max_length,
|
751 |
+
return_tensors="pt",\
|
752 |
+
).input_ids
|
753 |
+
|
754 |
+
extra_values = None
|
755 |
+
if args.model_variant == 'depth2img':
|
756 |
+
extra_values = depth_values
|
757 |
+
|
758 |
+
return {
|
759 |
+
"input_ids": input_ids,
|
760 |
+
"pixel_values": pixel_values,
|
761 |
+
"extra_values": extra_values,
|
762 |
+
"mask_values": mask_values,
|
763 |
+
"tokens": tokens
|
764 |
+
}
|
765 |
+
|
766 |
+
train_dataloader = torch.utils.data.DataLoader(
|
767 |
+
train_dataset, batch_size=args.train_batch_size, shuffle=False, collate_fn=collate_fn, pin_memory=True
|
768 |
+
)
|
769 |
+
#get the length of the dataset
|
770 |
+
train_dataset_length = len(train_dataset)
|
771 |
+
#code to check if latent cache needs to be resaved
|
772 |
+
#check if last_run.json file exists in logging_dir
|
773 |
+
if os.path.exists(logging_dir / "last_run.json"):
|
774 |
+
#if it exists, load it
|
775 |
+
with open(logging_dir / "last_run.json", "r") as f:
|
776 |
+
last_run = json.load(f)
|
777 |
+
last_run_batch_size = last_run["batch_size"]
|
778 |
+
last_run_dataset_length = last_run["dataset_length"]
|
779 |
+
if last_run_batch_size != args.train_batch_size:
|
780 |
+
print(f" {bcolors.WARNING}The batch_size has changed since the last run. Regenerating Latent Cache.{bcolors.ENDC}")
|
781 |
+
|
782 |
+
args.regenerate_latent_cache = True
|
783 |
+
#save the new batch_size and dataset_length to last_run.json
|
784 |
+
if last_run_dataset_length != train_dataset_length:
|
785 |
+
print(f" {bcolors.WARNING}The dataset length has changed since the last run. Regenerating Latent Cache.{bcolors.ENDC}")
|
786 |
+
|
787 |
+
args.regenerate_latent_cache = True
|
788 |
+
#save the new batch_size and dataset_length to last_run.json
|
789 |
+
with open(logging_dir / "last_run.json", "w") as f:
|
790 |
+
json.dump({"batch_size": args.train_batch_size, "dataset_length": train_dataset_length}, f)
|
791 |
+
|
792 |
+
else:
|
793 |
+
#if it doesn't exist, create it
|
794 |
+
last_run = {"batch_size": args.train_batch_size, "dataset_length": train_dataset_length}
|
795 |
+
#create the file
|
796 |
+
with open(logging_dir / "last_run.json", "w") as f:
|
797 |
+
json.dump(last_run, f)
|
798 |
+
|
799 |
+
weight_dtype = torch.float32
|
800 |
+
if accelerator.mixed_precision == "fp16":
|
801 |
+
print("Using fp16")
|
802 |
+
weight_dtype = torch.float16
|
803 |
+
elif accelerator.mixed_precision == "bf16":
|
804 |
+
print("Using bf16")
|
805 |
+
weight_dtype = torch.bfloat16
|
806 |
+
elif args.mixed_precision == "tf32":
|
807 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
808 |
+
#torch.set_float32_matmul_precision("medium")
|
809 |
+
|
810 |
+
# Move text_encode and vae to gpu.
|
811 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
812 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
813 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
814 |
+
if args.use_ema == True:
|
815 |
+
ema_unet.to(accelerator.device)
|
816 |
+
if not args.train_text_encoder:
|
817 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
818 |
+
|
819 |
+
if args.use_bucketing:
|
820 |
+
wh = set([tuple(x.target_wh) for x in train_dataset.image_train_items])
|
821 |
+
else:
|
822 |
+
wh = set([tuple([args.resolution, args.resolution]) for x in train_dataset.image_paths])
|
823 |
+
full_mask_by_aspect = {shape: vae.encode(torch.zeros(1, 3, shape[1], shape[0]).to(accelerator.device, dtype=weight_dtype)).latent_dist.mean * 0.18215 for shape in wh}
|
824 |
+
|
825 |
+
cached_dataset = CachedLatentsDataset(batch_size=args.train_batch_size,
|
826 |
+
text_encoder=text_encoder,
|
827 |
+
tokenizer=tokenizer,
|
828 |
+
dtype=weight_dtype,
|
829 |
+
model_variant=args.model_variant,
|
830 |
+
shuffle_per_epoch="False",
|
831 |
+
args = args,)
|
832 |
+
|
833 |
+
gen_cache = False
|
834 |
+
data_len = len(train_dataloader)
|
835 |
+
latent_cache_dir = Path(args.output_dir, "logs", "latent_cache")
|
836 |
+
#check if latents_cache.pt exists in the output_dir
|
837 |
+
if not os.path.exists(latent_cache_dir):
|
838 |
+
os.makedirs(latent_cache_dir)
|
839 |
+
for i in range(0,data_len-1):
|
840 |
+
if not os.path.exists(os.path.join(latent_cache_dir, f"latents_cache_{i}.pt")):
|
841 |
+
gen_cache = True
|
842 |
+
break
|
843 |
+
if args.regenerate_latent_cache == True:
|
844 |
+
files = os.listdir(latent_cache_dir)
|
845 |
+
gen_cache = True
|
846 |
+
for file in files:
|
847 |
+
os.remove(os.path.join(latent_cache_dir,file))
|
848 |
+
if gen_cache == False :
|
849 |
+
print(f" {bcolors.OKGREEN}Loading Latent Cache from {latent_cache_dir}{bcolors.ENDC}")
|
850 |
+
del vae
|
851 |
+
if not args.train_text_encoder:
|
852 |
+
del text_encoder
|
853 |
+
if torch.cuda.is_available():
|
854 |
+
torch.cuda.empty_cache()
|
855 |
+
torch.cuda.ipc_collect()
|
856 |
+
#load all the cached latents into a single dataset
|
857 |
+
for i in range(0,data_len-1):
|
858 |
+
cached_dataset.add_pt_cache(os.path.join(latent_cache_dir,f"latents_cache_{i}.pt"))
|
859 |
+
if gen_cache == True:
|
860 |
+
#delete all the cached latents if they exist to avoid problems
|
861 |
+
print(f" {bcolors.WARNING}Generating latents cache...{bcolors.ENDC}")
|
862 |
+
train_dataset = LatentsDataset([], [], [], [], [], [])
|
863 |
+
counter = 0
|
864 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
865 |
+
with torch.no_grad():
|
866 |
+
for batch in tqdm(train_dataloader, desc="Caching latents", bar_format='%s{l_bar}%s%s{bar}%s%s{r_bar}%s'%(bcolors.OKBLUE,bcolors.ENDC, bcolors.OKBLUE, bcolors.ENDC,bcolors.OKBLUE,bcolors.ENDC,)):
|
867 |
+
cached_extra = None
|
868 |
+
cached_mask = None
|
869 |
+
batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
|
870 |
+
batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
|
871 |
+
cached_latent = vae.encode(batch["pixel_values"]).latent_dist
|
872 |
+
if batch["mask_values"] is not None:
|
873 |
+
cached_mask = functional.resize(batch["mask_values"], size=cached_latent.mean.shape[2:])
|
874 |
+
if batch["mask_values"] is not None and args.model_variant == "inpainting":
|
875 |
+
batch["mask_values"] = batch["mask_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
|
876 |
+
cached_extra = vae.encode(batch["pixel_values"] * (1 - batch["mask_values"])).latent_dist
|
877 |
+
if args.model_variant == "depth2img":
|
878 |
+
batch["extra_values"] = batch["extra_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
|
879 |
+
cached_extra = functional.resize(batch["extra_values"], size=cached_latent.mean.shape[2:])
|
880 |
+
if args.train_text_encoder:
|
881 |
+
cached_text_enc = batch["input_ids"]
|
882 |
+
else:
|
883 |
+
cached_text_enc = text_encoder(batch["input_ids"])[0]
|
884 |
+
train_dataset.add_latent(cached_latent, cached_text_enc, cached_mask, cached_extra, batch["tokens"])
|
885 |
+
del batch
|
886 |
+
del cached_latent
|
887 |
+
del cached_text_enc
|
888 |
+
del cached_mask
|
889 |
+
del cached_extra
|
890 |
+
torch.save(train_dataset, os.path.join(latent_cache_dir,f"latents_cache_{counter}.pt"))
|
891 |
+
cached_dataset.add_pt_cache(os.path.join(latent_cache_dir,f"latents_cache_{counter}.pt"))
|
892 |
+
counter += 1
|
893 |
+
train_dataset = LatentsDataset([], [], [], [], [], [])
|
894 |
+
#if counter % 300 == 0:
|
895 |
+
#train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=False)
|
896 |
+
# gc.collect()
|
897 |
+
# torch.cuda.empty_cache()
|
898 |
+
# accelerator.free_memory()
|
899 |
+
|
900 |
+
#clear vram after caching latents
|
901 |
+
del vae
|
902 |
+
if not args.train_text_encoder:
|
903 |
+
del text_encoder
|
904 |
+
if torch.cuda.is_available():
|
905 |
+
torch.cuda.empty_cache()
|
906 |
+
torch.cuda.ipc_collect()
|
907 |
+
#load all the cached latents into a single dataset
|
908 |
+
train_dataloader = torch.utils.data.DataLoader(cached_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=False)
|
909 |
+
print(f" {bcolors.OKGREEN}Latents are ready.{bcolors.ENDC}")
|
910 |
+
# Scheduler and math around the number of training steps.
|
911 |
+
overrode_max_train_steps = False
|
912 |
+
num_update_steps_per_epoch = len(train_dataloader)
|
913 |
+
if args.max_train_steps is None:
|
914 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
915 |
+
overrode_max_train_steps = True
|
916 |
+
|
917 |
+
if args.lr_warmup_steps < 1:
|
918 |
+
args.lr_warmup_steps = math.floor(args.lr_warmup_steps * args.max_train_steps / args.gradient_accumulation_steps)
|
919 |
+
|
920 |
+
lr_scheduler = get_scheduler(
|
921 |
+
args.lr_scheduler,
|
922 |
+
optimizer=optimizer,
|
923 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
924 |
+
num_training_steps=args.max_train_steps,
|
925 |
+
)
|
926 |
+
|
927 |
+
if args.train_text_encoder and not args.use_ema:
|
928 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
929 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
930 |
+
)
|
931 |
+
elif args.train_text_encoder and args.use_ema:
|
932 |
+
unet, text_encoder, ema_unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
933 |
+
unet, text_encoder, ema_unet, optimizer, train_dataloader, lr_scheduler
|
934 |
+
)
|
935 |
+
elif not args.train_text_encoder and args.use_ema:
|
936 |
+
unet, ema_unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
937 |
+
unet, ema_unet, optimizer, train_dataloader, lr_scheduler
|
938 |
+
)
|
939 |
+
elif not args.train_text_encoder and not args.use_ema:
|
940 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
941 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
942 |
+
)
|
943 |
+
if args.with_gan:
|
944 |
+
lr_scheduler_discriminator = get_scheduler(
|
945 |
+
args.lr_scheduler,
|
946 |
+
optimizer=optimizer_discriminator,
|
947 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
948 |
+
num_training_steps=args.max_train_steps,
|
949 |
+
)
|
950 |
+
discriminator, optimizer_discriminator, lr_scheduler_discriminator = accelerator.prepare(discriminator, optimizer_discriminator, lr_scheduler_discriminator)
|
951 |
+
|
952 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
953 |
+
num_update_steps_per_epoch = len(train_dataloader)
|
954 |
+
if overrode_max_train_steps:
|
955 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
956 |
+
#print(args.max_train_steps, num_update_steps_per_epoch)
|
957 |
+
# Afterwards we recalculate our number of training epochs
|
958 |
+
#print(args.max_train_steps, num_update_steps_per_epoch)
|
959 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
960 |
+
|
961 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
962 |
+
# The trackers initializes automatically on the main process.
|
963 |
+
if accelerator.is_main_process:
|
964 |
+
accelerator.init_trackers("dreambooth")
|
965 |
+
# Train!
|
966 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
967 |
+
|
968 |
+
logger.info("***** Running training *****")
|
969 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
970 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
971 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
972 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
973 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
974 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
975 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
976 |
+
def mid_train_playground(step):
|
977 |
+
|
978 |
+
tqdm.write(f"{bcolors.WARNING} Booting up GUI{bcolors.ENDC}")
|
979 |
+
epoch = step // num_update_steps_per_epoch
|
980 |
+
if args.train_text_encoder and args.stop_text_encoder_training == True:
|
981 |
+
text_enc_model = accelerator.unwrap_model(text_encoder,True)
|
982 |
+
elif args.train_text_encoder and args.stop_text_encoder_training > epoch:
|
983 |
+
text_enc_model = accelerator.unwrap_model(text_encoder,True)
|
984 |
+
elif args.train_text_encoder == False:
|
985 |
+
text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder" )
|
986 |
+
elif args.train_text_encoder and args.stop_text_encoder_training <= epoch:
|
987 |
+
if 'frozen_directory' in locals():
|
988 |
+
text_enc_model = CLIPTextModel.from_pretrained(frozen_directory, subfolder="text_encoder")
|
989 |
+
else:
|
990 |
+
text_enc_model = accelerator.unwrap_model(text_encoder,True)
|
991 |
+
scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
992 |
+
unwrapped_unet = accelerator.unwrap_model(ema_unet if args.use_ema else unet,True)
|
993 |
+
|
994 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
995 |
+
args.pretrained_model_name_or_path,
|
996 |
+
unet=unwrapped_unet,
|
997 |
+
text_encoder=text_enc_model,
|
998 |
+
vae=AutoencoderKL.from_pretrained(args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,subfolder=None if args.pretrained_vae_name_or_path else "vae", safe_serialization=True),
|
999 |
+
safety_checker=None,
|
1000 |
+
torch_dtype=weight_dtype,
|
1001 |
+
local_files_only=False,
|
1002 |
+
requires_safety_checker=False,
|
1003 |
+
)
|
1004 |
+
pipeline.scheduler = scheduler
|
1005 |
+
if is_xformers_available() and args.attention=='xformers':
|
1006 |
+
try:
|
1007 |
+
vae.enable_xformers_memory_efficient_attention()
|
1008 |
+
unet.enable_xformers_memory_efficient_attention()
|
1009 |
+
except Exception as e:
|
1010 |
+
logger.warning(
|
1011 |
+
"Could not enable memory efficient attention. Make sure xformers is installed"
|
1012 |
+
f" correctly and a GPU is available: {e}"
|
1013 |
+
)
|
1014 |
+
elif args.attention=='flash_attention':
|
1015 |
+
replace_unet_cross_attn_to_flash_attention()
|
1016 |
+
pipeline = pipeline.to(accelerator.device)
|
1017 |
+
def inference(prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50,seed=-1,guidance_scale=7.5):
|
1018 |
+
with torch.autocast("cuda"), torch.inference_mode():
|
1019 |
+
if seed != -1:
|
1020 |
+
if g_cuda is None:
|
1021 |
+
g_cuda = torch.Generator(device='cuda')
|
1022 |
+
else:
|
1023 |
+
g_cuda.manual_seed(int(seed))
|
1024 |
+
else:
|
1025 |
+
seed = random.randint(0, 100000)
|
1026 |
+
g_cuda = torch.Generator(device='cuda')
|
1027 |
+
g_cuda.manual_seed(seed)
|
1028 |
+
return pipeline(
|
1029 |
+
prompt, height=int(height), width=int(width),
|
1030 |
+
negative_prompt=negative_prompt,
|
1031 |
+
num_images_per_prompt=int(num_samples),
|
1032 |
+
num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
|
1033 |
+
generator=g_cuda).images, seed
|
1034 |
+
|
1035 |
+
with gr.Blocks() as demo:
|
1036 |
+
with gr.Row():
|
1037 |
+
with gr.Column():
|
1038 |
+
prompt = gr.Textbox(label="Prompt", value="photo of zwx dog in a bucket")
|
1039 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
|
1040 |
+
run = gr.Button(value="Generate")
|
1041 |
+
with gr.Row():
|
1042 |
+
num_samples = gr.Number(label="Number of Samples", value=4)
|
1043 |
+
guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
|
1044 |
+
with gr.Row():
|
1045 |
+
height = gr.Number(label="Height", value=512)
|
1046 |
+
width = gr.Number(label="Width", value=512)
|
1047 |
+
with gr.Row():
|
1048 |
+
num_inference_steps = gr.Slider(label="Steps", value=25)
|
1049 |
+
seed = gr.Number(label="Seed", value=-1)
|
1050 |
+
with gr.Column():
|
1051 |
+
gallery = gr.Gallery()
|
1052 |
+
seedDisplay = gr.Number(label="Used Seed:", value=0)
|
1053 |
+
|
1054 |
+
run.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps,seed, guidance_scale], outputs=[gallery,seedDisplay])
|
1055 |
+
|
1056 |
+
demo.launch(share=True,prevent_thread_lock=True)
|
1057 |
+
tqdm.write(f"{bcolors.WARNING}Gradio Session is active, Press 'F12' to resume training{bcolors.ENDC}")
|
1058 |
+
keyboard.wait('f12')
|
1059 |
+
demo.close()
|
1060 |
+
del demo
|
1061 |
+
del text_enc_model
|
1062 |
+
del unwrapped_unet
|
1063 |
+
del pipeline
|
1064 |
+
return
|
1065 |
+
|
1066 |
+
def save_and_sample_weights(step,context='checkpoint',save_model=True):
|
1067 |
+
try:
|
1068 |
+
#check how many folders are in the output dir
|
1069 |
+
#if there are more than 5, delete the oldest one
|
1070 |
+
#save the model
|
1071 |
+
#save the optimizer
|
1072 |
+
#save the lr_scheduler
|
1073 |
+
#save the args
|
1074 |
+
height = args.sample_height
|
1075 |
+
width = args.sample_width
|
1076 |
+
batch_prompts = []
|
1077 |
+
if args.sample_from_batch > 0:
|
1078 |
+
num_samples = args.sample_from_batch if args.sample_from_batch < args.train_batch_size else args.train_batch_size
|
1079 |
+
batch_prompts = []
|
1080 |
+
tokens = args.batch_tokens
|
1081 |
+
if tokens != None:
|
1082 |
+
allPrompts = list(set([tokenizer.decode(p).replace('<|endoftext|>','').replace('<|startoftext|>', '') for p in tokens]))
|
1083 |
+
if len(allPrompts) < num_samples:
|
1084 |
+
num_samples = len(allPrompts)
|
1085 |
+
batch_prompts = random.sample(allPrompts, num_samples)
|
1086 |
+
|
1087 |
+
|
1088 |
+
if args.sample_aspect_ratios:
|
1089 |
+
#choose random aspect ratio from ASPECTS
|
1090 |
+
aspect_ratio = random.choice(ASPECTS)
|
1091 |
+
height = aspect_ratio[0]
|
1092 |
+
width = aspect_ratio[1]
|
1093 |
+
if os.path.exists(args.output_dir):
|
1094 |
+
if args.detect_full_drive==True:
|
1095 |
+
folders = os.listdir(args.output_dir)
|
1096 |
+
#check how much space is left on the drive
|
1097 |
+
total, used, free = shutil.disk_usage("/")
|
1098 |
+
if (free // (2**30)) < 4:
|
1099 |
+
#folders.remove("0")
|
1100 |
+
#get the folder with the lowest number
|
1101 |
+
#oldest_folder = min(folder for folder in folders if folder.isdigit())
|
1102 |
+
tqdm.write(f"{bcolors.FAIL}Drive is almost full, Please make some space to continue training.{bcolors.ENDC}")
|
1103 |
+
if args.send_telegram_updates:
|
1104 |
+
try:
|
1105 |
+
send_telegram_message(f"Drive is almost full, Please make some space to continue training.", args.telegram_chat_id, args.telegram_token)
|
1106 |
+
except:
|
1107 |
+
pass
|
1108 |
+
#count time
|
1109 |
+
import time
|
1110 |
+
start_time = time.time()
|
1111 |
+
import platform
|
1112 |
+
while input("Press Enter to continue... if you're on linux we'll wait 5 minutes for you to make space and continue"):
|
1113 |
+
#check if five minutes have passed
|
1114 |
+
#check if os is linux
|
1115 |
+
if 'Linux' in platform.platform():
|
1116 |
+
if time.time() - start_time > 300:
|
1117 |
+
break
|
1118 |
+
|
1119 |
+
|
1120 |
+
#oldest_folder_path = os.path.join(args.output_dir, oldest_folder)
|
1121 |
+
#shutil.rmtree(oldest_folder_path)
|
1122 |
+
# Create the pipeline using using the trained modules and save it.
|
1123 |
+
if accelerator.is_main_process:
|
1124 |
+
if 'step' in context:
|
1125 |
+
#what is the current epoch
|
1126 |
+
epoch = step // num_update_steps_per_epoch
|
1127 |
+
else:
|
1128 |
+
epoch = step
|
1129 |
+
if args.train_text_encoder and args.stop_text_encoder_training == True:
|
1130 |
+
text_enc_model = accelerator.unwrap_model(text_encoder,True)
|
1131 |
+
elif args.train_text_encoder and args.stop_text_encoder_training > epoch:
|
1132 |
+
text_enc_model = accelerator.unwrap_model(text_encoder,True)
|
1133 |
+
elif args.train_text_encoder == False:
|
1134 |
+
text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder" )
|
1135 |
+
elif args.train_text_encoder and args.stop_text_encoder_training <= epoch:
|
1136 |
+
if 'frozen_directory' in locals():
|
1137 |
+
text_enc_model = CLIPTextModel.from_pretrained(frozen_directory, subfolder="text_encoder")
|
1138 |
+
else:
|
1139 |
+
text_enc_model = accelerator.unwrap_model(text_encoder,True)
|
1140 |
+
|
1141 |
+
#scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
1142 |
+
#scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", prediction_type="v_prediction")
|
1143 |
+
scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
1144 |
+
unwrapped_unet = accelerator.unwrap_model(unet,True)
|
1145 |
+
|
1146 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
1147 |
+
args.pretrained_model_name_or_path,
|
1148 |
+
unet=unwrapped_unet,
|
1149 |
+
text_encoder=text_enc_model,
|
1150 |
+
vae=AutoencoderKL.from_pretrained(args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,subfolder=None if args.pretrained_vae_name_or_path else "vae",),
|
1151 |
+
safety_checker=None,
|
1152 |
+
torch_dtype=weight_dtype,
|
1153 |
+
local_files_only=False,
|
1154 |
+
requires_safety_checker=False,
|
1155 |
+
)
|
1156 |
+
pipeline.scheduler = scheduler
|
1157 |
+
if is_xformers_available() and args.attention=='xformers':
|
1158 |
+
try:
|
1159 |
+
unet.enable_xformers_memory_efficient_attention()
|
1160 |
+
except Exception as e:
|
1161 |
+
logger.warning(
|
1162 |
+
"Could not enable memory efficient attention. Make sure xformers is installed"
|
1163 |
+
f" correctly and a GPU is available: {e}"
|
1164 |
+
)
|
1165 |
+
elif args.attention=='flash_attention':
|
1166 |
+
replace_unet_cross_attn_to_flash_attention()
|
1167 |
+
if args.run_name:
|
1168 |
+
save_dir = os.path.join(args.output_dir, f"{context}_{step}_{args.run_name}")
|
1169 |
+
else:
|
1170 |
+
save_dir = os.path.join(args.output_dir, f"{context}_{step}")
|
1171 |
+
if args.flatten_sample_folder:
|
1172 |
+
sample_dir = main_sample_dir
|
1173 |
+
else:
|
1174 |
+
sample_dir = os.path.join(main_sample_dir, f"{context}_{step}")
|
1175 |
+
#if sample dir path does not exist, create it
|
1176 |
+
|
1177 |
+
if args.stop_text_encoder_training == True:
|
1178 |
+
save_dir = frozen_directory
|
1179 |
+
if save_model:
|
1180 |
+
pipeline.save_pretrained(save_dir,safe_serialization=True)
|
1181 |
+
if args.with_gan:
|
1182 |
+
discriminator.save_pretrained(os.path.join(save_dir, "discriminator"), safe_serialization=True)
|
1183 |
+
if args.use_ema:
|
1184 |
+
ema_unet.save_pretrained(os.path.join(save_dir, "unet_ema"), safe_serialization=True)
|
1185 |
+
with open(os.path.join(save_dir, "args.json"), "w") as f:
|
1186 |
+
json.dump(args.__dict__, f, indent=2)
|
1187 |
+
if args.stop_text_encoder_training == True:
|
1188 |
+
#delete every folder in frozen_directory but the text encoder
|
1189 |
+
for folder in os.listdir(save_dir):
|
1190 |
+
if folder != "text_encoder" and os.path.isdir(os.path.join(save_dir, folder)):
|
1191 |
+
shutil.rmtree(os.path.join(save_dir, folder))
|
1192 |
+
imgs = []
|
1193 |
+
if args.use_ema and args.sample_from_ema:
|
1194 |
+
pipeline.unet = ema_unet
|
1195 |
+
|
1196 |
+
for param in unet.parameters():
|
1197 |
+
param.requires_grad = False
|
1198 |
+
if torch.cuda.is_available():
|
1199 |
+
torch.cuda.empty_cache()
|
1200 |
+
torch.cuda.ipc_collect()
|
1201 |
+
|
1202 |
+
if args.add_sample_prompt is not None or batch_prompts != [] and args.stop_text_encoder_training != True:
|
1203 |
+
prompts = []
|
1204 |
+
if args.add_sample_prompt is not None:
|
1205 |
+
for prompt in args.add_sample_prompt:
|
1206 |
+
prompts.append(prompt)
|
1207 |
+
if batch_prompts != []:
|
1208 |
+
for prompt in batch_prompts:
|
1209 |
+
prompts.append(prompt)
|
1210 |
+
|
1211 |
+
pipeline = pipeline.to(accelerator.device)
|
1212 |
+
pipeline.set_progress_bar_config(disable=True)
|
1213 |
+
#sample_dir = os.path.join(save_dir, "samples")
|
1214 |
+
#if sample_dir exists, delete it
|
1215 |
+
if os.path.exists(sample_dir):
|
1216 |
+
if not args.flatten_sample_folder:
|
1217 |
+
shutil.rmtree(sample_dir)
|
1218 |
+
os.makedirs(sample_dir, exist_ok=True)
|
1219 |
+
with torch.autocast("cuda"), torch.inference_mode():
|
1220 |
+
if args.send_telegram_updates:
|
1221 |
+
try:
|
1222 |
+
send_telegram_message(f"Generating samples for <b>{step}</b> {context}", args.telegram_chat_id, args.telegram_token)
|
1223 |
+
except:
|
1224 |
+
pass
|
1225 |
+
n_sample = args.n_save_sample
|
1226 |
+
if args.save_sample_controlled_seed:
|
1227 |
+
n_sample += len(args.save_sample_controlled_seed)
|
1228 |
+
progress_bar_sample = tqdm(total=len(prompts)*n_sample,desc="Generating samples")
|
1229 |
+
for samplePrompt in prompts:
|
1230 |
+
sampleIndex = prompts.index(samplePrompt)
|
1231 |
+
#convert sampleIndex to number in words
|
1232 |
+
# Data to be written
|
1233 |
+
sampleProperties = {
|
1234 |
+
"samplePrompt" : samplePrompt
|
1235 |
+
}
|
1236 |
+
|
1237 |
+
# Serializing json
|
1238 |
+
json_object = json.dumps(sampleProperties, indent=4)
|
1239 |
+
|
1240 |
+
if args.flatten_sample_folder:
|
1241 |
+
sampleName = f"{context}_{step}_prompt_{sampleIndex+1}"
|
1242 |
+
else:
|
1243 |
+
sampleName = f"prompt_{sampleIndex+1}"
|
1244 |
+
|
1245 |
+
if not args.flatten_sample_folder:
|
1246 |
+
os.makedirs(os.path.join(sample_dir,sampleName), exist_ok=True)
|
1247 |
+
|
1248 |
+
if args.model_variant == 'inpainting':
|
1249 |
+
conditioning_image = torch.zeros(1, 3, height, width)
|
1250 |
+
mask = torch.ones(1, 1, height, width)
|
1251 |
+
if args.model_variant == 'depth2img':
|
1252 |
+
#pil new white image
|
1253 |
+
test_image = Image.new('RGB', (width, height), (255, 255, 255))
|
1254 |
+
depth_image = Image.new('RGB', (width, height), (255, 255, 255))
|
1255 |
+
depth = np.array(depth_image.convert("L"))
|
1256 |
+
depth = depth.astype(np.float32) / 255.0
|
1257 |
+
depth = depth[None, None]
|
1258 |
+
depth = torch.from_numpy(depth)
|
1259 |
+
for i in range(n_sample):
|
1260 |
+
#check if the sample is controlled by a seed
|
1261 |
+
if i < args.n_save_sample:
|
1262 |
+
if args.model_variant == 'inpainting':
|
1263 |
+
images = pipeline(samplePrompt, conditioning_image, mask, height=height,width=width, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps).images
|
1264 |
+
if args.model_variant == 'depth2img':
|
1265 |
+
images = pipeline(samplePrompt,image=test_image, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps,strength=1.0).images
|
1266 |
+
elif args.model_variant == 'base':
|
1267 |
+
images = pipeline(samplePrompt,height=height,width=width, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps).images
|
1268 |
+
|
1269 |
+
if not args.flatten_sample_folder:
|
1270 |
+
images[0].save(os.path.join(sample_dir,sampleName, f"{sampleName}_{i}.png"))
|
1271 |
+
else:
|
1272 |
+
images[0].save(os.path.join(sample_dir, f"{sampleName}_{i}.png"))
|
1273 |
+
|
1274 |
+
else:
|
1275 |
+
seed = args.save_sample_controlled_seed[i - args.n_save_sample]
|
1276 |
+
generator = torch.Generator("cuda").manual_seed(seed)
|
1277 |
+
if args.model_variant == 'inpainting':
|
1278 |
+
images = pipeline(samplePrompt,conditioning_image, mask,height=height,width=width, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps, generator=generator).images
|
1279 |
+
if args.model_variant == 'depth2img':
|
1280 |
+
images = pipeline(samplePrompt,image=test_image, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps,generator=generator,strength=1.0).images
|
1281 |
+
elif args.model_variant == 'base':
|
1282 |
+
images = pipeline(samplePrompt,height=height,width=width, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps, generator=generator).images
|
1283 |
+
|
1284 |
+
if not args.flatten_sample_folder:
|
1285 |
+
images[0].save(os.path.join(sample_dir,sampleName, f"{sampleName}_controlled_seed_{str(seed)}.png"))
|
1286 |
+
else:
|
1287 |
+
images[0].save(os.path.join(sample_dir, f"{sampleName}_controlled_seed_{str(seed)}.png"))
|
1288 |
+
progress_bar_sample.update(1)
|
1289 |
+
|
1290 |
+
if args.send_telegram_updates:
|
1291 |
+
imgs = []
|
1292 |
+
#get all the images from the sample folder
|
1293 |
+
if not args.flatten_sample_folder:
|
1294 |
+
dir = os.listdir(os.path.join(sample_dir,sampleName))
|
1295 |
+
else:
|
1296 |
+
dir = sample_dir
|
1297 |
+
|
1298 |
+
for file in dir:
|
1299 |
+
if file.endswith(".png"):
|
1300 |
+
#open the image with pil
|
1301 |
+
img = Image.open(os.path.join(sample_dir,sampleName,file))
|
1302 |
+
imgs.append(img)
|
1303 |
+
try:
|
1304 |
+
send_media_group(args.telegram_chat_id,args.telegram_token,imgs, caption=f"Samples for the <b>{step}</b> {context} using the prompt:\n\n<b>{samplePrompt}</b>")
|
1305 |
+
except:
|
1306 |
+
pass
|
1307 |
+
del pipeline
|
1308 |
+
del unwrapped_unet
|
1309 |
+
for param in unet.parameters():
|
1310 |
+
param.requires_grad = True
|
1311 |
+
if torch.cuda.is_available():
|
1312 |
+
torch.cuda.empty_cache()
|
1313 |
+
torch.cuda.ipc_collect()
|
1314 |
+
if save_model == True:
|
1315 |
+
tqdm.write(f"{bcolors.OKGREEN}Weights saved to {save_dir}{bcolors.ENDC}")
|
1316 |
+
elif save_model == False and len(imgs) > 0:
|
1317 |
+
del imgs
|
1318 |
+
tqdm.write(f"{bcolors.OKGREEN}Samples saved to {sample_dir}{bcolors.ENDC}")
|
1319 |
+
|
1320 |
+
except Exception as e:
|
1321 |
+
tqdm.write(e)
|
1322 |
+
tqdm.write(f"{bcolors.FAIL} Error occured during sampling, skipping.{bcolors.ENDC}")
|
1323 |
+
pass
|
1324 |
+
|
1325 |
+
@torch.no_grad()
|
1326 |
+
def update_ema(ema_model, model):
|
1327 |
+
ema_step = ema_model.config["step"]
|
1328 |
+
decay = min((ema_step + 1) / (ema_step + 10), 0.9999)
|
1329 |
+
ema_model.config["step"] += 1
|
1330 |
+
for (s_param, param) in zip(ema_model.parameters(), model.parameters()):
|
1331 |
+
if param.requires_grad:
|
1332 |
+
s_param.add_((1 - decay) * (param - s_param))
|
1333 |
+
else:
|
1334 |
+
s_param.copy_(param)
|
1335 |
+
|
1336 |
+
|
1337 |
+
# Only show the progress bar once on each machine.
|
1338 |
+
progress_bar = tqdm(range(args.max_train_steps),bar_format='%s{l_bar}%s%s{bar}%s%s{r_bar}%s'%(bcolors.OKBLUE,bcolors.ENDC, bcolors.OKBLUE, bcolors.ENDC,bcolors.OKBLUE,bcolors.ENDC,), disable=not accelerator.is_local_main_process)
|
1339 |
+
progress_bar_inter_epoch = tqdm(range(num_update_steps_per_epoch),bar_format='%s{l_bar}%s%s{bar}%s%s{r_bar}%s'%(bcolors.OKBLUE,bcolors.ENDC, bcolors.OKGREEN, bcolors.ENDC,bcolors.OKBLUE,bcolors.ENDC,), disable=not accelerator.is_local_main_process)
|
1340 |
+
progress_bar_e = tqdm(range(args.num_train_epochs),bar_format='%s{l_bar}%s%s{bar}%s%s{r_bar}%s'%(bcolors.OKBLUE,bcolors.ENDC, bcolors.OKGREEN, bcolors.ENDC,bcolors.OKBLUE,bcolors.ENDC,), disable=not accelerator.is_local_main_process)
|
1341 |
+
|
1342 |
+
progress_bar.set_description("Overall Steps")
|
1343 |
+
progress_bar_inter_epoch.set_description("Steps To Epoch")
|
1344 |
+
progress_bar_e.set_description("Overall Epochs")
|
1345 |
+
global_step = 0
|
1346 |
+
loss_avg = AverageMeter("loss_avg", max_eta=0.999)
|
1347 |
+
gan_loss_avg = AverageMeter("gan_loss_avg", max_eta=0.999)
|
1348 |
+
text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
|
1349 |
+
if args.send_telegram_updates:
|
1350 |
+
try:
|
1351 |
+
send_telegram_message(f"Starting training with the following settings:\n\n{format_dict(args.__dict__)}", args.telegram_chat_id, args.telegram_token)
|
1352 |
+
except:
|
1353 |
+
pass
|
1354 |
+
try:
|
1355 |
+
tqdm.write(f"{bcolors.OKBLUE}Starting Training!{bcolors.ENDC}")
|
1356 |
+
try:
|
1357 |
+
def toggle_gui(event=None):
|
1358 |
+
if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("g"):
|
1359 |
+
tqdm.write(f"{bcolors.WARNING}GUI will boot as soon as the current step is done.{bcolors.ENDC}")
|
1360 |
+
nonlocal mid_generation
|
1361 |
+
if mid_generation == True:
|
1362 |
+
mid_generation = False
|
1363 |
+
tqdm.write(f"{bcolors.WARNING}Cancelled GUI.{bcolors.ENDC}")
|
1364 |
+
else:
|
1365 |
+
mid_generation = True
|
1366 |
+
|
1367 |
+
def toggle_checkpoint(event=None):
|
1368 |
+
if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("s") and not keyboard.is_pressed("alt"):
|
1369 |
+
tqdm.write(f"{bcolors.WARNING}Saving the model as soon as this epoch is done.{bcolors.ENDC}")
|
1370 |
+
nonlocal mid_checkpoint
|
1371 |
+
if mid_checkpoint == True:
|
1372 |
+
mid_checkpoint = False
|
1373 |
+
tqdm.write(f"{bcolors.WARNING}Cancelled Checkpointing.{bcolors.ENDC}")
|
1374 |
+
else:
|
1375 |
+
mid_checkpoint = True
|
1376 |
+
|
1377 |
+
def toggle_sample(event=None):
|
1378 |
+
if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("p") and not keyboard.is_pressed("alt"):
|
1379 |
+
tqdm.write(f"{bcolors.WARNING}Sampling will begin as soon as this epoch is done.{bcolors.ENDC}")
|
1380 |
+
nonlocal mid_sample
|
1381 |
+
if mid_sample == True:
|
1382 |
+
mid_sample = False
|
1383 |
+
tqdm.write(f"{bcolors.WARNING}Cancelled Sampling.{bcolors.ENDC}")
|
1384 |
+
else:
|
1385 |
+
mid_sample = True
|
1386 |
+
def toggle_checkpoint_step(event=None):
|
1387 |
+
if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("alt") and keyboard.is_pressed("s"):
|
1388 |
+
tqdm.write(f"{bcolors.WARNING}Saving the model as soon as this step is done.{bcolors.ENDC}")
|
1389 |
+
nonlocal mid_checkpoint_step
|
1390 |
+
if mid_checkpoint_step == True:
|
1391 |
+
mid_checkpoint_step = False
|
1392 |
+
tqdm.write(f"{bcolors.WARNING}Cancelled Checkpointing.{bcolors.ENDC}")
|
1393 |
+
else:
|
1394 |
+
mid_checkpoint_step = True
|
1395 |
+
|
1396 |
+
def toggle_sample_step(event=None):
|
1397 |
+
if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("alt") and keyboard.is_pressed("p"):
|
1398 |
+
tqdm.write(f"{bcolors.WARNING}Sampling will begin as soon as this step is done.{bcolors.ENDC}")
|
1399 |
+
nonlocal mid_sample_step
|
1400 |
+
if mid_sample_step == True:
|
1401 |
+
mid_sample_step = False
|
1402 |
+
tqdm.write(f"{bcolors.WARNING}Cancelled Sampling.{bcolors.ENDC}")
|
1403 |
+
else:
|
1404 |
+
mid_sample_step = True
|
1405 |
+
def toggle_quit_and_save_epoch(event=None):
|
1406 |
+
if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("q") and not keyboard.is_pressed("alt"):
|
1407 |
+
tqdm.write(f"{bcolors.WARNING}Quitting and saving the model as soon as this epoch is done.{bcolors.ENDC}")
|
1408 |
+
nonlocal mid_quit
|
1409 |
+
if mid_quit == True:
|
1410 |
+
mid_quit = False
|
1411 |
+
tqdm.write(f"{bcolors.WARNING}Cancelled Quitting.{bcolors.ENDC}")
|
1412 |
+
else:
|
1413 |
+
mid_quit = True
|
1414 |
+
def toggle_quit_and_save_step(event=None):
|
1415 |
+
if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("alt") and keyboard.is_pressed("q"):
|
1416 |
+
tqdm.write(f"{bcolors.WARNING}Quitting and saving the model as soon as this step is done.{bcolors.ENDC}")
|
1417 |
+
nonlocal mid_quit_step
|
1418 |
+
if mid_quit_step == True:
|
1419 |
+
mid_quit_step = False
|
1420 |
+
tqdm.write(f"{bcolors.WARNING}Cancelled Quitting.{bcolors.ENDC}")
|
1421 |
+
else:
|
1422 |
+
mid_quit_step = True
|
1423 |
+
def help(event=None):
|
1424 |
+
if keyboard.is_pressed("ctrl") and keyboard.is_pressed("h"):
|
1425 |
+
print_instructions()
|
1426 |
+
keyboard.on_press_key("g", toggle_gui)
|
1427 |
+
keyboard.on_press_key("s", toggle_checkpoint)
|
1428 |
+
keyboard.on_press_key("p", toggle_sample)
|
1429 |
+
keyboard.on_press_key("s", toggle_checkpoint_step)
|
1430 |
+
keyboard.on_press_key("p", toggle_sample_step)
|
1431 |
+
keyboard.on_press_key("q", toggle_quit_and_save_epoch)
|
1432 |
+
keyboard.on_press_key("q", toggle_quit_and_save_step)
|
1433 |
+
keyboard.on_press_key("h", help)
|
1434 |
+
print_instructions()
|
1435 |
+
except Exception as e:
|
1436 |
+
pass
|
1437 |
+
|
1438 |
+
mid_generation = False
|
1439 |
+
mid_checkpoint = False
|
1440 |
+
mid_sample = False
|
1441 |
+
mid_checkpoint_step = False
|
1442 |
+
mid_sample_step = False
|
1443 |
+
mid_quit = False
|
1444 |
+
mid_quit_step = False
|
1445 |
+
#lambda set mid_generation to true
|
1446 |
+
if args.run_name:
|
1447 |
+
frozen_directory = os.path.join(args.output_dir, f"frozen_text_encoder_{args.run_name}")
|
1448 |
+
else:
|
1449 |
+
frozen_directory = os.path.join(args.output_dir, "frozen_text_encoder")
|
1450 |
+
|
1451 |
+
unet_stats = {}
|
1452 |
+
discriminator_stats = {}
|
1453 |
+
|
1454 |
+
os.makedirs(main_sample_dir, exist_ok=True)
|
1455 |
+
with open(os.path.join(main_sample_dir, "args.json"), "w") as f:
|
1456 |
+
json.dump(args.__dict__, f, indent=2)
|
1457 |
+
if args.with_gan:
|
1458 |
+
with open(os.path.join(main_sample_dir, "discriminator_config.json"), "w") as f:
|
1459 |
+
json.dump(discriminator.config, f, indent=2)
|
1460 |
+
|
1461 |
+
for epoch in range(args.num_train_epochs):
|
1462 |
+
#every 10 epochs print instructions
|
1463 |
+
unet.train()
|
1464 |
+
if args.train_text_encoder:
|
1465 |
+
text_encoder.train()
|
1466 |
+
|
1467 |
+
#save initial weights
|
1468 |
+
if args.sample_on_training_start==True and epoch==0:
|
1469 |
+
save_and_sample_weights(epoch,'start',save_model=False)
|
1470 |
+
|
1471 |
+
if args.train_text_encoder and args.stop_text_encoder_training == epoch:
|
1472 |
+
args.stop_text_encoder_training = True
|
1473 |
+
if accelerator.is_main_process:
|
1474 |
+
tqdm.write(f"{bcolors.WARNING} Stopping text encoder training{bcolors.ENDC}")
|
1475 |
+
current_percentage = (epoch/args.num_train_epochs)*100
|
1476 |
+
#round to the nearest whole number
|
1477 |
+
current_percentage = round(current_percentage,0)
|
1478 |
+
try:
|
1479 |
+
send_telegram_message(f"Text encoder training stopped at epoch {epoch} which is {current_percentage}% of training. Freezing weights and saving.", args.telegram_chat_id, args.telegram_token)
|
1480 |
+
except:
|
1481 |
+
pass
|
1482 |
+
if os.path.exists(frozen_directory):
|
1483 |
+
#delete the folder if it already exists
|
1484 |
+
shutil.rmtree(frozen_directory)
|
1485 |
+
os.mkdir(frozen_directory)
|
1486 |
+
save_and_sample_weights(epoch,'epoch')
|
1487 |
+
args.stop_text_encoder_training = epoch
|
1488 |
+
progress_bar_inter_epoch.reset(total=num_update_steps_per_epoch)
|
1489 |
+
for step, batch in enumerate(train_dataloader):
|
1490 |
+
with accelerator.accumulate(unet):
|
1491 |
+
# Convert images to latent space
|
1492 |
+
with torch.no_grad():
|
1493 |
+
|
1494 |
+
latent_dist = batch[0][0]
|
1495 |
+
latents = latent_dist.sample() * 0.18215
|
1496 |
+
|
1497 |
+
if args.model_variant == 'inpainting':
|
1498 |
+
mask = batch[0][2]
|
1499 |
+
mask_mean = batch[0][3]
|
1500 |
+
conditioning_latent_dist = batch[0][4]
|
1501 |
+
conditioning_latents = conditioning_latent_dist.sample() * 0.18215
|
1502 |
+
if args.model_variant == 'depth2img':
|
1503 |
+
depth = batch[0][4]
|
1504 |
+
if args.sample_from_batch > 0:
|
1505 |
+
args.batch_tokens = batch[0][5]
|
1506 |
+
# Sample noise that we'll add to the latents
|
1507 |
+
# and some extra bits to make it so that the model learns to change the zero-frequency of the component freely
|
1508 |
+
# https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
1509 |
+
if (args.with_offset_noise == True):
|
1510 |
+
noise = torch.randn_like(latents) + (args.offset_noise_weight * torch.randn(latents.shape[0], latents.shape[1], 1, 1).to(accelerator.device))
|
1511 |
+
else:
|
1512 |
+
noise = torch.randn_like(latents)
|
1513 |
+
|
1514 |
+
bsz = latents.shape[0]
|
1515 |
+
# Sample a random timestep for each image
|
1516 |
+
timesteps = torch.randint(0, int(noise_scheduler.config.num_train_timesteps * args.max_denoising_strength), (bsz,), device=latents.device)
|
1517 |
+
timesteps = timesteps.long()
|
1518 |
+
|
1519 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
1520 |
+
# (this is the forward diffusion process)
|
1521 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
1522 |
+
|
1523 |
+
# Get the text embedding for conditioning
|
1524 |
+
with text_enc_context:
|
1525 |
+
if args.train_text_encoder:
|
1526 |
+
if args.clip_penultimate == True:
|
1527 |
+
encoder_hidden_states = text_encoder(batch[0][1],output_hidden_states=True)
|
1528 |
+
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
|
1529 |
+
else:
|
1530 |
+
encoder_hidden_states = text_encoder(batch[0][1])[0]
|
1531 |
+
else:
|
1532 |
+
encoder_hidden_states = batch[0][1]
|
1533 |
+
|
1534 |
+
|
1535 |
+
# Predict the noise residual
|
1536 |
+
mask=None
|
1537 |
+
if args.model_variant == 'inpainting':
|
1538 |
+
if mask is not None and random.uniform(0, 1) < args.unmasked_probability:
|
1539 |
+
# for some steps, predict the unmasked image
|
1540 |
+
conditioning_latents = torch.stack([full_mask_by_aspect[tuple([latents.shape[3]*8, latents.shape[2]*8])].squeeze()] * bsz)
|
1541 |
+
mask = torch.ones(bsz, 1, latents.shape[2], latents.shape[3]).to(accelerator.device, dtype=weight_dtype)
|
1542 |
+
noisy_inpaint_latents = torch.concat([noisy_latents, mask, conditioning_latents], 1)
|
1543 |
+
model_pred = unet(noisy_inpaint_latents, timesteps, encoder_hidden_states).sample
|
1544 |
+
elif args.model_variant == 'depth2img':
|
1545 |
+
noisy_depth_latents = torch.cat([noisy_latents, depth], dim=1)
|
1546 |
+
model_pred = unet(noisy_depth_latents, timesteps, encoder_hidden_states, depth).sample
|
1547 |
+
elif args.model_variant == "base":
|
1548 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
1549 |
+
|
1550 |
+
|
1551 |
+
# Get the target for loss depending on the prediction type
|
1552 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
1553 |
+
target = noise
|
1554 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
1555 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
1556 |
+
else:
|
1557 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
1558 |
+
|
1559 |
+
# GAN stuff
|
1560 |
+
# Input: noisy_latents
|
1561 |
+
# True output: target
|
1562 |
+
# Fake output: model_pred
|
1563 |
+
|
1564 |
+
if args.with_gan:
|
1565 |
+
# Turn on learning for the discriminator, and do an optimization step
|
1566 |
+
for param in discriminator.parameters():
|
1567 |
+
param.requires_grad = True
|
1568 |
+
|
1569 |
+
pred_fake = discriminator(torch.cat((noisy_latents, model_pred), 1).detach(), encoder_hidden_states)
|
1570 |
+
pred_real = discriminator(torch.cat((noisy_latents, target), 1), encoder_hidden_states)
|
1571 |
+
discriminator_loss = F.mse_loss(pred_fake, torch.zeros_like(pred_fake), reduction="mean") + F.mse_loss(pred_real, torch.ones_like(pred_real), reduction="mean")
|
1572 |
+
if discriminator_loss.isnan():
|
1573 |
+
tqdm.write(f"{bcolors.WARNING}Discriminator loss is NAN, skipping GAN update.{bcolors.ENDC}")
|
1574 |
+
else:
|
1575 |
+
accelerator.backward(discriminator_loss)
|
1576 |
+
if accelerator.sync_gradients:
|
1577 |
+
accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm)
|
1578 |
+
optimizer_discriminator.step()
|
1579 |
+
lr_scheduler_discriminator.step()
|
1580 |
+
# Hack to fix NaNs caused by GAN training
|
1581 |
+
for name, p in discriminator.named_parameters():
|
1582 |
+
if p.isnan().any():
|
1583 |
+
fix_nans_(p, name, discriminator_stats[name])
|
1584 |
+
else:
|
1585 |
+
(std, mean) = torch.std_mean(p)
|
1586 |
+
discriminator_stats[name] = (std.item(), mean.item())
|
1587 |
+
del std, mean
|
1588 |
+
optimizer_discriminator.zero_grad()
|
1589 |
+
del pred_real, pred_fake, discriminator_loss
|
1590 |
+
|
1591 |
+
# Turn off learning for the discriminator for the generator optimization step
|
1592 |
+
for param in discriminator.parameters():
|
1593 |
+
param.requires_grad = False
|
1594 |
+
|
1595 |
+
if args.with_prior_preservation:
|
1596 |
+
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
|
1597 |
+
"""
|
1598 |
+
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
|
1599 |
+
noise, noise_prior = torch.chunk(noise, 2, dim=0)
|
1600 |
+
|
1601 |
+
# Compute instance loss
|
1602 |
+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
|
1603 |
+
|
1604 |
+
# Compute prior loss
|
1605 |
+
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
|
1606 |
+
|
1607 |
+
# Add the prior loss to the instance loss.
|
1608 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
1609 |
+
"""
|
1610 |
+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
1611 |
+
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
1612 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
1613 |
+
if mask is not None and args.model_variant != "inpainting":
|
1614 |
+
loss = masked_mse_loss(model_pred.float(), target.float(), mask, reduction="none").mean([1, 2, 3]).mean()
|
1615 |
+
prior_loss = masked_mse_loss(model_pred_prior.float(), target_prior.float(), mask, reduction="mean")
|
1616 |
+
else:
|
1617 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
|
1618 |
+
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
1619 |
+
|
1620 |
+
# Add the prior loss to the instance loss.
|
1621 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
1622 |
+
|
1623 |
+
if mask is not None and args.normalize_masked_area_loss:
|
1624 |
+
loss = loss / mask_mean
|
1625 |
+
|
1626 |
+
else:
|
1627 |
+
if mask is not None and args.model_variant != "inpainting":
|
1628 |
+
loss = masked_mse_loss(model_pred.float(), target.float(), mask, reduction="none").mean([1, 2, 3])
|
1629 |
+
loss = loss.mean()
|
1630 |
+
else:
|
1631 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
1632 |
+
|
1633 |
+
if mask is not None and args.normalize_masked_area_loss:
|
1634 |
+
loss = loss / mask_mean
|
1635 |
+
|
1636 |
+
base_loss = loss
|
1637 |
+
|
1638 |
+
if args.with_gan:
|
1639 |
+
# Add loss from the GAN
|
1640 |
+
pred_fake = discriminator(torch.cat((noisy_latents, model_pred), 1), encoder_hidden_states)
|
1641 |
+
gan_loss = F.mse_loss(pred_fake, torch.ones_like(pred_fake), reduction="mean")
|
1642 |
+
if gan_loss.isnan():
|
1643 |
+
tqdm.write(f"{bcolors.WARNING}GAN loss is NAN, skipping GAN loss.{bcolors.ENDC}")
|
1644 |
+
else:
|
1645 |
+
gan_weight = args.gan_weight
|
1646 |
+
if args.gan_warmup and global_step < args.gan_warmup:
|
1647 |
+
gan_weight *= global_step / args.gan_warmup
|
1648 |
+
loss += gan_weight * gan_loss
|
1649 |
+
del pred_fake
|
1650 |
+
|
1651 |
+
accelerator.backward(loss)
|
1652 |
+
if accelerator.sync_gradients:
|
1653 |
+
params_to_clip = (
|
1654 |
+
itertools.chain(unet.parameters(), text_encoder.parameters())
|
1655 |
+
if args.train_text_encoder
|
1656 |
+
else unet.parameters()
|
1657 |
+
)
|
1658 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
1659 |
+
optimizer.step()
|
1660 |
+
lr_scheduler.step()
|
1661 |
+
# Hack to fix NaNs caused by GAN training
|
1662 |
+
for name, p in unet.named_parameters():
|
1663 |
+
if p.isnan().any():
|
1664 |
+
fix_nans_(p, name, unet_stats[name])
|
1665 |
+
else:
|
1666 |
+
(std, mean) = torch.std_mean(p)
|
1667 |
+
unet_stats[name] = (std.item(), mean.item())
|
1668 |
+
del std, mean
|
1669 |
+
optimizer.zero_grad()
|
1670 |
+
loss_avg.update(base_loss.detach_())
|
1671 |
+
if args.with_gan and not gan_loss.isnan():
|
1672 |
+
gan_loss_avg.update(gan_loss.detach_())
|
1673 |
+
if args.use_ema == True:
|
1674 |
+
update_ema(ema_unet, unet)
|
1675 |
+
|
1676 |
+
del loss, model_pred
|
1677 |
+
if args.with_prior_preservation:
|
1678 |
+
del model_pred_prior
|
1679 |
+
|
1680 |
+
logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]}
|
1681 |
+
if args.with_gan:
|
1682 |
+
logs["gan_loss"] = gan_loss_avg.avg.item()
|
1683 |
+
progress_bar.set_postfix(**logs)
|
1684 |
+
if not global_step % args.log_interval:
|
1685 |
+
accelerator.log(logs, step=global_step)
|
1686 |
+
|
1687 |
+
|
1688 |
+
|
1689 |
+
if global_step > 0 and not global_step % args.sample_step_interval:
|
1690 |
+
save_and_sample_weights(global_step,'step',save_model=False)
|
1691 |
+
|
1692 |
+
progress_bar.update(1)
|
1693 |
+
progress_bar_inter_epoch.update(1)
|
1694 |
+
progress_bar_e.refresh()
|
1695 |
+
global_step += 1
|
1696 |
+
|
1697 |
+
if mid_quit_step==True:
|
1698 |
+
accelerator.wait_for_everyone()
|
1699 |
+
save_and_sample_weights(global_step,'quit_step')
|
1700 |
+
quit()
|
1701 |
+
if mid_generation==True:
|
1702 |
+
mid_train_playground(global_step)
|
1703 |
+
mid_generation=False
|
1704 |
+
if mid_checkpoint_step == True:
|
1705 |
+
save_and_sample_weights(global_step,'step',save_model=True)
|
1706 |
+
mid_checkpoint_step=False
|
1707 |
+
mid_sample_step=False
|
1708 |
+
elif mid_sample_step == True:
|
1709 |
+
save_and_sample_weights(global_step,'step',save_model=False)
|
1710 |
+
mid_sample_step=False
|
1711 |
+
if global_step >= args.max_train_steps:
|
1712 |
+
break
|
1713 |
+
progress_bar_e.update(1)
|
1714 |
+
if mid_quit==True:
|
1715 |
+
accelerator.wait_for_everyone()
|
1716 |
+
save_and_sample_weights(epoch,'quit_epoch')
|
1717 |
+
quit()
|
1718 |
+
if epoch == args.num_train_epochs - 1:
|
1719 |
+
save_and_sample_weights(epoch,'epoch',True)
|
1720 |
+
elif args.save_every_n_epoch and (epoch + 1) % args.save_every_n_epoch == 0:
|
1721 |
+
save_and_sample_weights(epoch,'epoch',True)
|
1722 |
+
elif mid_checkpoint==True:
|
1723 |
+
save_and_sample_weights(epoch,'epoch',True)
|
1724 |
+
mid_checkpoint=False
|
1725 |
+
mid_sample=False
|
1726 |
+
elif mid_sample==True:
|
1727 |
+
save_and_sample_weights(epoch,'epoch',False)
|
1728 |
+
mid_sample=False
|
1729 |
+
accelerator.wait_for_everyone()
|
1730 |
+
except Exception:
|
1731 |
+
try:
|
1732 |
+
send_telegram_message("Something went wrong while training! :(", args.telegram_chat_id, args.telegram_token)
|
1733 |
+
#save_and_sample_weights(global_step,'checkpoint')
|
1734 |
+
send_telegram_message(f"Saved checkpoint {global_step} on exit", args.telegram_chat_id, args.telegram_token)
|
1735 |
+
except Exception:
|
1736 |
+
pass
|
1737 |
+
raise
|
1738 |
+
except KeyboardInterrupt:
|
1739 |
+
send_telegram_message("Training stopped", args.telegram_chat_id, args.telegram_token)
|
1740 |
+
try:
|
1741 |
+
send_telegram_message("Training finished!", args.telegram_chat_id, args.telegram_token)
|
1742 |
+
except:
|
1743 |
+
pass
|
1744 |
+
|
1745 |
+
accelerator.end_training()
|
1746 |
+
|
1747 |
+
|
1748 |
+
|
1749 |
+
if __name__ == "__main__":
|
1750 |
+
main()
|
StableTuner_RunPod_Fix/trainer_util.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.utils.checkpoint
|
9 |
+
from accelerate.logging import get_logger
|
10 |
+
from accelerate.utils import set_seed
|
11 |
+
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel,DiffusionPipeline, DPMSolverMultistepScheduler,EulerDiscreteScheduler
|
12 |
+
from diffusers.optimization import get_scheduler
|
13 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
14 |
+
from torchvision import transforms
|
15 |
+
from tqdm.auto import tqdm
|
16 |
+
from typing import Dict, List, Generator, Tuple
|
17 |
+
from PIL import Image, ImageFile
|
18 |
+
from collections.abc import Iterable
|
19 |
+
from trainer_util import *
|
20 |
+
from dataloaders_util import *
|
21 |
+
|
22 |
+
# FlashAttention based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main
|
23 |
+
# /memory_efficient_attention_pytorch/flash_attention.py LICENSE MIT
|
24 |
+
# https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE constants
|
25 |
+
EPSILON = 1e-6
|
26 |
+
|
27 |
+
class bcolors:
|
28 |
+
HEADER = '\033[95m'
|
29 |
+
OKBLUE = '\033[94m'
|
30 |
+
OKCYAN = '\033[96m'
|
31 |
+
OKGREEN = '\033[92m'
|
32 |
+
WARNING = '\033[93m'
|
33 |
+
FAIL = '\033[91m'
|
34 |
+
ENDC = '\033[0m'
|
35 |
+
BOLD = '\033[1m'
|
36 |
+
UNDERLINE = '\033[4m'
|
37 |
+
# helper functions
|
38 |
+
def print_instructions():
|
39 |
+
tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+G' to open up a GUI to play around with the model (will pause training){bcolors.ENDC}")
|
40 |
+
tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+S' to save a checkpoint of the current epoch{bcolors.ENDC}")
|
41 |
+
tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+P' to generate samples for current epoch{bcolors.ENDC}")
|
42 |
+
tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+Q' to save and quit after the current epoch{bcolors.ENDC}")
|
43 |
+
tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+ALT+S' to save a checkpoint of the current step{bcolors.ENDC}")
|
44 |
+
tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+ALT+P' to generate samples for current step{bcolors.ENDC}")
|
45 |
+
tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+ALT+Q' to save and quit after the current step{bcolors.ENDC}")
|
46 |
+
tqdm.write('')
|
47 |
+
tqdm.write(f"{bcolors.WARNING}Use 'CTRL+H' to print this message again.{bcolors.ENDC}")
|
48 |
+
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
49 |
+
if token is None:
|
50 |
+
token = HfFolder.get_token()
|
51 |
+
if organization is None:
|
52 |
+
username = whoami(token)["name"]
|
53 |
+
return f"{username}/{model_id}"
|
54 |
+
else:
|
55 |
+
return f"{organization}/{model_id}"
|
56 |
+
|
57 |
+
#function to format a dictionary into a telegram message
|
58 |
+
def format_dict(d):
|
59 |
+
message = ""
|
60 |
+
for key, value in d.items():
|
61 |
+
#filter keys that have the word "token" in them
|
62 |
+
if "token" in key and "tokenizer" not in key:
|
63 |
+
value = "TOKEN"
|
64 |
+
if 'id' in key:
|
65 |
+
value = "ID"
|
66 |
+
#if value is a dictionary, format it recursively
|
67 |
+
if isinstance(value, dict):
|
68 |
+
for k, v in value.items():
|
69 |
+
message += f"\n- {k}: <b>{v}</b> \n"
|
70 |
+
elif isinstance(value, list):
|
71 |
+
#each value is a new line in the message
|
72 |
+
message += f"- {key}:\n\n"
|
73 |
+
for v in value:
|
74 |
+
message += f" <b>{v}</b>\n\n"
|
75 |
+
#if value is a list, format it as a list
|
76 |
+
else:
|
77 |
+
message += f"- {key}: <b>{value}</b>\n"
|
78 |
+
return message
|
79 |
+
|
80 |
+
def send_telegram_message(message, chat_id, token):
|
81 |
+
url = f"https://api.telegram.org/bot{token}/sendMessage?chat_id={chat_id}&text={message}&parse_mode=html&disable_notification=True"
|
82 |
+
import requests
|
83 |
+
req = requests.get(url)
|
84 |
+
if req.status_code != 200:
|
85 |
+
raise ValueError(f"Telegram request failed with status code {req.status_code}")
|
86 |
+
def send_media_group(chat_id,telegram_token, images, caption=None, reply_to_message_id=None):
|
87 |
+
"""
|
88 |
+
Use this method to send an album of photos. On success, an array of Messages that were sent is returned.
|
89 |
+
:param chat_id: chat id
|
90 |
+
:param images: list of PIL images to send
|
91 |
+
:param caption: caption of image
|
92 |
+
:param reply_to_message_id: If the message is a reply, ID of the original message
|
93 |
+
:return: response with the sent message
|
94 |
+
"""
|
95 |
+
SEND_MEDIA_GROUP = f'https://api.telegram.org/bot{telegram_token}/sendMediaGroup'
|
96 |
+
from io import BytesIO
|
97 |
+
import requests
|
98 |
+
files = {}
|
99 |
+
media = []
|
100 |
+
for i, img in enumerate(images):
|
101 |
+
with BytesIO() as output:
|
102 |
+
img.save(output, format='PNG')
|
103 |
+
output.seek(0)
|
104 |
+
name = f'photo{i}'
|
105 |
+
files[name] = output.read()
|
106 |
+
# a list of InputMediaPhoto. attach refers to the name of the file in the files dict
|
107 |
+
media.append(dict(type='photo', media=f'attach://{name}'))
|
108 |
+
media[0]['caption'] = caption
|
109 |
+
media[0]['parse_mode'] = 'HTML'
|
110 |
+
return requests.post(SEND_MEDIA_GROUP, data={'chat_id': chat_id, 'media': json.dumps(media),'disable_notification':True, 'reply_to_message_id': reply_to_message_id }, files=files)
|
111 |
+
class AverageMeter:
|
112 |
+
def __init__(self, name=None, max_eta=None):
|
113 |
+
self.name = name
|
114 |
+
self.max_eta = max_eta
|
115 |
+
self.reset()
|
116 |
+
|
117 |
+
def reset(self):
|
118 |
+
self.count = self.avg = 0
|
119 |
+
|
120 |
+
@torch.no_grad()
|
121 |
+
def update(self, val, n=1):
|
122 |
+
eta = self.count / (self.count + n)
|
123 |
+
if self.max_eta:
|
124 |
+
eta = min(eta, self.max_eta ** n)
|
125 |
+
self.avg += (1 - eta) * (val - self.avg)
|
126 |
+
self.count += n
|
127 |
+
|
128 |
+
def exists(val):
|
129 |
+
return val is not None
|
130 |
+
|
131 |
+
|
132 |
+
def default(val, d):
|
133 |
+
return val if exists(val) else d
|
134 |
+
|
135 |
+
|
136 |
+
def masked_mse_loss(predicted, target, mask, reduction="none"):
|
137 |
+
masked_predicted = predicted * mask
|
138 |
+
masked_target = target * mask
|
139 |
+
return F.mse_loss(masked_predicted, masked_target, reduction=reduction)
|
140 |
+
|
141 |
+
# flash attention forwards and backwards
|
142 |
+
# https://arxiv.org/abs/2205.14135
|
143 |
+
|
144 |
+
|
145 |
+
class FlashAttentionFunction(torch.autograd.function.Function):
|
146 |
+
@staticmethod
|
147 |
+
@torch.no_grad()
|
148 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
149 |
+
""" Algorithm 2 in the paper """
|
150 |
+
|
151 |
+
device = q.device
|
152 |
+
dtype = q.dtype
|
153 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
154 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
155 |
+
|
156 |
+
o = torch.zeros_like(q)
|
157 |
+
all_row_sums = torch.zeros(
|
158 |
+
(*q.shape[:-1], 1), dtype=dtype, device=device)
|
159 |
+
all_row_maxes = torch.full(
|
160 |
+
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
161 |
+
|
162 |
+
scale = (q.shape[-1] ** -0.5)
|
163 |
+
|
164 |
+
if not exists(mask):
|
165 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
166 |
+
else:
|
167 |
+
mask = rearrange(mask, 'b n -> b 1 1 n')
|
168 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
169 |
+
|
170 |
+
row_splits = zip(
|
171 |
+
q.split(q_bucket_size, dim=-2),
|
172 |
+
o.split(q_bucket_size, dim=-2),
|
173 |
+
mask,
|
174 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
175 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
176 |
+
)
|
177 |
+
|
178 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
179 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
180 |
+
|
181 |
+
col_splits = zip(
|
182 |
+
k.split(k_bucket_size, dim=-2),
|
183 |
+
v.split(k_bucket_size, dim=-2),
|
184 |
+
)
|
185 |
+
|
186 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
187 |
+
k_start_index = k_ind * k_bucket_size
|
188 |
+
|
189 |
+
attn_weights = einsum(
|
190 |
+
'... i d, ... j d -> ... i j', qc, kc) * scale
|
191 |
+
|
192 |
+
if exists(row_mask):
|
193 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
194 |
+
|
195 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
196 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
197 |
+
device=device).triu(q_start_index - k_start_index + 1)
|
198 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
199 |
+
|
200 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
201 |
+
attn_weights -= block_row_maxes
|
202 |
+
exp_weights = torch.exp(attn_weights)
|
203 |
+
|
204 |
+
if exists(row_mask):
|
205 |
+
exp_weights.masked_fill_(~row_mask, 0.)
|
206 |
+
|
207 |
+
block_row_sums = exp_weights.sum(
|
208 |
+
dim=-1, keepdims=True).clamp(min=EPSILON)
|
209 |
+
|
210 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
211 |
+
|
212 |
+
exp_values = einsum(
|
213 |
+
'... i j, ... j d -> ... i d', exp_weights, vc)
|
214 |
+
|
215 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
216 |
+
exp_block_row_max_diff = torch.exp(
|
217 |
+
block_row_maxes - new_row_maxes)
|
218 |
+
|
219 |
+
new_row_sums = exp_row_max_diff * row_sums + \
|
220 |
+
exp_block_row_max_diff * block_row_sums
|
221 |
+
|
222 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
|
223 |
+
(exp_block_row_max_diff / new_row_sums) * exp_values)
|
224 |
+
|
225 |
+
row_maxes.copy_(new_row_maxes)
|
226 |
+
row_sums.copy_(new_row_sums)
|
227 |
+
|
228 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
229 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
230 |
+
|
231 |
+
return o
|
232 |
+
|
233 |
+
@staticmethod
|
234 |
+
@torch.no_grad()
|
235 |
+
def backward(ctx, do):
|
236 |
+
""" Algorithm 4 in the paper """
|
237 |
+
|
238 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
239 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
240 |
+
|
241 |
+
device = q.device
|
242 |
+
|
243 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
244 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
245 |
+
|
246 |
+
dq = torch.zeros_like(q)
|
247 |
+
dk = torch.zeros_like(k)
|
248 |
+
dv = torch.zeros_like(v)
|
249 |
+
|
250 |
+
row_splits = zip(
|
251 |
+
q.split(q_bucket_size, dim=-2),
|
252 |
+
o.split(q_bucket_size, dim=-2),
|
253 |
+
do.split(q_bucket_size, dim=-2),
|
254 |
+
mask,
|
255 |
+
l.split(q_bucket_size, dim=-2),
|
256 |
+
m.split(q_bucket_size, dim=-2),
|
257 |
+
dq.split(q_bucket_size, dim=-2)
|
258 |
+
)
|
259 |
+
|
260 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
261 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
262 |
+
|
263 |
+
col_splits = zip(
|
264 |
+
k.split(k_bucket_size, dim=-2),
|
265 |
+
v.split(k_bucket_size, dim=-2),
|
266 |
+
dk.split(k_bucket_size, dim=-2),
|
267 |
+
dv.split(k_bucket_size, dim=-2),
|
268 |
+
)
|
269 |
+
|
270 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
271 |
+
k_start_index = k_ind * k_bucket_size
|
272 |
+
|
273 |
+
attn_weights = einsum(
|
274 |
+
'... i d, ... j d -> ... i j', qc, kc) * scale
|
275 |
+
|
276 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
277 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
278 |
+
device=device).triu(q_start_index - k_start_index + 1)
|
279 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
280 |
+
|
281 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
282 |
+
|
283 |
+
if exists(row_mask):
|
284 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.)
|
285 |
+
|
286 |
+
p = exp_attn_weights / lc
|
287 |
+
|
288 |
+
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
|
289 |
+
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
|
290 |
+
|
291 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
292 |
+
ds = p * scale * (dp - D)
|
293 |
+
|
294 |
+
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
|
295 |
+
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
|
296 |
+
|
297 |
+
dqc.add_(dq_chunk)
|
298 |
+
dkc.add_(dk_chunk)
|
299 |
+
dvc.add_(dv_chunk)
|
300 |
+
|
301 |
+
return dq, dk, dv, None, None, None, None
|
302 |
+
|
303 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
304 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
305 |
+
pretrained_model_name_or_path,
|
306 |
+
subfolder="text_encoder",
|
307 |
+
revision=revision,
|
308 |
+
)
|
309 |
+
model_class = text_encoder_config.architectures[0]
|
310 |
+
|
311 |
+
if model_class == "CLIPTextModel":
|
312 |
+
from transformers import CLIPTextModel
|
313 |
+
|
314 |
+
return CLIPTextModel
|
315 |
+
elif model_class == "RobertaSeriesModelWithTransformation":
|
316 |
+
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
317 |
+
|
318 |
+
return RobertaSeriesModelWithTransformation
|
319 |
+
else:
|
320 |
+
raise ValueError(f"{model_class} is not supported.")
|
321 |
+
|
322 |
+
def replace_unet_cross_attn_to_flash_attention():
|
323 |
+
print("Using FlashAttention")
|
324 |
+
|
325 |
+
def forward_flash_attn(self, x, context=None, mask=None):
|
326 |
+
q_bucket_size = 512
|
327 |
+
k_bucket_size = 1024
|
328 |
+
|
329 |
+
h = self.heads
|
330 |
+
q = self.to_q(x)
|
331 |
+
|
332 |
+
context = context if context is not None else x
|
333 |
+
context = context.to(x.dtype)
|
334 |
+
|
335 |
+
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
|
336 |
+
context_k, context_v = self.hypernetwork.forward(x, context)
|
337 |
+
context_k = context_k.to(x.dtype)
|
338 |
+
context_v = context_v.to(x.dtype)
|
339 |
+
else:
|
340 |
+
context_k = context
|
341 |
+
context_v = context
|
342 |
+
|
343 |
+
k = self.to_k(context_k)
|
344 |
+
v = self.to_v(context_v)
|
345 |
+
del context, x
|
346 |
+
|
347 |
+
q, k, v = map(lambda t: rearrange(
|
348 |
+
t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
349 |
+
|
350 |
+
out = FlashAttentionFunction.apply(q, k, v, mask, False,
|
351 |
+
q_bucket_size, k_bucket_size)
|
352 |
+
|
353 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
354 |
+
|
355 |
+
# diffusers 0.6.0
|
356 |
+
if type(self.to_out) is torch.nn.Sequential:
|
357 |
+
return self.to_out(out)
|
358 |
+
|
359 |
+
# diffusers 0.7.0
|
360 |
+
out = self.to_out[0](out)
|
361 |
+
out = self.to_out[1](out)
|
362 |
+
return out
|
363 |
+
|
364 |
+
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
365 |
+
class Depth2Img:
|
366 |
+
def __init__(self,unet,text_encoder,revision,pretrained_model_name_or_path,accelerator):
|
367 |
+
self.unet = unet
|
368 |
+
self.text_encoder = text_encoder
|
369 |
+
self.revision = revision if revision != 'no' else 'fp32'
|
370 |
+
self.pretrained_model_name_or_path = pretrained_model_name_or_path
|
371 |
+
self.accelerator = accelerator
|
372 |
+
self.pipeline = None
|
373 |
+
def depth_images(self,paths):
|
374 |
+
if self.pipeline is None:
|
375 |
+
self.pipeline = DiffusionPipeline.from_pretrained(
|
376 |
+
self.pretrained_model_name_or_path,
|
377 |
+
unet=self.accelerator.unwrap_model(self.unet),
|
378 |
+
text_encoder=self.accelerator.unwrap_model(self.text_encoder),
|
379 |
+
revision=self.revision,
|
380 |
+
local_files_only=True,)
|
381 |
+
self.pipeline.to(self.accelerator.device)
|
382 |
+
self.vae_scale_factor = 2 ** (len(self.pipeline.vae.config.block_out_channels) - 1)
|
383 |
+
non_depth_image_files = []
|
384 |
+
image_paths_by_path = {}
|
385 |
+
|
386 |
+
for path in paths:
|
387 |
+
#if path is list
|
388 |
+
if isinstance(path, list):
|
389 |
+
img = Path(path[0])
|
390 |
+
else:
|
391 |
+
img = Path(path)
|
392 |
+
if self.get_depth_image_path(img).exists():
|
393 |
+
continue
|
394 |
+
else:
|
395 |
+
non_depth_image_files.append(img)
|
396 |
+
image_objects = []
|
397 |
+
for image_path in non_depth_image_files:
|
398 |
+
image_instance = Image.open(image_path)
|
399 |
+
if not image_instance.mode == "RGB":
|
400 |
+
image_instance = image_instance.convert("RGB")
|
401 |
+
image_instance = self.pipeline.feature_extractor(
|
402 |
+
image_instance, return_tensors="pt"
|
403 |
+
).pixel_values
|
404 |
+
|
405 |
+
image_instance = image_instance.to(self.accelerator.device)
|
406 |
+
image_objects.append((image_path, image_instance))
|
407 |
+
|
408 |
+
for image_path, image_instance in image_objects:
|
409 |
+
path = image_path.parent
|
410 |
+
ogImg = Image.open(image_path)
|
411 |
+
ogImg_x = ogImg.size[0]
|
412 |
+
ogImg_y = ogImg.size[1]
|
413 |
+
depth_map = self.pipeline.depth_estimator(image_instance).predicted_depth
|
414 |
+
depth_min = torch.amin(depth_map, dim=[0, 1, 2], keepdim=True)
|
415 |
+
depth_max = torch.amax(depth_map, dim=[0, 1, 2], keepdim=True)
|
416 |
+
depth_map = torch.nn.functional.interpolate(depth_map.unsqueeze(1),size=(ogImg_y, ogImg_x),mode="bicubic",align_corners=False,)
|
417 |
+
|
418 |
+
depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
|
419 |
+
depth_map = depth_map[0,:,:]
|
420 |
+
depth_map_image = transforms.ToPILImage()(depth_map)
|
421 |
+
depth_map_image = depth_map_image.filter(ImageFilter.GaussianBlur(radius=1))
|
422 |
+
depth_map_image.save(self.get_depth_image_path(image_path))
|
423 |
+
#quit()
|
424 |
+
return 2 ** (len(self.pipeline.vae.config.block_out_channels) - 1)
|
425 |
+
|
426 |
+
def get_depth_image_path(self,image_path):
|
427 |
+
#if image_path is a string, convert it to a Path object
|
428 |
+
if isinstance(image_path, str):
|
429 |
+
image_path = Path(image_path)
|
430 |
+
return image_path.parent / f"{image_path.stem}-depth.png"
|
431 |
+
|
432 |
+
def fix_nans_(param, name=None, stats=None):
|
433 |
+
(std, mean) = stats or (1, 0)
|
434 |
+
tqdm.write(name, param.shape, param.dtype, mean, std)
|
435 |
+
param.data = torch.where(param.data.isnan(), torch.randn_like(param.data) * std + mean, param.data).detach()
|