John6666 commited on
Commit
c200633
·
verified ·
1 Parent(s): 1d94112

Delete convert_repo_to_safetensors.py

Browse files
Files changed (1) hide show
  1. convert_repo_to_safetensors.py +0 -366
convert_repo_to_safetensors.py DELETED
@@ -1,366 +0,0 @@
1
- # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
- # *Only* converts the UNet, VAE, and Text Encoder.
3
- # Does not convert optimizer state or any other thing.
4
-
5
- import argparse
6
- import os.path as osp
7
- import re
8
-
9
- import torch
10
- from safetensors.torch import load_file, save_file
11
-
12
-
13
- # =================#
14
- # UNet Conversion #
15
- # =================#
16
-
17
- unet_conversion_map = [
18
- # (stable-diffusion, HF Diffusers)
19
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
20
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
21
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
22
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
23
- ("input_blocks.0.0.weight", "conv_in.weight"),
24
- ("input_blocks.0.0.bias", "conv_in.bias"),
25
- ("out.0.weight", "conv_norm_out.weight"),
26
- ("out.0.bias", "conv_norm_out.bias"),
27
- ("out.2.weight", "conv_out.weight"),
28
- ("out.2.bias", "conv_out.bias"),
29
- # the following are for sdxl
30
- ("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
31
- ("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
32
- ("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
33
- ("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
34
- ]
35
-
36
- unet_conversion_map_resnet = [
37
- # (stable-diffusion, HF Diffusers)
38
- ("in_layers.0", "norm1"),
39
- ("in_layers.2", "conv1"),
40
- ("out_layers.0", "norm2"),
41
- ("out_layers.3", "conv2"),
42
- ("emb_layers.1", "time_emb_proj"),
43
- ("skip_connection", "conv_shortcut"),
44
- ]
45
-
46
- unet_conversion_map_layer = []
47
- # hardcoded number of downblocks and resnets/attentions...
48
- # would need smarter logic for other networks.
49
- for i in range(3):
50
- # loop over downblocks/upblocks
51
-
52
- for j in range(2):
53
- # loop over resnets/attentions for downblocks
54
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
55
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
56
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
57
-
58
- if i > 0:
59
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
60
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
61
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
62
-
63
- for j in range(4):
64
- # loop over resnets/attentions for upblocks
65
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
66
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
67
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
68
-
69
- if i < 2:
70
- # no attention layers in up_blocks.0
71
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
72
- sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
73
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
74
-
75
- if i < 3:
76
- # no downsample in down_blocks.3
77
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
78
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
79
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
80
-
81
- # no upsample in up_blocks.3
82
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
83
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
84
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
85
- unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
86
-
87
- hf_mid_atn_prefix = "mid_block.attentions.0."
88
- sd_mid_atn_prefix = "middle_block.1."
89
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
90
- for j in range(2):
91
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
92
- sd_mid_res_prefix = f"middle_block.{2*j}."
93
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
94
-
95
-
96
- def convert_unet_state_dict(unet_state_dict):
97
- # buyer beware: this is a *brittle* function,
98
- # and correct output requires that all of these pieces interact in
99
- # the exact order in which I have arranged them.
100
- mapping = {k: k for k in unet_state_dict.keys()}
101
- for sd_name, hf_name in unet_conversion_map:
102
- mapping[hf_name] = sd_name
103
- for k, v in mapping.items():
104
- if "resnets" in k:
105
- for sd_part, hf_part in unet_conversion_map_resnet:
106
- v = v.replace(hf_part, sd_part)
107
- mapping[k] = v
108
- for k, v in mapping.items():
109
- for sd_part, hf_part in unet_conversion_map_layer:
110
- v = v.replace(hf_part, sd_part)
111
- mapping[k] = v
112
- new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
113
- return new_state_dict
114
-
115
-
116
- # ================#
117
- # VAE Conversion #
118
- # ================#
119
-
120
- vae_conversion_map = [
121
- # (stable-diffusion, HF Diffusers)
122
- ("nin_shortcut", "conv_shortcut"),
123
- ("norm_out", "conv_norm_out"),
124
- ("mid.attn_1.", "mid_block.attentions.0."),
125
- ]
126
-
127
- for i in range(4):
128
- # down_blocks have two resnets
129
- for j in range(2):
130
- hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
131
- sd_down_prefix = f"encoder.down.{i}.block.{j}."
132
- vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
133
-
134
- if i < 3:
135
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
136
- sd_downsample_prefix = f"down.{i}.downsample."
137
- vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
138
-
139
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
140
- sd_upsample_prefix = f"up.{3-i}.upsample."
141
- vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
142
-
143
- # up_blocks have three resnets
144
- # also, up blocks in hf are numbered in reverse from sd
145
- for j in range(3):
146
- hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
147
- sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
148
- vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
149
-
150
- # this part accounts for mid blocks in both the encoder and the decoder
151
- for i in range(2):
152
- hf_mid_res_prefix = f"mid_block.resnets.{i}."
153
- sd_mid_res_prefix = f"mid.block_{i+1}."
154
- vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
155
-
156
-
157
- vae_conversion_map_attn = [
158
- # (stable-diffusion, HF Diffusers)
159
- ("norm.", "group_norm."),
160
- # the following are for SDXL
161
- ("q.", "to_q."),
162
- ("k.", "to_k."),
163
- ("v.", "to_v."),
164
- ("proj_out.", "to_out.0."),
165
- ]
166
-
167
-
168
- def reshape_weight_for_sd(w):
169
- # convert HF linear weights to SD conv2d weights
170
- if not w.ndim == 1:
171
- return w.reshape(*w.shape, 1, 1)
172
- else:
173
- return w
174
-
175
-
176
- def convert_vae_state_dict(vae_state_dict):
177
- mapping = {k: k for k in vae_state_dict.keys()}
178
- for k, v in mapping.items():
179
- for sd_part, hf_part in vae_conversion_map:
180
- v = v.replace(hf_part, sd_part)
181
- mapping[k] = v
182
- for k, v in mapping.items():
183
- if "attentions" in k:
184
- for sd_part, hf_part in vae_conversion_map_attn:
185
- v = v.replace(hf_part, sd_part)
186
- mapping[k] = v
187
- new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
188
- weights_to_convert = ["q", "k", "v", "proj_out"]
189
- for k, v in new_state_dict.items():
190
- for weight_name in weights_to_convert:
191
- if f"mid.attn_1.{weight_name}.weight" in k:
192
- print(f"Reshaping {k} for SD format")
193
- new_state_dict[k] = reshape_weight_for_sd(v)
194
- return new_state_dict
195
-
196
-
197
- # =========================#
198
- # Text Encoder Conversion #
199
- # =========================#
200
-
201
-
202
- textenc_conversion_lst = [
203
- # (stable-diffusion, HF Diffusers)
204
- ("transformer.resblocks.", "text_model.encoder.layers."),
205
- ("ln_1", "layer_norm1"),
206
- ("ln_2", "layer_norm2"),
207
- (".c_fc.", ".fc1."),
208
- (".c_proj.", ".fc2."),
209
- (".attn", ".self_attn"),
210
- ("ln_final.", "text_model.final_layer_norm."),
211
- ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
212
- ("positional_embedding", "text_model.embeddings.position_embedding.weight"),
213
- ]
214
- protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
215
- textenc_pattern = re.compile("|".join(protected.keys()))
216
-
217
- # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
218
- code2idx = {"q": 0, "k": 1, "v": 2}
219
-
220
-
221
- def convert_openclip_text_enc_state_dict(text_enc_dict):
222
- new_state_dict = {}
223
- capture_qkv_weight = {}
224
- capture_qkv_bias = {}
225
- for k, v in text_enc_dict.items():
226
- if (
227
- k.endswith(".self_attn.q_proj.weight")
228
- or k.endswith(".self_attn.k_proj.weight")
229
- or k.endswith(".self_attn.v_proj.weight")
230
- ):
231
- k_pre = k[: -len(".q_proj.weight")]
232
- k_code = k[-len("q_proj.weight")]
233
- if k_pre not in capture_qkv_weight:
234
- capture_qkv_weight[k_pre] = [None, None, None]
235
- capture_qkv_weight[k_pre][code2idx[k_code]] = v
236
- continue
237
-
238
- if (
239
- k.endswith(".self_attn.q_proj.bias")
240
- or k.endswith(".self_attn.k_proj.bias")
241
- or k.endswith(".self_attn.v_proj.bias")
242
- ):
243
- k_pre = k[: -len(".q_proj.bias")]
244
- k_code = k[-len("q_proj.bias")]
245
- if k_pre not in capture_qkv_bias:
246
- capture_qkv_bias[k_pre] = [None, None, None]
247
- capture_qkv_bias[k_pre][code2idx[k_code]] = v
248
- continue
249
-
250
- relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
251
- new_state_dict[relabelled_key] = v
252
-
253
- for k_pre, tensors in capture_qkv_weight.items():
254
- if None in tensors:
255
- raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
256
- relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
257
- new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
258
-
259
- for k_pre, tensors in capture_qkv_bias.items():
260
- if None in tensors:
261
- raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
262
- relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
263
- new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
264
-
265
- return new_state_dict
266
-
267
-
268
- def convert_openai_text_enc_state_dict(text_enc_dict):
269
- return text_enc_dict
270
-
271
-
272
- def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True, use_safetensors = True):
273
- # Path for safetensors
274
- unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
275
- vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
276
- text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
277
- text_enc_2_path = osp.join(model_path, "text_encoder_2", "model.safetensors")
278
-
279
- # Load models from safetensors if it exists, if it doesn't pytorch
280
- if osp.exists(unet_path):
281
- unet_state_dict = load_file(unet_path, device="cpu")
282
- else:
283
- unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
284
- unet_state_dict = torch.load(unet_path, map_location="cpu")
285
-
286
- if osp.exists(vae_path):
287
- vae_state_dict = load_file(vae_path, device="cpu")
288
- else:
289
- vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
290
- vae_state_dict = torch.load(vae_path, map_location="cpu")
291
-
292
- if osp.exists(text_enc_path):
293
- text_enc_dict = load_file(text_enc_path, device="cpu")
294
- else:
295
- text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
296
- text_enc_dict = torch.load(text_enc_path, map_location="cpu")
297
-
298
- if osp.exists(text_enc_2_path):
299
- text_enc_2_dict = load_file(text_enc_2_path, device="cpu")
300
- else:
301
- text_enc_2_path = osp.join(model_path, "text_encoder_2", "pytorch_model.bin")
302
- text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu")
303
-
304
- # Convert the UNet model
305
- unet_state_dict = convert_unet_state_dict(unet_state_dict)
306
- unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
307
-
308
- # Convert the VAE model
309
- vae_state_dict = convert_vae_state_dict(vae_state_dict)
310
- vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
311
-
312
- # Convert text encoder 1
313
- text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
314
- text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
315
-
316
- # Convert text encoder 2
317
- text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
318
- text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
319
- # We call the `.T.contiguous()` to match what's done in
320
- # https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085
321
- text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop(
322
- "conditioner.embedders.1.model.text_projection.weight"
323
- ).T.contiguous()
324
-
325
- # Put together new checkpoint
326
- state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
327
-
328
- if half:
329
- state_dict = {k: v.half() for k, v in state_dict.items()}
330
-
331
- if use_safetensors:
332
- save_file(state_dict, checkpoint_path)
333
- else:
334
- state_dict = {"state_dict": state_dict}
335
- torch.save(state_dict, checkpoint_path)
336
-
337
-
338
- def download_repo(repo_id, dir_path):
339
- from huggingface_hub import snapshot_download
340
- try:
341
- snapshot_download(repo_id=repo_id, local_dir=dir_path)
342
- except Exception as e:
343
- print(f"Error: Failed to download {repo_id}. ")
344
- return
345
-
346
-
347
- def convert_repo_to_safetensors(repo_id):
348
- download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
349
- output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
350
- download_repo(repo_id, download_dir)
351
- convert_diffusers_to_safetensors(download_dir, output_filename)
352
- return output_filename
353
-
354
-
355
- if __name__ == "__main__":
356
- parser = argparse.ArgumentParser()
357
-
358
- parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
359
-
360
- args = parser.parse_args()
361
- assert args.repo_id is not None, "Must provide a Repo ID!"
362
-
363
- convert_repo_to_safetensors(args.repo_id)
364
-
365
-
366
- # Usage: python convert_repo_to_safetensors.py --repo_id GraydientPlatformAPI/goodfit-pony41-xl