tfwang commited on
Commit
96849a6
1 Parent(s): e9bd7f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -4,7 +4,7 @@ Train a diffusion model on images.
4
  import gradio as gr
5
  import argparse
6
  from einops import rearrange
7
- from glide_text2im import dist_util, logger
8
  from torchvision.utils import make_grid
9
  from glide_text2im.script_util import (
10
  model_and_diffusion_defaults,
@@ -30,7 +30,7 @@ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
30
 
31
  args = parser.parse_args()
32
  args_up = parser_up.parse_args()
33
- dist_util.setup_dist()
34
 
35
  if mode == 'sketch':
36
  args.mode = 'coco-edge'
@@ -59,21 +59,21 @@ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
59
 
60
  if args.model_path:
61
  print('loading model')
62
- model_ckpt = dist_util.load_state_dict(args.model_path, map_location="cpu")
63
 
64
  model.load_state_dict(
65
  model_ckpt , strict=True )
66
 
67
  if args.sr_model_path:
68
  print('loading sr model')
69
- model_ckpt2 = dist_util.load_state_dict(args.sr_model_path, map_location="cpu")
70
 
71
  model_up.load_state_dict(
72
  model_ckpt2 , strict=True )
73
 
74
 
75
- model.to(dist_util.dev())
76
- model_up.to(dist_util.dev())
77
  model.eval()
78
  model_up.eval()
79
 
@@ -122,7 +122,7 @@ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
122
  prompt=model_kwargs,
123
  batch_size= args.num_samples,
124
  guidance_scale=args.sample_c,
125
- device=dist_util.dev(),
126
  prediction_respacing= str(sample_step),
127
  upsample_enabled= False,
128
  upsample_temp=0.997,
@@ -142,7 +142,7 @@ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
142
  prompt=model_kwargs,
143
  batch_size=args.num_samples,
144
  guidance_scale=1,
145
- device=dist_util.dev(),
146
  prediction_respacing= "fast27",
147
  upsample_enabled=True,
148
  upsample_temp=0.997,
 
4
  import gradio as gr
5
  import argparse
6
  from einops import rearrange
7
+ #from glide_text2im import dist_util, logger
8
  from torchvision.utils import make_grid
9
  from glide_text2im.script_util import (
10
  model_and_diffusion_defaults,
 
30
 
31
  args = parser.parse_args()
32
  args_up = parser_up.parse_args()
33
+ #dist_util.setup_dist()
34
 
35
  if mode == 'sketch':
36
  args.mode = 'coco-edge'
 
59
 
60
  if args.model_path:
61
  print('loading model')
62
+ model_ckpt = torch.load(args.model_path, map_location="cpu")
63
 
64
  model.load_state_dict(
65
  model_ckpt , strict=True )
66
 
67
  if args.sr_model_path:
68
  print('loading sr model')
69
+ model_ckpt2 = torch.load(args.sr_model_path, map_location="cpu")
70
 
71
  model_up.load_state_dict(
72
  model_ckpt2 , strict=True )
73
 
74
 
75
+ model.cuda()
76
+ model_up.cuda()
77
  model.eval()
78
  model_up.eval()
79
 
 
122
  prompt=model_kwargs,
123
  batch_size= args.num_samples,
124
  guidance_scale=args.sample_c,
125
+ device=torch.device('cuda'),
126
  prediction_respacing= str(sample_step),
127
  upsample_enabled= False,
128
  upsample_temp=0.997,
 
142
  prompt=model_kwargs,
143
  batch_size=args.num_samples,
144
  guidance_scale=1,
145
+ device=torch.device('cuda'),
146
  prediction_respacing= "fast27",
147
  upsample_enabled=True,
148
  upsample_temp=0.997,