Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ Train a diffusion model on images.
|
|
4 |
import gradio as gr
|
5 |
import argparse
|
6 |
from einops import rearrange
|
7 |
-
from glide_text2im import dist_util, logger
|
8 |
from torchvision.utils import make_grid
|
9 |
from glide_text2im.script_util import (
|
10 |
model_and_diffusion_defaults,
|
@@ -30,7 +30,7 @@ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
|
|
30 |
|
31 |
args = parser.parse_args()
|
32 |
args_up = parser_up.parse_args()
|
33 |
-
dist_util.setup_dist()
|
34 |
|
35 |
if mode == 'sketch':
|
36 |
args.mode = 'coco-edge'
|
@@ -59,21 +59,21 @@ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
|
|
59 |
|
60 |
if args.model_path:
|
61 |
print('loading model')
|
62 |
-
model_ckpt =
|
63 |
|
64 |
model.load_state_dict(
|
65 |
model_ckpt , strict=True )
|
66 |
|
67 |
if args.sr_model_path:
|
68 |
print('loading sr model')
|
69 |
-
model_ckpt2 =
|
70 |
|
71 |
model_up.load_state_dict(
|
72 |
model_ckpt2 , strict=True )
|
73 |
|
74 |
|
75 |
-
model.
|
76 |
-
model_up.
|
77 |
model.eval()
|
78 |
model_up.eval()
|
79 |
|
@@ -122,7 +122,7 @@ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
|
|
122 |
prompt=model_kwargs,
|
123 |
batch_size= args.num_samples,
|
124 |
guidance_scale=args.sample_c,
|
125 |
-
device=
|
126 |
prediction_respacing= str(sample_step),
|
127 |
upsample_enabled= False,
|
128 |
upsample_temp=0.997,
|
@@ -142,7 +142,7 @@ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
|
|
142 |
prompt=model_kwargs,
|
143 |
batch_size=args.num_samples,
|
144 |
guidance_scale=1,
|
145 |
-
device=
|
146 |
prediction_respacing= "fast27",
|
147 |
upsample_enabled=True,
|
148 |
upsample_temp=0.997,
|
|
|
4 |
import gradio as gr
|
5 |
import argparse
|
6 |
from einops import rearrange
|
7 |
+
#from glide_text2im import dist_util, logger
|
8 |
from torchvision.utils import make_grid
|
9 |
from glide_text2im.script_util import (
|
10 |
model_and_diffusion_defaults,
|
|
|
30 |
|
31 |
args = parser.parse_args()
|
32 |
args_up = parser_up.parse_args()
|
33 |
+
#dist_util.setup_dist()
|
34 |
|
35 |
if mode == 'sketch':
|
36 |
args.mode = 'coco-edge'
|
|
|
59 |
|
60 |
if args.model_path:
|
61 |
print('loading model')
|
62 |
+
model_ckpt = torch.load(args.model_path, map_location="cpu")
|
63 |
|
64 |
model.load_state_dict(
|
65 |
model_ckpt , strict=True )
|
66 |
|
67 |
if args.sr_model_path:
|
68 |
print('loading sr model')
|
69 |
+
model_ckpt2 = torch.load(args.sr_model_path, map_location="cpu")
|
70 |
|
71 |
model_up.load_state_dict(
|
72 |
model_ckpt2 , strict=True )
|
73 |
|
74 |
|
75 |
+
model.cuda()
|
76 |
+
model_up.cuda()
|
77 |
model.eval()
|
78 |
model_up.eval()
|
79 |
|
|
|
122 |
prompt=model_kwargs,
|
123 |
batch_size= args.num_samples,
|
124 |
guidance_scale=args.sample_c,
|
125 |
+
device=torch.device('cuda'),
|
126 |
prediction_respacing= str(sample_step),
|
127 |
upsample_enabled= False,
|
128 |
upsample_temp=0.997,
|
|
|
142 |
prompt=model_kwargs,
|
143 |
batch_size=args.num_samples,
|
144 |
guidance_scale=1,
|
145 |
+
device=torch.device('cuda'),
|
146 |
prediction_respacing= "fast27",
|
147 |
upsample_enabled=True,
|
148 |
upsample_temp=0.997,
|