ailm commited on
Commit
31a9876
1 Parent(s): f28730a

import spaces

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -9,13 +9,15 @@ from torchvision.utils import save_image #to save the generated images
9
  from tqdm import tqdm # progress bar
10
  import matplotlib.pyplot as plt
11
  import gradio as gr
 
12
 
13
  from styleTransfer import style_transfer
14
  from dataTransform import tensor_to_image
15
 
16
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
  print(device)
18
 
 
19
  def gradio_style_transfer(steps, content_image, style_image):
20
  generated_tensor = style_transfer(content_image, style_image, total_steps= steps)
21
  generated_image = tensor_to_image(generated_tensor)
 
9
  from tqdm import tqdm # progress bar
10
  import matplotlib.pyplot as plt
11
  import gradio as gr
12
+ import spaces
13
 
14
  from styleTransfer import style_transfer
15
  from dataTransform import tensor_to_image
16
 
17
+ device = 'cuda'
18
  print(device)
19
 
20
+ @spaces.GPU
21
  def gradio_style_transfer(steps, content_image, style_image):
22
  generated_tensor = style_transfer(content_image, style_image, total_steps= steps)
23
  generated_image = tensor_to_image(generated_tensor)