taki0112 commited on
Commit
95aa48f
·
1 Parent(s): 0d4c732
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -13,6 +13,7 @@ from random import randint
13
  from utils import init_latent
14
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
15
  from diffusers import DDIMScheduler
 
16
 
17
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
  if device == 'cpu':
@@ -147,6 +148,7 @@ def load_example_style():
147
 
148
  return examples
149
 
 
150
  def style_fn(image_path, style_name, content_text, output_number, diffusion_step=50):
151
  user_image_flag = not style_name.strip() # empty
152
 
@@ -219,7 +221,7 @@ def style_fn(image_path, style_name, content_text, output_number, diffusion_step
219
  use_advanced_sampling = False
220
  use_prompt_as_null = True
221
 
222
- ref_prompt = blip_inf_prompt(origin_real_img).to(device)
223
  inf_prompt = inf_object
224
  style_description_neg = None
225
 
@@ -275,6 +277,7 @@ def style_fn(image_path, style_name, content_text, output_number, diffusion_step
275
 
276
  return grid
277
 
 
278
  def controlnet_fn(image_path, depth_image_path, style_name, content_text, output_number, controlnet_scale=0.5, diffusion_step=50):
279
  config_path = './config/{}.json'.format(style_name)
280
  config = parse_config(config_path)
 
13
  from utils import init_latent
14
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
15
  from diffusers import DDIMScheduler
16
+ import spaces
17
 
18
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
  if device == 'cpu':
 
148
 
149
  return examples
150
 
151
+ @spaces.GPU
152
  def style_fn(image_path, style_name, content_text, output_number, diffusion_step=50):
153
  user_image_flag = not style_name.strip() # empty
154
 
 
221
  use_advanced_sampling = False
222
  use_prompt_as_null = True
223
 
224
+ ref_prompt = blip_inf_prompt(origin_real_img)
225
  inf_prompt = inf_object
226
  style_description_neg = None
227
 
 
277
 
278
  return grid
279
 
280
+ @spaces.GPU
281
  def controlnet_fn(image_path, depth_image_path, style_name, content_text, output_number, controlnet_scale=0.5, diffusion_step=50):
282
  config_path = './config/{}.json'.format(style_name)
283
  config = parse_config(config_path)