sayakpaul HF staff commited on
Commit
89a6b3b
·
1 Parent(s): 7081a39

add support for textual conversion pipelines.

Browse files
Files changed (3) hide show
  1. app.py +10 -2
  2. convert.py +27 -5
  3. hub_utils/readme.py +9 -1
app.py CHANGED
@@ -19,6 +19,7 @@ This Space lets you convert KerasCV Stable Diffusion weights to a format compati
19
  * [Traditional text2image fine-tuning](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image): UNet
20
 
21
  **In case none of the `text_encoder_weights` and `unet_weights` is provided, nothing will be done.**
 
22
  * When providing the weights' links, ensure they're directly downloadable. Internally, the Space uses [`tf.keras.utils.get_file()`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/get_file) to retrieve the weights locally.
23
  * If you don't provide `your_hf_token` the converted pipeline won't be pushed.
24
 
@@ -26,7 +27,7 @@ Check [here](https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640
26
  """
27
 
28
 
29
- def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
30
  if text_encoder_weights == "":
31
  text_encoder_weights = None
32
  if unet_weights == "":
@@ -35,7 +36,12 @@ def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
35
  if text_encoder_weights is None and unet_weights is None:
36
  return "❌ No fine-tuned weights provided, nothing to do."
37
 
38
- pipeline = run_conversion(text_encoder_weights, unet_weights)
 
 
 
 
 
39
  output_path = "kerascv_sd_diffusers_pipeline"
40
  pipeline.save_pretrained(output_path)
41
 
@@ -48,6 +54,7 @@ def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
48
  base_model=PRETRAINED_CKPT,
49
  repo_folder=output_path,
50
  weight_paths=weight_paths,
 
51
  )
52
  push_str = push_to_hub(hf_token, output_path, repo_prefix)
53
  return push_str
@@ -61,6 +68,7 @@ demo = gr.Interface(
61
  gr.Text(max_lines=1, label="your_hf_token"),
62
  gr.Text(max_lines=1, label="text_encoder_weights"),
63
  gr.Text(max_lines=1, label="unet_weights"),
 
64
  gr.Text(max_lines=1, label="output_repo_prefix"),
65
  ],
66
  outputs=[gr.Markdown(label="output")],
 
19
  * [Traditional text2image fine-tuning](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image): UNet
20
 
21
  **In case none of the `text_encoder_weights` and `unet_weights` is provided, nothing will be done.**
22
+ * For Textual Inversion, you MUST provide a valid `placeholder_token` i.e., the text concept used for conducting Textual Inversion.
23
  * When providing the weights' links, ensure they're directly downloadable. Internally, the Space uses [`tf.keras.utils.get_file()`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/get_file) to retrieve the weights locally.
24
  * If you don't provide `your_hf_token` the converted pipeline won't be pushed.
25
 
 
27
  """
28
 
29
 
30
+ def run(hf_token, text_encoder_weights, unet_weights, placeholder_token, repo_prefix):
31
  if text_encoder_weights == "":
32
  text_encoder_weights = None
33
  if unet_weights == "":
 
36
  if text_encoder_weights is None and unet_weights is None:
37
  return "❌ No fine-tuned weights provided, nothing to do."
38
 
39
+ if placeholder_token == "":
40
+ placeholder_token = None
41
+ if placeholder_token is not None and text_encoder_weights is None:
42
+ return "❌ Placeholder token provided but no text encoder weights were provided. Cannot proceed."
43
+
44
+ pipeline = run_conversion(text_encoder_weights, unet_weights, placeholder_token)
45
  output_path = "kerascv_sd_diffusers_pipeline"
46
  pipeline.save_pretrained(output_path)
47
 
 
54
  base_model=PRETRAINED_CKPT,
55
  repo_folder=output_path,
56
  weight_paths=weight_paths,
57
+ placeholder_token=placeholder_token,
58
  )
59
  push_str = push_to_hub(hf_token, output_path, repo_prefix)
60
  return push_str
 
68
  gr.Text(max_lines=1, label="your_hf_token"),
69
  gr.Text(max_lines=1, label="text_encoder_weights"),
70
  gr.Text(max_lines=1, label="unet_weights"),
71
+ gr.Text(max_lines=1, label="placeholder_token"),
72
  gr.Text(max_lines=1, label="output_repo_prefix"),
73
  ],
74
  outputs=[gr.Markdown(label="output")],
convert.py CHANGED
@@ -4,7 +4,7 @@ from diffusers import (AutoencoderKL, StableDiffusionPipeline,
4
  UNet2DConditionModel)
5
  from diffusers.pipelines.stable_diffusion.safety_checker import \
6
  StableDiffusionSafetyChecker
7
- from transformers import CLIPTextModel
8
 
9
  from conversion_utils import populate_text_encoder, populate_unet
10
 
@@ -21,6 +21,7 @@ def initialize_pt_models():
21
  pt_text_encoder = CLIPTextModel.from_pretrained(
22
  PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
23
  )
 
24
  pt_vae = AutoencoderKL.from_pretrained(
25
  PRETRAINED_CKPT, subfolder="vae", revision=REVISION
26
  )
@@ -31,7 +32,7 @@ def initialize_pt_models():
31
  PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
32
  )
33
 
34
- return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
35
 
36
 
37
  def initialize_tf_models(text_encoder_weights: str, unet_weights: str):
@@ -56,8 +57,18 @@ def initialize_tf_models(text_encoder_weights: str, unet_weights: str):
56
  return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
57
 
58
 
59
- def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
60
- pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models()
 
 
 
 
 
 
 
 
 
 
61
  tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models(
62
  text_encoder_weights, unet_weights
63
  )
@@ -70,19 +81,30 @@ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
70
  text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
71
  pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
72
  print("Populated PT text encoder from TF weights.")
 
73
  if unet_weights is not None:
74
  print("Loading fine-tuned UNet weights.")
75
  unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
76
  tf_unet.load_weights(unet_weights_path)
77
  unet_state_dict_from_tf = populate_unet(tf_unet)
78
  pt_unet.load_state_dict(unet_state_dict_from_tf)
79
- print("Populated PT UNet from TF weights.")
 
 
 
 
 
 
 
 
 
80
 
81
  print("Weights ported, preparing StabelDiffusionPipeline...")
82
  pipeline = StableDiffusionPipeline.from_pretrained(
83
  PRETRAINED_CKPT,
84
  unet=pt_unet,
85
  text_encoder=pt_text_encoder,
 
86
  vae=pt_vae,
87
  safety_checker=pt_safety_checker,
88
  revision=None,
 
4
  UNet2DConditionModel)
5
  from diffusers.pipelines.stable_diffusion.safety_checker import \
6
  StableDiffusionSafetyChecker
7
+ from transformers import CLIPTextModel, CLIPTokenizer
8
 
9
  from conversion_utils import populate_text_encoder, populate_unet
10
 
 
21
  pt_text_encoder = CLIPTextModel.from_pretrained(
22
  PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
23
  )
24
+ pt_tokenizer = CLIPTokenizer.from_pretrained(PRETRAINED_CKPT, subfolder="tokenizer")
25
  pt_vae = AutoencoderKL.from_pretrained(
26
  PRETRAINED_CKPT, subfolder="vae", revision=REVISION
27
  )
 
32
  PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
33
  )
34
 
35
+ return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker
36
 
37
 
38
  def initialize_tf_models(text_encoder_weights: str, unet_weights: str):
 
57
  return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
58
 
59
 
60
+ def run_conversion(
61
+ text_encoder_weights: str = None,
62
+ unet_weights: str = None,
63
+ placeholder_token: str = None,
64
+ ):
65
+ (
66
+ pt_text_encoder,
67
+ pt_tokenizer,
68
+ pt_vae,
69
+ pt_unet,
70
+ pt_safety_checker,
71
+ ) = initialize_pt_models()
72
  tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models(
73
  text_encoder_weights, unet_weights
74
  )
 
81
  text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
82
  pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
83
  print("Populated PT text encoder from TF weights.")
84
+
85
  if unet_weights is not None:
86
  print("Loading fine-tuned UNet weights.")
87
  unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
88
  tf_unet.load_weights(unet_weights_path)
89
  unet_state_dict_from_tf = populate_unet(tf_unet)
90
  pt_unet.load_state_dict(unet_state_dict_from_tf)
91
+ print("Populated PT UNet from TF weights.")
92
+
93
+ if placeholder_token is not None:
94
+ print("Adding the placeholder_token to CLIPTokenizer...")
95
+ num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
96
+ if num_added_tokens == 0:
97
+ raise ValueError(
98
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
99
+ " `placeholder_token` that is not already in the tokenizer."
100
+ )
101
 
102
  print("Weights ported, preparing StabelDiffusionPipeline...")
103
  pipeline = StableDiffusionPipeline.from_pretrained(
104
  PRETRAINED_CKPT,
105
  unet=pt_unet,
106
  text_encoder=pt_text_encoder,
107
+ tokenizer=pt_tokenizer,
108
  vae=pt_vae,
109
  safety_checker=pt_safety_checker,
110
  revision=None,
hub_utils/readme.py CHANGED
@@ -3,7 +3,12 @@ from typing import List
3
 
4
 
5
  # Copied from https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/text_to_image/train_text_to_image_lora.py#L55
6
- def save_model_card(base_model=str, repo_folder=None, weight_paths: List = None):
 
 
 
 
 
7
  yaml = f"""
8
  ---
9
  license: creativeml-openrail-m
@@ -26,5 +31,8 @@ The pipeline contained in this repository was created using [this Space](https:/
26
  if len(weight_paths) > 0:
27
  model_card += f"Following weight paths (KerasCV) were used \n: {weight_paths}"
28
 
 
 
 
29
  with open(os.path.join(repo_folder, "README.md"), "w") as f:
30
  f.write(yaml + model_card)
 
3
 
4
 
5
  # Copied from https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/text_to_image/train_text_to_image_lora.py#L55
6
+ def save_model_card(
7
+ base_model=str,
8
+ repo_folder=None,
9
+ weight_paths: List = None,
10
+ placeholder_token: str = None,
11
+ ):
12
  yaml = f"""
13
  ---
14
  license: creativeml-openrail-m
 
31
  if len(weight_paths) > 0:
32
  model_card += f"Following weight paths (KerasCV) were used \n: {weight_paths}"
33
 
34
+ if placeholder_token is not None:
35
+ model_card += "\nFollowing `placeholder_token` was added to the tokenizer: {placeholder_token}."
36
+
37
  with open(os.path.join(repo_folder, "README.md"), "w") as f:
38
  f.write(yaml + model_card)