Spaces:
Runtime error
Runtime error
Update glide_text2im/train_util.py
Browse files- glide_text2im/train_util.py +10 -10
glide_text2im/train_util.py
CHANGED
@@ -9,7 +9,7 @@ import torch.distributed as dist
|
|
9 |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
10 |
from torch.optim import AdamW
|
11 |
from .glide_util import sample
|
12 |
-
from . import
|
13 |
from .fp16_util import (
|
14 |
make_master_params,
|
15 |
master_params_to_model_params,
|
@@ -84,7 +84,7 @@ class TrainLoop:
|
|
84 |
self.global_batch = self.batch_size * dist.get_world_size()
|
85 |
|
86 |
if use_vgg:
|
87 |
-
self.vgg = VGG(conv_index='22').
|
88 |
print('use perc')
|
89 |
else:
|
90 |
self.vgg = None
|
@@ -131,8 +131,8 @@ class TrainLoop:
|
|
131 |
self.use_ddp = True
|
132 |
self.ddp_model = DDP(
|
133 |
self.model,
|
134 |
-
device_ids=[
|
135 |
-
output_device=
|
136 |
broadcast_buffers=False,
|
137 |
bucket_cap_mb=128,
|
138 |
find_unused_parameters=False,
|
@@ -155,7 +155,7 @@ class TrainLoop:
|
|
155 |
logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
|
156 |
self.model.load_state_dict(th.load(resume_checkpoint, map_location="cpu"),strict=False)
|
157 |
|
158 |
-
dist_util.sync_params(self.model.parameters())
|
159 |
|
160 |
def _load_ema_parameters(self, rate):
|
161 |
ema_params = copy.deepcopy(self.master_params)
|
@@ -165,7 +165,7 @@ class TrainLoop:
|
|
165 |
if ema_checkpoint:
|
166 |
if dist.get_rank() == 0:
|
167 |
logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
|
168 |
-
state_dict = th.load(ema_checkpoint, map_location=
|
169 |
ema_params = self._state_dict_to_master_params(state_dict)
|
170 |
|
171 |
#dist_util.sync_params(ema_params)
|
@@ -227,10 +227,10 @@ class TrainLoop:
|
|
227 |
def forward_backward(self, batch, model_kwargs):
|
228 |
zero_grad(self.model_params)
|
229 |
for i in range(0, batch.shape[0], self.microbatch):
|
230 |
-
micro = batch[i : i + self.microbatch].
|
231 |
-
micro_cond={n:model_kwargs[n][i:i+self.microbatch].
|
232 |
last_batch = (i + self.microbatch) >= batch.shape[0]
|
233 |
-
t, weights = self.schedule_sampler.sample(micro.shape[0],
|
234 |
|
235 |
if self.step <100:
|
236 |
vgg_loss = None
|
@@ -295,7 +295,7 @@ class TrainLoop:
|
|
295 |
prompt=model_kwargs,
|
296 |
batch_size=self.glide_options['batch_size']//2,
|
297 |
guidance_scale=guidance_scale,
|
298 |
-
device=
|
299 |
prediction_respacing=self.glide_options['sample_respacing'],
|
300 |
upsample_enabled=self.glide_options['super_res'],
|
301 |
upsample_temp=0.997,
|
|
|
9 |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
10 |
from torch.optim import AdamW
|
11 |
from .glide_util import sample
|
12 |
+
from . import logger
|
13 |
from .fp16_util import (
|
14 |
make_master_params,
|
15 |
master_params_to_model_params,
|
|
|
84 |
self.global_batch = self.batch_size * dist.get_world_size()
|
85 |
|
86 |
if use_vgg:
|
87 |
+
self.vgg = VGG(conv_index='22').cuda()
|
88 |
print('use perc')
|
89 |
else:
|
90 |
self.vgg = None
|
|
|
131 |
self.use_ddp = True
|
132 |
self.ddp_model = DDP(
|
133 |
self.model,
|
134 |
+
device_ids=[torch.device('cuda')],
|
135 |
+
output_device=torch.device('cuda'),
|
136 |
broadcast_buffers=False,
|
137 |
bucket_cap_mb=128,
|
138 |
find_unused_parameters=False,
|
|
|
155 |
logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
|
156 |
self.model.load_state_dict(th.load(resume_checkpoint, map_location="cpu"),strict=False)
|
157 |
|
158 |
+
#dist_util.sync_params(self.model.parameters())
|
159 |
|
160 |
def _load_ema_parameters(self, rate):
|
161 |
ema_params = copy.deepcopy(self.master_params)
|
|
|
165 |
if ema_checkpoint:
|
166 |
if dist.get_rank() == 0:
|
167 |
logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
|
168 |
+
state_dict = th.load(ema_checkpoint, map_location=torch.device('cuda'))
|
169 |
ema_params = self._state_dict_to_master_params(state_dict)
|
170 |
|
171 |
#dist_util.sync_params(ema_params)
|
|
|
227 |
def forward_backward(self, batch, model_kwargs):
|
228 |
zero_grad(self.model_params)
|
229 |
for i in range(0, batch.shape[0], self.microbatch):
|
230 |
+
micro = batch[i : i + self.microbatch].cuda()
|
231 |
+
micro_cond={n:model_kwargs[n][i:i+self.microbatch].cuda() for n in model_kwargs if n in ['ref', 'low_res']}
|
232 |
last_batch = (i + self.microbatch) >= batch.shape[0]
|
233 |
+
t, weights = self.schedule_sampler.sample(micro.shape[0], torch.device('cuda'))
|
234 |
|
235 |
if self.step <100:
|
236 |
vgg_loss = None
|
|
|
295 |
prompt=model_kwargs,
|
296 |
batch_size=self.glide_options['batch_size']//2,
|
297 |
guidance_scale=guidance_scale,
|
298 |
+
device=torch.device('cuda'),
|
299 |
prediction_respacing=self.glide_options['sample_respacing'],
|
300 |
upsample_enabled=self.glide_options['super_res'],
|
301 |
upsample_temp=0.997,
|