tfwang commited on
Commit
9d92961
1 Parent(s): f4a50a2

Update glide_text2im/train_util.py

Browse files
Files changed (1) hide show
  1. glide_text2im/train_util.py +10 -10
glide_text2im/train_util.py CHANGED
@@ -9,7 +9,7 @@ import torch.distributed as dist
9
  from torch.nn.parallel.distributed import DistributedDataParallel as DDP
10
  from torch.optim import AdamW
11
  from .glide_util import sample
12
- from . import dist_util, logger
13
  from .fp16_util import (
14
  make_master_params,
15
  master_params_to_model_params,
@@ -84,7 +84,7 @@ class TrainLoop:
84
  self.global_batch = self.batch_size * dist.get_world_size()
85
 
86
  if use_vgg:
87
- self.vgg = VGG(conv_index='22').to(dist_util.dev())
88
  print('use perc')
89
  else:
90
  self.vgg = None
@@ -131,8 +131,8 @@ class TrainLoop:
131
  self.use_ddp = True
132
  self.ddp_model = DDP(
133
  self.model,
134
- device_ids=[dist_util.dev()],
135
- output_device=dist_util.dev(),
136
  broadcast_buffers=False,
137
  bucket_cap_mb=128,
138
  find_unused_parameters=False,
@@ -155,7 +155,7 @@ class TrainLoop:
155
  logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
156
  self.model.load_state_dict(th.load(resume_checkpoint, map_location="cpu"),strict=False)
157
 
158
- dist_util.sync_params(self.model.parameters())
159
 
160
  def _load_ema_parameters(self, rate):
161
  ema_params = copy.deepcopy(self.master_params)
@@ -165,7 +165,7 @@ class TrainLoop:
165
  if ema_checkpoint:
166
  if dist.get_rank() == 0:
167
  logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
168
- state_dict = th.load(ema_checkpoint, map_location=dist_util.dev())
169
  ema_params = self._state_dict_to_master_params(state_dict)
170
 
171
  #dist_util.sync_params(ema_params)
@@ -227,10 +227,10 @@ class TrainLoop:
227
  def forward_backward(self, batch, model_kwargs):
228
  zero_grad(self.model_params)
229
  for i in range(0, batch.shape[0], self.microbatch):
230
- micro = batch[i : i + self.microbatch].to(dist_util.dev())
231
- micro_cond={n:model_kwargs[n][i:i+self.microbatch].to(dist_util.dev()) for n in model_kwargs if n in ['ref', 'low_res']}
232
  last_batch = (i + self.microbatch) >= batch.shape[0]
233
- t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
234
 
235
  if self.step <100:
236
  vgg_loss = None
@@ -295,7 +295,7 @@ class TrainLoop:
295
  prompt=model_kwargs,
296
  batch_size=self.glide_options['batch_size']//2,
297
  guidance_scale=guidance_scale,
298
- device=dist_util.dev(),
299
  prediction_respacing=self.glide_options['sample_respacing'],
300
  upsample_enabled=self.glide_options['super_res'],
301
  upsample_temp=0.997,
 
9
  from torch.nn.parallel.distributed import DistributedDataParallel as DDP
10
  from torch.optim import AdamW
11
  from .glide_util import sample
12
+ from . import logger
13
  from .fp16_util import (
14
  make_master_params,
15
  master_params_to_model_params,
 
84
  self.global_batch = self.batch_size * dist.get_world_size()
85
 
86
  if use_vgg:
87
+ self.vgg = VGG(conv_index='22').cuda()
88
  print('use perc')
89
  else:
90
  self.vgg = None
 
131
  self.use_ddp = True
132
  self.ddp_model = DDP(
133
  self.model,
134
+ device_ids=[torch.device('cuda')],
135
+ output_device=torch.device('cuda'),
136
  broadcast_buffers=False,
137
  bucket_cap_mb=128,
138
  find_unused_parameters=False,
 
155
  logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
156
  self.model.load_state_dict(th.load(resume_checkpoint, map_location="cpu"),strict=False)
157
 
158
+ #dist_util.sync_params(self.model.parameters())
159
 
160
  def _load_ema_parameters(self, rate):
161
  ema_params = copy.deepcopy(self.master_params)
 
165
  if ema_checkpoint:
166
  if dist.get_rank() == 0:
167
  logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
168
+ state_dict = th.load(ema_checkpoint, map_location=torch.device('cuda'))
169
  ema_params = self._state_dict_to_master_params(state_dict)
170
 
171
  #dist_util.sync_params(ema_params)
 
227
  def forward_backward(self, batch, model_kwargs):
228
  zero_grad(self.model_params)
229
  for i in range(0, batch.shape[0], self.microbatch):
230
+ micro = batch[i : i + self.microbatch].cuda()
231
+ micro_cond={n:model_kwargs[n][i:i+self.microbatch].cuda() for n in model_kwargs if n in ['ref', 'low_res']}
232
  last_batch = (i + self.microbatch) >= batch.shape[0]
233
+ t, weights = self.schedule_sampler.sample(micro.shape[0], torch.device('cuda'))
234
 
235
  if self.step <100:
236
  vgg_loss = None
 
295
  prompt=model_kwargs,
296
  batch_size=self.glide_options['batch_size']//2,
297
  guidance_scale=guidance_scale,
298
+ device=torch.device('cuda'),
299
  prediction_respacing=self.glide_options['sample_respacing'],
300
  upsample_enabled=self.glide_options['super_res'],
301
  upsample_temp=0.997,