KingNish commited on
Commit
c82ffd4
1 Parent(s): 66892aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -4,7 +4,7 @@ import random
4
  import spaces
5
  import torch
6
  import time
7
- from diffusers import DiffusionPipeline
8
  from custom_pipeline import FLUXPipelineWithIntermediateOutputs
9
 
10
  # Constants
@@ -19,6 +19,7 @@ dtype = torch.float16
19
  pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
20
  "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
21
  ).to("cuda")
 
22
  torch.cuda.empty_cache()
23
 
24
  # Inference function
@@ -26,7 +27,7 @@ torch.cuda.empty_cache()
26
  def generate_image(prompt, seed=42, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
27
  if randomize_seed:
28
  seed = random.randint(0, MAX_SEED)
29
- generator = torch.Generator(device="cuda").manual_seed(int(float(seed)))
30
 
31
  start_time = time.time()
32
 
 
4
  import spaces
5
  import torch
6
  import time
7
+ from diffusers import DiffusionPipeline, AutoencoderTiny
8
  from custom_pipeline import FLUXPipelineWithIntermediateOutputs
9
 
10
  # Constants
 
19
  pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
20
  "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
21
  ).to("cuda")
22
+ pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.float16)
23
  torch.cuda.empty_cache()
24
 
25
  # Inference function
 
27
  def generate_image(prompt, seed=42, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
28
  if randomize_seed:
29
  seed = random.randint(0, MAX_SEED)
30
+ generator = torch.Generator().manual_seed(int(float(seed)))
31
 
32
  start_time = time.time()
33