Spaces:
Runtime error
Runtime error
Update glide_text2im/adv.py
Browse files- glide_text2im/adv.py +5 -5
glide_text2im/adv.py
CHANGED
@@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|
4 |
import torch.optim as optim
|
5 |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
6 |
from .nn import mean_flat
|
7 |
-
from . import dist_util
|
8 |
import functools
|
9 |
|
10 |
class AdversarialLoss(nn.Module):
|
@@ -16,11 +16,11 @@ class AdversarialLoss(nn.Module):
|
|
16 |
self.gan_type = gan_type
|
17 |
self.gan_k = gan_k
|
18 |
|
19 |
-
model = NLayerDiscriminator().
|
20 |
self.discriminator = DDP(
|
21 |
model,
|
22 |
-
device_ids=[
|
23 |
-
output_device=
|
24 |
broadcast_buffers=False,
|
25 |
bucket_cap_mb=128,
|
26 |
find_unused_parameters=False,
|
@@ -41,7 +41,7 @@ class AdversarialLoss(nn.Module):
|
|
41 |
if (self.gan_type.find('WGAN') >= 0):
|
42 |
loss_d = (d_fake - d_real).mean()
|
43 |
if self.gan_type.find('GP') >= 0:
|
44 |
-
epsilon = torch.rand(real.size(0), 1, 1, 1).
|
45 |
epsilon = epsilon.expand(real.size())
|
46 |
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
|
47 |
hat.requires_grad = True
|
|
|
4 |
import torch.optim as optim
|
5 |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
6 |
from .nn import mean_flat
|
7 |
+
#from . import dist_util
|
8 |
import functools
|
9 |
|
10 |
class AdversarialLoss(nn.Module):
|
|
|
16 |
self.gan_type = gan_type
|
17 |
self.gan_k = gan_k
|
18 |
|
19 |
+
model = NLayerDiscriminator().cuda()
|
20 |
self.discriminator = DDP(
|
21 |
model,
|
22 |
+
device_ids=[torch.device('cuda')],
|
23 |
+
output_device=torch.device('cuda'),
|
24 |
broadcast_buffers=False,
|
25 |
bucket_cap_mb=128,
|
26 |
find_unused_parameters=False,
|
|
|
41 |
if (self.gan_type.find('WGAN') >= 0):
|
42 |
loss_d = (d_fake - d_real).mean()
|
43 |
if self.gan_type.find('GP') >= 0:
|
44 |
+
epsilon = torch.rand(real.size(0), 1, 1, 1).cuda()
|
45 |
epsilon = epsilon.expand(real.size())
|
46 |
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
|
47 |
hat.requires_grad = True
|