Ziqi commited on
Commit
ac3a4de
1 Parent(s): 76e546c
Files changed (2) hide show
  1. app.py +2 -4
  2. inference.py +1 -1
app.py CHANGED
@@ -101,7 +101,7 @@ def reload_custom_diffusion_weight_list() -> dict:
101
  return gr.update(choices=find_weight_files())
102
 
103
 
104
- def create_inference_demo(func: inference_fn, device) -> gr.Blocks:
105
  with gr.Blocks() as demo:
106
  with gr.Row():
107
  with gr.Column():
@@ -151,7 +151,6 @@ def create_inference_demo(func: inference_fn, device) -> gr.Blocks:
151
  prompt,
152
  num_samples,
153
  guidance_scale,
154
- device
155
  ],
156
  outputs=result,
157
  queue=False)
@@ -161,7 +160,6 @@ def create_inference_demo(func: inference_fn, device) -> gr.Blocks:
161
  prompt,
162
  num_samples,
163
  guidance_scale,
164
- device
165
  ],
166
  outputs=result,
167
  queue=False)
@@ -184,7 +182,7 @@ with gr.Blocks(css='style.css') as demo:
184
 
185
  with gr.Tabs():
186
  with gr.TabItem('Test'):
187
- create_inference_demo(inference_fn, device=args.device)
188
 
189
  demo.launch(
190
  enable_queue=args.enable_queue,
 
101
  return gr.update(choices=find_weight_files())
102
 
103
 
104
+ def create_inference_demo(func: inference_fn) -> gr.Blocks:
105
  with gr.Blocks() as demo:
106
  with gr.Row():
107
  with gr.Column():
 
151
  prompt,
152
  num_samples,
153
  guidance_scale,
 
154
  ],
155
  outputs=result,
156
  queue=False)
 
160
  prompt,
161
  num_samples,
162
  guidance_scale,
 
163
  ],
164
  outputs=result,
165
  queue=False)
 
182
 
183
  with gr.Tabs():
184
  with gr.TabItem('Test'):
185
+ create_inference_demo(inference_fn)
186
 
187
  demo.launch(
188
  enable_queue=args.enable_queue,
inference.py CHANGED
@@ -46,10 +46,10 @@ def inference_fn(
46
  prompt,
47
  num_samples,
48
  guidance_scale,
49
- device
50
  ):
51
 
52
  # create inference pipeline
 
53
  pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to(device)
54
 
55
  # make directory to save images
 
46
  prompt,
47
  num_samples,
48
  guidance_scale,
 
49
  ):
50
 
51
  # create inference pipeline
52
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
53
  pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to(device)
54
 
55
  # make directory to save images