sayakpaul HF staff commited on
Commit
ddc8a59
·
1 Parent(s): 7079251

add: files.

Browse files
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
- title: Convert Kerascv Sd Diffusers
3
- emoji: 🐢
4
- colorFrom: green
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.16.2
 
1
  ---
2
+ title: Convert Kerascv SD to Diffusers
3
+ emoji: 🧨
4
+ colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.16.2
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from convert import run_conversion
3
+ from hub_utils import save_model_card, push_to_hub
4
+
5
+
6
+ PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
7
+ DESCRIPTION = """
8
+ This Space lets you convert KerasCV Stable Diffusion weights to a format compatible with [Diffusers](https://github.com/huggingface/diffusers) 🧨. This allows users to fine-tune using KerasCV and use the fine-tuned weights in Diffusers taking advantage of its nifty features (like schedulers, fast attention, etc.). Specifically, the parameters are converted and then they are wrapped into a [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview). This pipeline is then pushed to the Hugging Face Hub given you have provided a `your_hf_token`.
9
+
10
+ ## Notes (important)
11
+
12
+ * Only Stable Diffusion (v1) is supported as of now. In particular this checkpoint: [`"CompVis/stable-diffusion-v1-4"`](https://huggingface.co/CompVis/stable-diffusion-v1-4).
13
+ * Only the text encoder and the UNet parameters converted since only these two elements are generally fine-tuned.
14
+ * [This Colab Notebook](https://colab.research.google.com/drive/1RYY077IQbAJldg8FkK8HSEpNILKHEwLb?usp=sharing) was used to develop the conversion utilities initially.
15
+ * You can choose not to provide `text_encoder_weights` and `unet_weights` in case you don't have any fine-tuned weights. In that case, the original parameters of the respective models (text encoder and UNet) from KerasCV will be used.
16
+ * You can provide only `text_encoder_weights` or `unet_weights` or both.
17
+ * When providing the weights' links, ensure they're directly downloadable. Internally, the Space uses [`tf.keras.utils.get_file()`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/get_file) to retrieve the weights locally.
18
+ * If you don't provide `your_hf_token` the converted pipeline won't be pushed.
19
+
20
+ Check [here](https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/dreambooth/train_dreambooth_lora.py#L975) for an example on how you can change the scheduler of an already initialized pipeline.
21
+ """
22
+
23
+ def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
24
+ if text_encoder_weights == "":
25
+ text_encoder_weights = None
26
+ if unet_weights == "":
27
+ unet_weights = None
28
+ pipeline = run_conversion(text_encoder_weights, unet_weights)
29
+ output_path = "kerascv_sd_diffusers_pipeline"
30
+ pipeline.save_pretrained(output_path)
31
+ save_model_card(base_model=PRETRAINED_CKPT, repo_folder=output_path, weight_paths=[text_encoder_weights, unet_weights], repo_prefix=repo_prefix)
32
+ push_str = push_to_hub(hf_token, output_path, repo_prefix)
33
+ return push_str
34
+
35
+ demo = gr.Interface(
36
+ title="KerasCV Stable Diffusion to Diffusers Stable Diffusion Pipelines 🧨🤗",
37
+ description=DESCRIPTION,
38
+ allow_flagging="never",
39
+ inputs=[gr.Text(max_lines=1, label="your_hf_token"), gr.Text(max_lines=1, label="text_encoder_weights"), gr.Text(max_lines=1, label="unet_weights"), gr.Text(max_lines=1, label="output_repo_prefix")],
40
+ outputs=[gr.Markdown(label="output")],
41
+ fn=run,
42
+ )
43
+
44
+ demo.launch()
conversion_utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .text_encoder import populate_text_encoder
2
+ from .unet import populate_unet
3
+ from .utils import run_assertion
conversion_utils/text_encoder.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras_cv.models import stable_diffusion
2
+ import tensorflow as tf
3
+ import torch
4
+ from typing import Dict
5
+
6
+ MAX_SEQ_LENGTH = 77
7
+
8
+ def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Tensor]:
9
+ """Populates the state dict from the provided TensorFlow model
10
+ (applicable only for the text encoder)."""
11
+ text_state_dict = dict()
12
+ num_encoder_layers = 0
13
+
14
+ for layer in tf_text_encoder.layers:
15
+ # Embeddings.
16
+ if isinstance(layer, stable_diffusion.text_encoder.CLIPEmbedding):
17
+ text_state_dict[
18
+ "text_model.embeddings.token_embedding.weight"
19
+ ] = torch.from_numpy(layer.token_embedding.get_weights()[0])
20
+ text_state_dict[
21
+ "text_model.embeddings.position_embedding.weight"
22
+ ] = torch.from_numpy(layer.position_embedding.get_weights()[0])
23
+
24
+ # Encoder blocks.
25
+ elif isinstance(layer, stable_diffusion.text_encoder.CLIPEncoderLayer):
26
+ # LayerNorms
27
+ for i in range(1, 3):
28
+ if i == 1:
29
+ text_state_dict[
30
+ f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.weight"
31
+ ] = torch.from_numpy(layer.layer_norm1.get_weights()[0])
32
+ text_state_dict[
33
+ f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.bias"
34
+ ] = torch.from_numpy(layer.layer_norm1.get_weights()[1])
35
+ else:
36
+ text_state_dict[
37
+ f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.weight"
38
+ ] = torch.from_numpy(layer.layer_norm2.get_weights()[0])
39
+ text_state_dict[
40
+ f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.bias"
41
+ ] = torch.from_numpy(layer.layer_norm2.get_weights()[1])
42
+
43
+ # Attention.
44
+ q_proj = layer.clip_attn.q_proj
45
+ k_proj = layer.clip_attn.k_proj
46
+ v_proj = layer.clip_attn.v_proj
47
+ out_proj = layer.clip_attn.out_proj
48
+
49
+ text_state_dict[
50
+ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.weight"
51
+ ] = torch.from_numpy(q_proj.get_weights()[0].transpose())
52
+ text_state_dict[
53
+ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.bias"
54
+ ] = torch.from_numpy(q_proj.get_weights()[1])
55
+
56
+ text_state_dict[
57
+ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.weight"
58
+ ] = torch.from_numpy(k_proj.get_weights()[0].transpose())
59
+ text_state_dict[
60
+ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.bias"
61
+ ] = torch.from_numpy(k_proj.get_weights()[1])
62
+
63
+ text_state_dict[
64
+ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.weight"
65
+ ] = torch.from_numpy(v_proj.get_weights()[0].transpose())
66
+ text_state_dict[
67
+ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.bias"
68
+ ] = torch.from_numpy(v_proj.get_weights()[1])
69
+
70
+ text_state_dict[
71
+ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.weight"
72
+ ] = torch.from_numpy(out_proj.get_weights()[0].transpose())
73
+ text_state_dict[
74
+ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.bias"
75
+ ] = torch.from_numpy(out_proj.get_weights()[1])
76
+
77
+ # MLPs.
78
+ fc1 = layer.fc1
79
+ fc2 = layer.fc2
80
+
81
+ text_state_dict[
82
+ f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.weight"
83
+ ] = torch.from_numpy(fc1.get_weights()[0].transpose())
84
+ text_state_dict[
85
+ f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.bias"
86
+ ] = torch.from_numpy(fc1.get_weights()[1])
87
+ text_state_dict[
88
+ f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.weight"
89
+ ] = torch.from_numpy(fc2.get_weights()[0].transpose())
90
+ text_state_dict[
91
+ f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.bias"
92
+ ] = torch.from_numpy(fc2.get_weights()[1])
93
+
94
+ num_encoder_layers += 1
95
+
96
+ # Final LayerNorm.
97
+ elif isinstance(layer, tf.keras.layers.LayerNormalization):
98
+ text_state_dict["text_model.final_layer_norm.weight"] = torch.from_numpy(
99
+ layer.get_weights()[0]
100
+ )
101
+ text_state_dict["text_model.final_layer_norm.bias"] = torch.from_numpy(
102
+ layer.get_weights()[1]
103
+ )
104
+
105
+ # Position ids.
106
+ text_state_dict["text_model.embeddings.position_ids"] = torch.tensor(
107
+ list(range(77))
108
+ ).unsqueeze(0)
109
+
110
+ return text_state_dict
conversion_utils/unet.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import torch
3
+ from typing import Dict
4
+ from itertools import product
5
+ from keras_cv.models import stable_diffusion
6
+
7
+ def port_transformer_block(transformer_block: tf.keras.Model, up_down: int, block_id: int, attention_id: int) -> Dict[str, torch.Tensor]:
8
+ """Populates a Transformer block."""
9
+ transformer_dict = dict()
10
+ if block_id is not None:
11
+ prefix = f"{up_down}_blocks.{block_id}"
12
+ else:
13
+ prefix = "mid_block"
14
+
15
+ # Norms.
16
+ for i in range(1, 4):
17
+ if i == 1:
18
+ norm = transformer_block.norm1
19
+ elif i == 2:
20
+ norm = transformer_block.norm2
21
+ elif i == 3:
22
+ norm = transformer_block.norm3
23
+ transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.weight"] = torch.from_numpy(norm.get_weights()[0])
24
+ transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.bias"] = torch.from_numpy(norm.get_weights()[1])
25
+
26
+ # Attentions.
27
+ for i in range(1, 3):
28
+ if i == 1:
29
+ attn = transformer_block.attn1
30
+ else:
31
+ attn = transformer_block.attn2
32
+ transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_q.weight"] = torch.from_numpy(attn.to_q.get_weights()[0].transpose())
33
+ transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_k.weight"] = torch.from_numpy(attn.to_k.get_weights()[0].transpose())
34
+ transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_v.weight"] = torch.from_numpy(attn.to_v.get_weights()[0].transpose())
35
+ transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.weight"] = torch.from_numpy(attn.out_proj.get_weights()[0].transpose())
36
+ transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.bias"] = torch.from_numpy(attn.out_proj.get_weights()[1])
37
+
38
+ # Dense.
39
+ for i in range(0, 3, 2):
40
+ if i == 0:
41
+ layer = transformer_block.geglu.dense
42
+ transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.weight"] = torch.from_numpy(layer.get_weights()[0].transpose())
43
+ transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.bias"] = torch.from_numpy(layer.get_weights()[1])
44
+ else:
45
+ layer = transformer_block.dense
46
+ transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.weight"] = torch.from_numpy(layer.get_weights()[0].transpose())
47
+ transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.bias"] = torch.from_numpy(layer.get_weights()[1])
48
+
49
+ return transformer_dict
50
+
51
+
52
+ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
53
+ """Populates the state dict from the provided TensorFlow model
54
+ (applicable only for the UNet)."""
55
+ unet_state_dict = dict()
56
+
57
+ timstep_emb = 1
58
+ padded_conv = 1
59
+ up_block = 0
60
+
61
+ up_res_blocks = list(product([0, 1, 2, 3], [0, 1, 2]))
62
+ up_res_block_flag = 0
63
+
64
+ up_spatial_transformer_blocks = list(product([1, 2, 3], [0, 1, 2]))
65
+ up_spatial_transformer_flag = 0
66
+
67
+ for layer in tf_unet.layers:
68
+ # Timstep embedding.
69
+ if isinstance(layer, tf.keras.layers.Dense):
70
+ unet_state_dict[f"time_embedding.linear_{timstep_emb}.weight"] = torch.from_numpy(layer.get_weights()[0].transpose())
71
+ unet_state_dict[f"time_embedding.linear_{timstep_emb}.bias"] = torch.from_numpy(layer.get_weights()[1])
72
+ timstep_emb += 1
73
+
74
+ # Padded convs (downsamplers).
75
+ elif isinstance(layer, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
76
+ if padded_conv == 1:
77
+ # Transposition axes taken from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_pytorch_utils.py#L104
78
+ unet_state_dict["conv_in.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
79
+ unet_state_dict["conv_in.bias"] = torch.from_numpy(layer.get_weights()[1])
80
+ elif padded_conv in [2, 3, 4]:
81
+ unet_state_dict[f"down_blocks.{padded_conv-2}.downsamplers.0.conv.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
82
+ unet_state_dict[f"down_blocks.{padded_conv-2}.downsamplers.0.conv.bias"] = torch.from_numpy(layer.get_weights()[1])
83
+ elif padded_conv == 5:
84
+ unet_state_dict["conv_out.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
85
+ unet_state_dict["conv_out.bias"] = torch.from_numpy(layer.get_weights()[1])
86
+
87
+ padded_conv += 1
88
+
89
+ # Upsamplers.
90
+ elif isinstance(layer, stable_diffusion.diffusion_model.Upsample):
91
+ conv = layer.conv
92
+ unet_state_dict[f"up_blocks.{up_block}.upsamplers.0.conv.weight"] = torch.from_numpy(conv.get_weights()[0].transpose(3, 2, 0, 1))
93
+ unet_state_dict[f"up_blocks.{up_block}.upsamplers.0.conv.bias"] = torch.from_numpy(conv.get_weights()[1])
94
+ up_block += 1
95
+
96
+ # Output norms.
97
+ elif isinstance(layer, stable_diffusion.__internal__.layers.group_normalization.GroupNormalization):
98
+ unet_state_dict["conv_norm_out.weight"] = torch.from_numpy(layer.get_weights()[0])
99
+ unet_state_dict["conv_norm_out.bias"] = torch.from_numpy(layer.get_weights()[1])
100
+
101
+ # All ResBlocks.
102
+ elif isinstance(layer, stable_diffusion.diffusion_model.ResBlock):
103
+ layer_name = layer.name
104
+ parts = layer_name.split("_")
105
+
106
+ # Down.
107
+ if len(parts) == 2 or int(parts[-1]) < 8:
108
+ entry_flow = layer.entry_flow
109
+ embedding_flow = layer.embedding_flow
110
+ exit_flow = layer.exit_flow
111
+
112
+ down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
113
+ down_resnet_id = 0 if len(parts) == 2 else int(parts[-1]) % 2
114
+
115
+ # Conv blocks.
116
+ first_conv_layer = entry_flow[-1]
117
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
118
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1])
119
+ second_conv_layer = exit_flow[-1]
120
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
121
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1])
122
+
123
+ # Residual blocks.
124
+ if hasattr(layer, "residual_projection"):
125
+ if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
126
+ residual = layer.residual_projection
127
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1))
128
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1])
129
+
130
+ # Timestep embedding.
131
+ embedding_proj = embedding_flow[-1]
132
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
133
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1])
134
+
135
+ # Norms.
136
+ first_group_norm = entry_flow[0]
137
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0])
138
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1])
139
+ second_group_norm = exit_flow[0]
140
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0])
141
+ unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1])
142
+
143
+ # Middle.
144
+ elif int(parts[-1]) == 8 or int(parts[-1]) == 9:
145
+ entry_flow = layer.entry_flow
146
+ embedding_flow = layer.embedding_flow
147
+ exit_flow = layer.exit_flow
148
+
149
+ mid_resnet_id = int(parts[-1]) % 2
150
+
151
+ # Conv blocks.
152
+ first_conv_layer = entry_flow[-1]
153
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
154
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1])
155
+ second_conv_layer = exit_flow[-1]
156
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
157
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1])
158
+
159
+ # Residual blocks.
160
+ if hasattr(layer, "residual_projection"):
161
+ if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
162
+ residual = layer.residual_projection
163
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1))
164
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1])
165
+
166
+ # Timestep embedding.
167
+ embedding_proj = embedding_flow[-1]
168
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
169
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1])
170
+
171
+ # Norms.
172
+ first_group_norm = entry_flow[0]
173
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0])
174
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1])
175
+ second_group_norm = exit_flow[0]
176
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0])
177
+ unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1])
178
+
179
+ # Up.
180
+ elif int(parts[-1]) > 9 and up_res_block_flag < len(up_res_blocks):
181
+ entry_flow = layer.entry_flow
182
+ embedding_flow = layer.embedding_flow
183
+ exit_flow = layer.exit_flow
184
+
185
+ up_res_block = up_res_blocks[up_res_block_flag]
186
+ up_block_id = up_res_block[0]
187
+ up_resnet_id = up_res_block[1]
188
+
189
+ # Conv blocks.
190
+ first_conv_layer = entry_flow[-1]
191
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
192
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1])
193
+ second_conv_layer = exit_flow[-1]
194
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
195
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1])
196
+
197
+ # Residual blocks.
198
+ if hasattr(layer, "residual_projection"):
199
+ if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
200
+ residual = layer.residual_projection
201
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1))
202
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1])
203
+
204
+ # Timestep embedding.
205
+ embedding_proj = embedding_flow[-1]
206
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
207
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1])
208
+
209
+ # Norms.
210
+ first_group_norm = entry_flow[0]
211
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0])
212
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1])
213
+ second_group_norm = exit_flow[0]
214
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0])
215
+ unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1])
216
+
217
+ up_res_block_flag += 1
218
+
219
+ # All SpatialTransformer blocks.
220
+ elif isinstance(layer, stable_diffusion.diffusion_model.SpatialTransformer):
221
+ layer_name = layer.name
222
+ parts = layer_name.split("_")
223
+
224
+ # Down.
225
+ if len(parts) == 2 or int(parts[-1]) < 6:
226
+ down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
227
+ down_attention_id = 0 if len(parts) == 2 else int(parts[-1]) % 2
228
+
229
+ # Convs.
230
+ proj1 = layer.proj1
231
+ unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
232
+ unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1])
233
+ proj2 = layer.proj2
234
+ unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
235
+ unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1])
236
+
237
+ # Transformer blocks.
238
+ transformer_block = layer.transformer_block
239
+ unet_state_dict.update(port_transformer_block(transformer_block, "down", down_block_id, down_attention_id))
240
+
241
+ # Norms.
242
+ norm = layer.norm
243
+ unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0])
244
+ unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1])
245
+
246
+ # Middle.
247
+ elif int(parts[-1]) == 6:
248
+ mid_attention_id = int(parts[-1]) % 2
249
+ # Convs.
250
+ proj1 = layer.proj1
251
+ unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
252
+ unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1])
253
+ proj2 = layer.proj2
254
+ unet_state_dict[f"mid_block.attentions.{mid_resnet_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
255
+ unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1])
256
+
257
+ # Transformer blocks.
258
+ transformer_block = layer.transformer_block
259
+ unet_state_dict.update(port_transformer_block(transformer_block, "mid", None, mid_attention_id))
260
+
261
+ # Norms.
262
+ norm = layer.norm
263
+ unet_state_dict[f"mid_block.attentions.{mid_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0])
264
+ unet_state_dict[f"mid_block.attentions.{mid_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1])
265
+
266
+ # Up.
267
+ elif int(parts[-1]) > 6 and up_spatial_transformer_flag < len(up_spatial_transformer_blocks):
268
+ up_spatial_transformer_block = up_spatial_transformer_blocks[up_spatial_transformer_flag]
269
+ up_block_id = up_spatial_transformer_block[0]
270
+ up_attention_id = up_spatial_transformer_block[1]
271
+
272
+ # Convs.
273
+ proj1 = layer.proj1
274
+ unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
275
+ unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1])
276
+ proj2 = layer.proj2
277
+ unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
278
+ unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1])
279
+
280
+ # Transformer blocks.
281
+ transformer_block = layer.transformer_block
282
+ unet_state_dict.update(port_transformer_block(transformer_block, "up", up_block_id, up_attention_id))
283
+
284
+ # Norms.
285
+ norm = layer.norm
286
+ unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0])
287
+ unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1])
288
+
289
+ up_spatial_transformer_flag += 1
290
+
291
+ return unet_state_dict
conversion_utils/utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ from typing import Dict
5
+
6
+
7
+ def run_assertion(orig_pt_state_dict: Dict[str, torch.Tensor], pt_state_dict_from_tf: Dict[str, torch.Tensor]):
8
+ for k in orig_pt_state_dict:
9
+ try:
10
+ np.testing.assert_allclose(
11
+ orig_pt_state_dict[k].numpy(),
12
+ pt_state_dict_from_tf[k].numpy()
13
+ )
14
+ except:
15
+ raise ValueError("There are problems in the parameter population process. Cannot proceed :(")
convert.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from conversion_utils import populate_text_encoder, populate_unet, run_assertion
2
+
3
+ from diffusers import (
4
+ AutoencoderKL,
5
+ StableDiffusionPipeline,
6
+ UNet2DConditionModel,
7
+ )
8
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
9
+ from transformers import CLIPTextModel
10
+ import keras_cv
11
+ import tensorflow as tf
12
+
13
+
14
+ PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
15
+ REVISION = None
16
+ NON_EMA_REVISION = None
17
+ IMG_HEIGHT = IMG_WIDTH = 512
18
+
19
+ def initialize_pt_models():
20
+ """Initializes the separate models of Stable Diffusion from diffusers and downloads
21
+ their pre-trained weights."""
22
+ pt_text_encoder = CLIPTextModel.from_pretrained(
23
+ PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
24
+ )
25
+ pt_vae = AutoencoderKL.from_pretrained(
26
+ PRETRAINED_CKPT, subfolder="vae", revision=REVISION
27
+ )
28
+ pt_unet = UNet2DConditionModel.from_pretrained(
29
+ PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION
30
+ )
31
+ pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
32
+ PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
33
+ )
34
+
35
+ return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
36
+
37
+ def initialize_tf_models():
38
+ """Initializes the separate models of Stable Diffusion from KerasCV and downloads
39
+ their pre-trained weights."""
40
+ tf_sd_model = keras_cv.models.StableDiffusion(img_height=IMG_HEIGHT, img_width=IMG_WIDTH)
41
+ _ = tf_sd_model.text_to_image("Cartoon") # To download the weights.
42
+
43
+ tf_text_encoder = tf_sd_model.text_encoder
44
+ tf_vae = tf_sd_model.image_encoder
45
+ tf_unet = tf_sd_model.diffusion_model
46
+ return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
47
+
48
+
49
+ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
50
+ pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models()
51
+ tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models()
52
+ print("Pre-trained model weights downloaded.")
53
+
54
+ if text_encoder_weights is not None:
55
+ print("Loading fine-tuned text encoder weights.")
56
+ text_encoder_weights_path = tf.keras.utils.get_file(text_encoder_weights)
57
+ tf_text_encoder.load_weights(text_encoder_weights_path)
58
+ if unet_weights is not None:
59
+ print("Loading fine-tuned UNet weights.")
60
+ unet_weights_path = tf.keras.utils.get_file(unet_weights)
61
+ tf_unet.load_weights(unet_weights_path)
62
+
63
+ text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
64
+ unet_state_dict_from_tf = populate_unet(tf_unet)
65
+ print("Conversion done, now running assertions...")
66
+
67
+ # Since we cannot compare the fine-tuned weights.
68
+ if text_encoder_weights is None:
69
+ text_encoder_state_dict_from_pt = pt_text_encoder.state_dict()
70
+ run_assertion(text_encoder_state_dict_from_pt, text_encoder_state_dict_from_tf)
71
+ if unet_weights is None:
72
+ unet_state_dict_from_pt = pt_text_encoder.state_dict()
73
+ run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)
74
+
75
+ print("Assertions successful, populating the converted parameters into the diffusers models...")
76
+ pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
77
+ pt_unet.load_state_dict(unet_state_dict_from_tf)
78
+
79
+ print("Parameters ported, preparing StabelDiffusionPipeline...")
80
+ pipeline = StableDiffusionPipeline.from_pretrained(
81
+ PRETRAINED_CKPT,
82
+ unet=pt_unet,
83
+ text_encoder=pt_text_encoder,
84
+ vae=pt_vae,
85
+ safety_checker=pt_safety_checker,
86
+ revision=None,
87
+ )
88
+ return pipeline
89
+
90
+
hub_utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .readme import save_model_card
2
+ from .repo import push_to_hub
hub_utils/readme.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ # Copied from https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/text_to_image/train_text_to_image_lora.py#L55
5
+ def save_model_card(base_model=str, repo_folder=None, weight_paths=None):
6
+ yaml = f"""
7
+ ---
8
+ license: creativeml-openrail-m
9
+ base_model: {base_model}
10
+ tags:
11
+ - stable-diffusion
12
+ - stable-diffusion-diffusers
13
+ - text-to-image
14
+ - diffusers
15
+ inference: true
16
+ ---
17
+ """
18
+ model_card = f"""
19
+ # KerasCV Stable Diffusion in Diffusers 🧨🤗
20
+
21
+ The pipeline contained in this repository was created using [this Space](https://huggingface.co/spaces/sayakpaul/convert-kerascv-sd-diffusers). The purpose is to convert the KerasCV Stable Diffusion weights in a way that is compatible with Diffusers. This allows users to fine-tune using KerasCV and use the fine-tuned weights in Diffusers taking advantage of its nifty features (like schedulers, fast attention, etc.).\n
22
+
23
+ """
24
+
25
+ if weight_paths is not None:
26
+ model_card += "Following weight paths (KerasCV) were used: {weight_paths}"
27
+
28
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
29
+ f.write(yaml + model_card)
hub_utils/repo.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi, create_repo
2
+
3
+ def push_to_hub(hf_token: str, push_dir: str, repo_prefix: None) -> str:
4
+ try:
5
+ if hf_token == "":
6
+ return "No HF token provided. Model won't be pushed."
7
+ else:
8
+ hf_api = HfApi(token=hf_token)
9
+ user = hf_api.whoami()["name"]
10
+ repo_id = f"{user}/{push_dir}" if repo_prefix == "" else f"{user}/{repo_prefix}-{push_dir}"
11
+ _ = create_repo(repo_id=repo_id, token=hf_token)
12
+ url = hf_api.upload_folder(folder_path=push_dir, repo_id=repo_id, exist_ok=True)
13
+ return f"Model successfully pushed: [{url}]({url})"
14
+ except Exception as e:
15
+ return f"{e}"
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers==4.25.1
2
+ numpy==1.21.6
3
+ torch==1.12.1
4
+ tensorflow==2.10.0
5
+ git+https://github.com/keras-team/keras-cv.git@master
6
+ git+https://github.com/huggingface/diffusers.git@main
7
+ tensorflow-datasets==4.8.0