Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ import einops
|
|
12 |
import math
|
13 |
import random
|
14 |
import pytorch_lightning as pl
|
|
|
15 |
|
16 |
def download_file(url, filename):
|
17 |
response = requests.get(url, stream=True)
|
@@ -56,7 +57,7 @@ from utils.image import auto_resize
|
|
56 |
|
57 |
config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
|
58 |
model = instantiate_from_config(config)
|
59 |
-
ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
|
60 |
load_state_dict(model, ckpt, strict=True)
|
61 |
model.freeze()
|
62 |
|
@@ -65,6 +66,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
65 |
model.to(device)
|
66 |
|
67 |
@torch.no_grad()
|
|
|
68 |
def process(
|
69 |
control_img: Image.Image,
|
70 |
num_samples: int,
|
|
|
12 |
import math
|
13 |
import random
|
14 |
import pytorch_lightning as pl
|
15 |
+
import spaces
|
16 |
|
17 |
def download_file(url, filename):
|
18 |
response = requests.get(url, stream=True)
|
|
|
57 |
|
58 |
config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
|
59 |
model = instantiate_from_config(config)
|
60 |
+
ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu", weights_only=True)
|
61 |
load_state_dict(model, ckpt, strict=True)
|
62 |
model.freeze()
|
63 |
|
|
|
66 |
model.to(device)
|
67 |
|
68 |
@torch.no_grad()
|
69 |
+
@spaces.GPU
|
70 |
def process(
|
71 |
control_img: Image.Image,
|
72 |
num_samples: int,
|