Contrebande Labs commited on
Commit
b07346d
1 Parent(s): e0cb68e

put CPU offloading and half precision back

Browse files
Files changed (1) hide show
  1. app.py +18 -13
app.py CHANGED
@@ -18,7 +18,6 @@ from transformers import ByT5Tokenizer, FlaxT5ForConditionalGeneration
18
 
19
 
20
  def get_inference_lambda(seed):
21
-
22
  tokenizer = ByT5Tokenizer()
23
 
24
  language_model = FlaxT5ForConditionalGeneration.from_pretrained(
@@ -53,17 +52,17 @@ def get_inference_lambda(seed):
53
  }
54
  )
55
  timesteps = 20
56
- guidance_scale = jnp.array([7.5], dtype=jnp.float32)
57
 
58
  unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
59
  "character-aware-diffusion/charred",
60
- dtype=jnp.float32,
61
  )
62
 
63
  vae, vae_params = FlaxAutoencoderKL.from_pretrained(
64
  "flax/stable-diffusion-2-1",
65
  subfolder="vae",
66
- dtype=jnp.float32,
67
  )
68
  vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
69
 
@@ -71,14 +70,13 @@ def get_inference_lambda(seed):
71
 
72
  # Generating latent shape
73
  latent_shape = (
74
- negative_prompt_text_encoder_hidden_states.shape[0],
75
  unet.in_channels,
76
  image_width // vae_scale_factor,
77
  image_height // vae_scale_factor,
78
  )
79
 
80
  def __tokenize_prompt(prompt: str):
81
-
82
  return tokenizer(
83
  text=prompt,
84
  max_length=1024,
@@ -91,20 +89,21 @@ def get_inference_lambda(seed):
91
  # create PIL image from JAX tensor converted to numpy
92
  return Image.fromarray(np.asarray(image), mode="RGB")
93
 
94
- def __predict_image(tokenized_prompt: jnp.array):
95
-
96
  # Get the text embedding
97
  text_encoder_hidden_states = text_encoder(
98
  tokenized_prompt,
99
  params=text_encoder_params,
100
  train=False,
101
  )[0]
102
- context = jnp.concatenate(
 
 
103
  [negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
104
  )
105
 
 
106
  def ___timestep(step, step_args):
107
-
108
  latents, scheduler_state = step_args
109
 
110
  t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
@@ -153,7 +152,7 @@ def get_inference_lambda(seed):
153
  # initialize latents
154
  initial_latents = (
155
  jax.random.normal(
156
- jax.random.PRNGKey(seed), shape=latent_shape, dtype=jnp.float32
157
  )
158
  * initial_scheduler_state.init_noise_sigma
159
  )
@@ -175,10 +174,16 @@ def get_inference_lambda(seed):
175
  .astype(jnp.uint8)[0]
176
  )
177
 
178
- jax_jit_compiled_predict_image = jax.jit(__predict_image)
 
 
 
 
179
 
180
  return lambda prompt: __convert_image(
181
- jax_jit_compiled_predict_image(__tokenize_prompt(prompt))
 
 
182
  )
183
 
184
 
 
18
 
19
 
20
  def get_inference_lambda(seed):
 
21
  tokenizer = ByT5Tokenizer()
22
 
23
  language_model = FlaxT5ForConditionalGeneration.from_pretrained(
 
52
  }
53
  )
54
  timesteps = 20
55
+ guidance_scale = jnp.array([7.5], dtype=jnp.bfloat16)
56
 
57
  unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
58
  "character-aware-diffusion/charred",
59
+ dtype=jnp.bfloat16,
60
  )
61
 
62
  vae, vae_params = FlaxAutoencoderKL.from_pretrained(
63
  "flax/stable-diffusion-2-1",
64
  subfolder="vae",
65
+ dtype=jnp.bfloat16,
66
  )
67
  vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
68
 
 
70
 
71
  # Generating latent shape
72
  latent_shape = (
73
+ negative_prompt_text_encoder_hidden_states.shape[0], # is th
74
  unet.in_channels,
75
  image_width // vae_scale_factor,
76
  image_height // vae_scale_factor,
77
  )
78
 
79
  def __tokenize_prompt(prompt: str):
 
80
  return tokenizer(
81
  text=prompt,
82
  max_length=1024,
 
89
  # create PIL image from JAX tensor converted to numpy
90
  return Image.fromarray(np.asarray(image), mode="RGB")
91
 
92
+ def __get_context(tokenized_prompt: jnp.array):
 
93
  # Get the text embedding
94
  text_encoder_hidden_states = text_encoder(
95
  tokenized_prompt,
96
  params=text_encoder_params,
97
  train=False,
98
  )[0]
99
+
100
+ # context = empty negative prompt embedding + prompt embedding
101
+ return jnp.concatenate(
102
  [negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
103
  )
104
 
105
+ def __predict_image(context: jnp.array):
106
  def ___timestep(step, step_args):
 
107
  latents, scheduler_state = step_args
108
 
109
  t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
 
152
  # initialize latents
153
  initial_latents = (
154
  jax.random.normal(
155
+ jax.random.PRNGKey(seed), shape=latent_shape, dtype=jnp.bfloat16
156
  )
157
  * initial_scheduler_state.init_noise_sigma
158
  )
 
174
  .astype(jnp.uint8)[0]
175
  )
176
 
177
+ jax_jit_compiled_accel_predict_image = jax.jit(__predict_image)
178
+
179
+ jax_jit_compiled_cpu_get_context = jax.jit(
180
+ __get_context, device=jax.devices(backend="cpu")[0]
181
+ )
182
 
183
  return lambda prompt: __convert_image(
184
+ jax_jit_compiled_accel_predict_image(
185
+ jax_jit_compiled_cpu_get_context(__tokenize_prompt(prompt))
186
+ )
187
  )
188
 
189