Eyalgut commited on
Commit
5bf77c5
·
verified ·
1 Parent(s): fa999ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -28,14 +28,13 @@ pipe.load_lora_weights(f'{pipeline_path}/pytorch_lora_weights.safetensors')
28
  pipe.fuse_lora()
29
  pipe.unload_lora_weights()
30
  pipe.force_zeros_for_empty_prompt = False
 
31
 
32
- def get_pipe():
33
- global pipe
34
- if type(pipe)!=EllaXLPipeline:
35
- pipe.to("cuda")
36
- pipe = EllaXLPipeline(pipe,f'{pipeline_path}/pytorch_model.bin')
37
-
38
- return pipe
39
 
40
 
41
  # print("Optimizing BRIA-2.4 - this could take a while")
@@ -62,7 +61,8 @@ def get_pipe():
62
  @spaces.GPU(enable_queue=True)
63
  def infer(prompt,negative_prompt,seed,resolution, steps):
64
 
65
- pipe = get_pipe()
 
66
 
67
  print(f"""
68
  —/n
 
28
  pipe.fuse_lora()
29
  pipe.unload_lora_weights()
30
  pipe.force_zeros_for_empty_prompt = False
31
+ pipe = EllaXLPipeline(pipe,f'{pipeline_path}/pytorch_model.bin')
32
 
33
+ def tocuda():
34
+ pipe.pipe.vae.to('cuda')
35
+ pipe.t5_encoder.to('cuda')
36
+ pipe.pipe.unet.unet.to('cuda')
37
+ pipe.pipe.unet.ella.to('cuda')
 
 
38
 
39
 
40
  # print("Optimizing BRIA-2.4 - this could take a while")
 
61
  @spaces.GPU(enable_queue=True)
62
  def infer(prompt,negative_prompt,seed,resolution, steps):
63
 
64
+ if 'cuda' not in pipe.pipe.device.type:
65
+ tocuda()
66
 
67
  print(f"""
68
  —/n