Ziqi commited on
Commit
1b0d680
1 Parent(s): 684fb8f

edit device; edit dropdown

Browse files
Files changed (2) hide show
  1. app.py +10 -8
  2. inference.py +4 -3
app.py CHANGED
@@ -1,11 +1,11 @@
1
  #!/usr/bin/env python
2
  """Demo app for https://github.com/ziqihuangg/ReVersion.
3
  The code in this repo is partly adapted from the following repository:
4
- https://huggingface.co/spaces/hysts/LoRA-SD-training
5
 
6
  S-Lab License 1.0
7
 
8
- Copyright 2022 S-Lab
9
  Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
10
  1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
11
  2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
@@ -101,13 +101,13 @@ def reload_custom_diffusion_weight_list() -> dict:
101
  return gr.update(choices=find_weight_files())
102
 
103
 
104
- def create_inference_demo(func: inference_fn) -> gr.Blocks:
105
  with gr.Blocks() as demo:
106
  with gr.Row():
107
  with gr.Column():
108
  model_id = gr.Dropdown(
109
- choices=['experiments/painted_on'],
110
- value='experiments/painted_on',
111
  label='Relation',
112
  visible=True)
113
  # reload_button = gr.Button('Reload Weight List')
@@ -151,7 +151,8 @@ def create_inference_demo(func: inference_fn) -> gr.Blocks:
151
  prompt,
152
  placeholder_string,
153
  num_samples,
154
- guidance_scale
 
155
  ],
156
  outputs=result,
157
  queue=False)
@@ -161,7 +162,8 @@ def create_inference_demo(func: inference_fn) -> gr.Blocks:
161
  prompt,
162
  placeholder_string,
163
  num_samples,
164
- guidance_scale
 
165
  ],
166
  outputs=result,
167
  queue=False)
@@ -184,7 +186,7 @@ with gr.Blocks(css='style.css') as demo:
184
 
185
  with gr.Tabs():
186
  with gr.TabItem('Test'):
187
- create_inference_demo(inference_fn)
188
 
189
  demo.launch(
190
  enable_queue=args.enable_queue,
 
1
  #!/usr/bin/env python
2
  """Demo app for https://github.com/ziqihuangg/ReVersion.
3
  The code in this repo is partly adapted from the following repository:
4
+ https://huggingface.co/spaces/nupurkmr9/custom-diffusion
5
 
6
  S-Lab License 1.0
7
 
8
+ Copyright 2023 S-Lab
9
  Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
10
  1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
11
  2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
 
101
  return gr.update(choices=find_weight_files())
102
 
103
 
104
+ def create_inference_demo(func: inference_fn, device=args.device) -> gr.Blocks:
105
  with gr.Blocks() as demo:
106
  with gr.Row():
107
  with gr.Column():
108
  model_id = gr.Dropdown(
109
+ choices=['painted_on'],
110
+ value='painted_on',
111
  label='Relation',
112
  visible=True)
113
  # reload_button = gr.Button('Reload Weight List')
 
151
  prompt,
152
  placeholder_string,
153
  num_samples,
154
+ guidance_scale,
155
+ device
156
  ],
157
  outputs=result,
158
  queue=False)
 
162
  prompt,
163
  placeholder_string,
164
  num_samples,
165
+ guidance_scale,
166
+ device
167
  ],
168
  outputs=result,
169
  queue=False)
 
186
 
187
  with gr.Tabs():
188
  with gr.TabItem('Test'):
189
+ create_inference_demo(inference_fn, args.device)
190
 
191
  demo.launch(
192
  enable_queue=args.enable_queue,
inference.py CHANGED
@@ -46,14 +46,15 @@ def inference_fn(
46
  prompt,
47
  placeholder_string,
48
  num_samples,
49
- guidance_scale
 
50
  ):
51
 
52
  # create inference pipeline
53
- pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
54
 
55
  # make directory to save images
56
- image_root_folder = os.path.join(model_id, 'inference')
57
  os.makedirs(image_root_folder, exist_ok = True)
58
 
59
  # if prompt is None and args.template_name is None:
 
46
  prompt,
47
  placeholder_string,
48
  num_samples,
49
+ guidance_scale,
50
+ device
51
  ):
52
 
53
  # create inference pipeline
54
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to(device)
55
 
56
  # make directory to save images
57
+ image_root_folder = os.path.join('experiments', model_id, 'inference')
58
  os.makedirs(image_root_folder, exist_ok = True)
59
 
60
  # if prompt is None and args.template_name is None: