nupurkmr9 commited on
Commit
c16e3db
·
1 Parent(s): 7403db4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -31
app.py CHANGED
@@ -72,6 +72,8 @@ def create_training_demo(trainer: Trainer,
72
  concept_images = gr.Files(label='Images for your concept')
73
  concept_prompt = gr.Textbox(label='Concept Prompt',
74
  max_lines=1)
 
 
75
  gr.Markdown('''
76
  - Upload images of the style you are planning on training on.
77
  - For a concept prompt, use a unique, made up word to avoid collisions.
@@ -80,11 +82,13 @@ def create_training_demo(trainer: Trainer,
80
  gr.Markdown('Training Parameters')
81
  num_training_steps = gr.Number(
82
  label='Number of Training Steps', value=1000, precision=0)
83
- learning_rate = gr.Number(label='Learning Rate', value=0.0001)
84
  train_text_encoder = gr.Checkbox(label='Train Text Encoder',
85
  value=True)
 
 
86
  learning_rate_text = gr.Number(
87
- label='Learning Rate for Text Encoder', value=0.00005)
88
  gradient_accumulation = gr.Number(
89
  label='Number of Gradient Accumulation',
90
  value=1,
@@ -145,7 +149,7 @@ def find_weight_files() -> list[str]:
145
  return [path.relative_to(curr_dir).as_posix() for path in paths]
146
 
147
 
148
- def reload_lora_weight_list() -> dict:
149
  return gr.update(choices=find_weight_files())
150
 
151
 
@@ -159,23 +163,13 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
159
  label='Base Model',
160
  visible=False)
161
  reload_button = gr.Button('Reload Weight List')
162
- lora_weight_name = gr.Dropdown(choices=find_weight_files(),
163
- value='lora/lora_disney.pt',
164
- label='LoRA Weight File')
165
  prompt = gr.Textbox(
166
  label='Prompt',
167
  max_lines=1,
168
- placeholder='Example: "style of sks, baby lion"')
169
- alpha = gr.Slider(label='Alpha',
170
- minimum=0,
171
- maximum=2,
172
- step=0.05,
173
- value=1)
174
- alpha_for_text = gr.Slider(label='Alpha for Text Encoder',
175
- minimum=0,
176
- maximum=2,
177
- step=0.05,
178
- value=1)
179
  seed = gr.Slider(label='Seed',
180
  minimum=0,
181
  maximum=100000,
@@ -184,52 +178,53 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
184
  with gr.Accordion('Other Parameters', open=False):
185
  num_steps = gr.Slider(label='Number of Steps',
186
  minimum=0,
187
- maximum=100,
188
  step=1,
189
- value=50)
190
  guidance_scale = gr.Slider(label='CFG Scale',
191
  minimum=0,
192
  maximum=50,
193
  step=0.1,
194
- value=7)
 
 
 
 
 
195
 
196
  run_button = gr.Button('Generate')
197
 
198
  gr.Markdown('''
199
- - Models with names starting with "lora/" are the pretrained models provided in the [original repo](https://github.com/cloneofsimo/lora), and the ones with names starting with "results/" are your trained models.
200
  - After training, you can press "Reload Weight List" button to load your trained model names.
201
- - The pretrained models for "disney", "illust" and "pop" are trained with the concept prompt "style of sks".
202
- - The pretrained model for "kiriko" is trained with the concept prompt "game character bnha". For this model, the text encoder is also trained.
203
  ''')
204
  with gr.Column():
205
  result = gr.Image(label='Result')
206
 
207
- reload_button.click(fn=reload_lora_weight_list,
208
  inputs=None,
209
- outputs=lora_weight_name)
210
  prompt.submit(fn=pipe.run,
211
  inputs=[
212
  base_model,
213
- lora_weight_name,
214
  prompt,
215
- alpha,
216
- alpha_for_text,
217
  seed,
218
  num_steps,
219
  guidance_scale,
 
220
  ],
221
  outputs=result,
222
  queue=False)
223
  run_button.click(fn=pipe.run,
224
  inputs=[
225
  base_model,
226
- lora_weight_name,
227
  prompt,
228
- alpha,
229
- alpha_for_text,
230
  seed,
231
  num_steps,
232
  guidance_scale,
 
233
  ],
234
  outputs=result,
235
  queue=False)
 
72
  concept_images = gr.Files(label='Images for your concept')
73
  concept_prompt = gr.Textbox(label='Concept Prompt',
74
  max_lines=1)
75
+ class_prompt = gr.Textbox(label='Regularization set Prompt',
76
+ max_lines=1)
77
  gr.Markdown('''
78
  - Upload images of the style you are planning on training on.
79
  - For a concept prompt, use a unique, made up word to avoid collisions.
 
82
  gr.Markdown('Training Parameters')
83
  num_training_steps = gr.Number(
84
  label='Number of Training Steps', value=1000, precision=0)
85
+ learning_rate = gr.Number(label='Learning Rate', value=0.00001)
86
  train_text_encoder = gr.Checkbox(label='Train Text Encoder',
87
  value=True)
88
+ modifier_token = gr.Checkbox(label='modifier token',
89
+ value=True)
90
  learning_rate_text = gr.Number(
91
+ label='Learning Rate for Text Encoder', value=0.00001)
92
  gradient_accumulation = gr.Number(
93
  label='Number of Gradient Accumulation',
94
  value=1,
 
149
  return [path.relative_to(curr_dir).as_posix() for path in paths]
150
 
151
 
152
+ def reload_custom_diffusion_weight_list() -> dict:
153
  return gr.update(choices=find_weight_files())
154
 
155
 
 
163
  label='Base Model',
164
  visible=False)
165
  reload_button = gr.Button('Reload Weight List')
166
+ weight_name = gr.Dropdown(choices=find_weight_files(),
167
+ value='custom-diffusion/cat.ckpt',
168
+ label='Custom Diffusion Weight File')
169
  prompt = gr.Textbox(
170
  label='Prompt',
171
  max_lines=1,
172
+ placeholder='Example: "<new1> cat swimming in a pool"')
 
 
 
 
 
 
 
 
 
 
173
  seed = gr.Slider(label='Seed',
174
  minimum=0,
175
  maximum=100000,
 
178
  with gr.Accordion('Other Parameters', open=False):
179
  num_steps = gr.Slider(label='Number of Steps',
180
  minimum=0,
181
+ maximum=500,
182
  step=1,
183
+ value=200)
184
  guidance_scale = gr.Slider(label='CFG Scale',
185
  minimum=0,
186
  maximum=50,
187
  step=0.1,
188
+ value=6
189
+ eta = gr.Slider(label='CFG Scale',
190
+ minimum=0,
191
+ maximum=1.,
192
+ step=0.1,
193
+ value=1.)
194
 
195
  run_button = gr.Button('Generate')
196
 
197
  gr.Markdown('''
198
+ - Models with names starting with "custom-diffusion/" are the pretrained models provided in the [original repo](https://github.com/adobe-research/custom-diffusion), and the ones with names starting with "results/" are your trained models.
199
  - After training, you can press "Reload Weight List" button to load your trained model names.
 
 
200
  ''')
201
  with gr.Column():
202
  result = gr.Image(label='Result')
203
 
204
+ reload_button.click(fn=reload_custom_diffusion_weight_list,
205
  inputs=None,
206
+ outputs=weight_name)
207
  prompt.submit(fn=pipe.run,
208
  inputs=[
209
  base_model,
210
+ weight_name,
211
  prompt,
 
 
212
  seed,
213
  num_steps,
214
  guidance_scale,
215
+ eta,
216
  ],
217
  outputs=result,
218
  queue=False)
219
  run_button.click(fn=pipe.run,
220
  inputs=[
221
  base_model,
222
+ weight_name,
223
  prompt,
 
 
224
  seed,
225
  num_steps,
226
  guidance_scale,
227
+ eta,
228
  ],
229
  outputs=result,
230
  queue=False)