owiedotch commited on
Commit
de38782
·
verified ·
1 Parent(s): 59593f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -12,6 +12,7 @@ import einops
12
  import math
13
  import random
14
  import pytorch_lightning as pl
 
15
 
16
  def download_file(url, filename):
17
  response = requests.get(url, stream=True)
@@ -56,7 +57,7 @@ from utils.image import auto_resize
56
 
57
  config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
58
  model = instantiate_from_config(config)
59
- ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
60
  load_state_dict(model, ckpt, strict=True)
61
  model.freeze()
62
 
@@ -65,6 +66,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
  model.to(device)
66
 
67
  @torch.no_grad()
 
68
  def process(
69
  control_img: Image.Image,
70
  num_samples: int,
 
12
  import math
13
  import random
14
  import pytorch_lightning as pl
15
+ import spaces
16
 
17
  def download_file(url, filename):
18
  response = requests.get(url, stream=True)
 
57
 
58
  config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
59
  model = instantiate_from_config(config)
60
+ ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu", weights_only=True)
61
  load_state_dict(model, ckpt, strict=True)
62
  model.freeze()
63
 
 
66
  model.to(device)
67
 
68
  @torch.no_grad()
69
+ @spaces.GPU
70
  def process(
71
  control_img: Image.Image,
72
  num_samples: int,