Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
•
8ab4de9
1
Parent(s):
3dcdf92
add basic cross attention + global attention block
Browse files- score_sde/models/layers.py +1 -1
- score_sde/models/layerspp.py +28 -0
- score_sde/models/ncsnpp_generator_adagn.py +42 -4
- train_ddgan.py +39 -25
score_sde/models/layers.py
CHANGED
@@ -583,7 +583,7 @@ class Identity(nn.Module):
|
|
583 |
def forward(self, x, *args, **kwargs):
|
584 |
return x
|
585 |
|
586 |
-
|
587 |
class CrossAttention(nn.Module):
|
588 |
def __init__(
|
589 |
self,
|
|
|
583 |
def forward(self, x, *args, **kwargs):
|
584 |
return x
|
585 |
|
586 |
+
|
587 |
class CrossAttention(nn.Module):
|
588 |
def __init__(
|
589 |
self,
|
score_sde/models/layerspp.py
CHANGED
@@ -123,6 +123,34 @@ class AttnBlockpp(nn.Module):
|
|
123 |
else:
|
124 |
return (x + h) / np.sqrt(2.)
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
class Upsample(nn.Module):
|
128 |
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
|
|
|
123 |
else:
|
124 |
return (x + h) / np.sqrt(2.)
|
125 |
|
126 |
+
class AttnBlockppRaw(nn.Module):
|
127 |
+
"""Channel-wise self-attention block. Modified from DDPM."""
|
128 |
+
|
129 |
+
def __init__(self, channels, skip_rescale=False, init_scale=0.):
|
130 |
+
super().__init__()
|
131 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
|
132 |
+
eps=1e-6)
|
133 |
+
self.NIN_0 = NIN(channels, channels)
|
134 |
+
self.NIN_1 = NIN(channels, channels)
|
135 |
+
self.NIN_2 = NIN(channels, channels)
|
136 |
+
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
|
137 |
+
self.skip_rescale = skip_rescale
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
B, C, H, W = x.shape
|
141 |
+
h = self.GroupNorm_0(x)
|
142 |
+
q = self.NIN_0(h)
|
143 |
+
k = self.NIN_1(h)
|
144 |
+
v = self.NIN_2(h)
|
145 |
+
|
146 |
+
w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
|
147 |
+
w = torch.reshape(w, (B, H, W, H * W))
|
148 |
+
w = F.softmax(w, dim=-1)
|
149 |
+
w = torch.reshape(w, (B, H, W, H, W))
|
150 |
+
h = torch.einsum('bhwij,bcij->bchw', w, v)
|
151 |
+
h = self.NIN_3(h)
|
152 |
+
return h
|
153 |
+
|
154 |
|
155 |
class Upsample(nn.Module):
|
156 |
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
|
score_sde/models/ncsnpp_generator_adagn.py
CHANGED
@@ -53,6 +53,36 @@ get_act = layers.get_act
|
|
53 |
default_initializer = layers.default_init
|
54 |
dense = dense_layer.dense
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
class PixelNorm(nn.Module):
|
57 |
def __init__(self):
|
58 |
super().__init__()
|
@@ -68,6 +98,7 @@ class NCSNpp(nn.Module):
|
|
68 |
def __init__(self, config):
|
69 |
super().__init__()
|
70 |
self.config = config
|
|
|
71 |
self.grad_checkpointing = config.grad_checkpointing if hasattr(config, "grad_checkpointing") else False
|
72 |
self.not_use_tanh = config.not_use_tanh
|
73 |
self.act = act = nn.SiLU()
|
@@ -124,7 +155,14 @@ class NCSNpp(nn.Module):
|
|
124 |
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
125 |
nn.init.zeros_(modules[-1].bias)
|
126 |
if config.cross_attention:
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
else:
|
129 |
AttnBlock = functools.partial(layerspp.AttnBlockpp,
|
130 |
init_scale=init_scale,
|
@@ -342,7 +380,7 @@ class NCSNpp(nn.Module):
|
|
342 |
h = modules[m_idx](hs[-1], temb, zemb)
|
343 |
m_idx += 1
|
344 |
if h.shape[-1] in self.attn_resolutions:
|
345 |
-
if type(modules[m_idx])
|
346 |
h = modules[m_idx](h, cond, cond_mask)
|
347 |
else:
|
348 |
h = modules[m_idx](h)
|
@@ -377,7 +415,7 @@ class NCSNpp(nn.Module):
|
|
377 |
h = hs[-1]
|
378 |
h = modules[m_idx](h, temb, zemb)
|
379 |
m_idx += 1
|
380 |
-
if type(modules[m_idx])
|
381 |
h = modules[m_idx](h, cond, cond_mask)
|
382 |
else:
|
383 |
h = modules[m_idx](h)
|
@@ -394,7 +432,7 @@ class NCSNpp(nn.Module):
|
|
394 |
m_idx += 1
|
395 |
|
396 |
if h.shape[-1] in self.attn_resolutions:
|
397 |
-
if type(modules[m_idx])
|
398 |
h = modules[m_idx](h, cond, cond_mask)
|
399 |
else:
|
400 |
h = modules[m_idx](h)
|
|
|
53 |
default_initializer = layers.default_init
|
54 |
dense = dense_layer.dense
|
55 |
|
56 |
+
class CrossAndGlobalAttnBlock(nn.Module):
|
57 |
+
"""Channel-wise self-attention block."""
|
58 |
+
def __init__(self, channels, *, context_dim=None, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False):
|
59 |
+
super().__init__()
|
60 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
|
61 |
+
self.ca = layers.CrossAttention(
|
62 |
+
channels,
|
63 |
+
context_dim=context_dim,
|
64 |
+
dim_head=dim_head,
|
65 |
+
heads=heads,
|
66 |
+
norm_context=norm_context,
|
67 |
+
cosine_sim_attn=cosine_sim_attn,
|
68 |
+
)
|
69 |
+
self.attn = layerspp.AttnBlockppRaw(channels)
|
70 |
+
|
71 |
+
def forward(self, x, cond, mask=None):
|
72 |
+
B, C, H, W = x.shape
|
73 |
+
h = self.GroupNorm_0(x)
|
74 |
+
h = h.view(B, C, H*W)
|
75 |
+
h = h.permute(0,2,1)
|
76 |
+
h = h.contiguous()
|
77 |
+
h_new = self.ca(h, cond, mask=mask)
|
78 |
+
h_new = h_new.permute(0,2,1)
|
79 |
+
h_new = h_new.contiguous()
|
80 |
+
h_new = h_new.view(B, C, H, W)
|
81 |
+
|
82 |
+
h_global = self.attn(x)
|
83 |
+
h = h_new + h_global
|
84 |
+
return x + h
|
85 |
+
|
86 |
class PixelNorm(nn.Module):
|
87 |
def __init__(self):
|
88 |
super().__init__()
|
|
|
98 |
def __init__(self, config):
|
99 |
super().__init__()
|
100 |
self.config = config
|
101 |
+
self.cross_attention_block = config.cross_attention_block
|
102 |
self.grad_checkpointing = config.grad_checkpointing if hasattr(config, "grad_checkpointing") else False
|
103 |
self.not_use_tanh = config.not_use_tanh
|
104 |
self.act = act = nn.SiLU()
|
|
|
155 |
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
156 |
nn.init.zeros_(modules[-1].bias)
|
157 |
if config.cross_attention:
|
158 |
+
|
159 |
+
#block_name = config.cross_attention_block if hasattr(config, "cross_attention_block") else "basic"
|
160 |
+
block_name = config.cross_attention_block
|
161 |
+
if block_name == "basic":
|
162 |
+
AttnBlock = functools.partial(layers.CondAttnBlock, context_dim=config.cond_size)
|
163 |
+
elif block_name == "cross_and_global_attention":
|
164 |
+
AttnBlock = functools.partial(CrossAndGlobalAttnBlock, context_dim=config.cond_size)
|
165 |
+
print(AttnBlock)
|
166 |
else:
|
167 |
AttnBlock = functools.partial(layerspp.AttnBlockpp,
|
168 |
init_scale=init_scale,
|
|
|
380 |
h = modules[m_idx](hs[-1], temb, zemb)
|
381 |
m_idx += 1
|
382 |
if h.shape[-1] in self.attn_resolutions:
|
383 |
+
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
384 |
h = modules[m_idx](h, cond, cond_mask)
|
385 |
else:
|
386 |
h = modules[m_idx](h)
|
|
|
415 |
h = hs[-1]
|
416 |
h = modules[m_idx](h, temb, zemb)
|
417 |
m_idx += 1
|
418 |
+
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
419 |
h = modules[m_idx](h, cond, cond_mask)
|
420 |
else:
|
421 |
h = modules[m_idx](h)
|
|
|
432 |
m_idx += 1
|
433 |
|
434 |
if h.shape[-1] in self.attn_resolutions:
|
435 |
+
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
436 |
h = modules[m_idx](h, cond, cond_mask)
|
437 |
else:
|
438 |
h = modules[m_idx](h)
|
train_ddgan.py
CHANGED
@@ -385,9 +385,10 @@ def train(rank, gpu, args):
|
|
385 |
backbone_kwargs={"cond_size": text_encoder.output_size}
|
386 |
)
|
387 |
netD = netD.to(device)
|
388 |
-
|
389 |
-
|
390 |
-
|
|
|
391 |
|
392 |
if args.fsdp:
|
393 |
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
@@ -410,8 +411,9 @@ def train(rank, gpu, args):
|
|
410 |
if args.fsdp:
|
411 |
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
|
412 |
else:
|
413 |
-
|
414 |
-
|
|
|
415 |
#if args.discr_type == "projected_gan":
|
416 |
# netD._set_static_graph()
|
417 |
|
@@ -652,7 +654,8 @@ def train(rank, gpu, args):
|
|
652 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
|
653 |
|
654 |
if args.save_content:
|
655 |
-
|
|
|
656 |
if rank == 0:
|
657 |
print('Saving content.')
|
658 |
def to_cpu(d):
|
@@ -709,20 +712,26 @@ def init_processes(rank, size, fn, args):
|
|
709 |
""" Initialize the distributed environment. """
|
710 |
|
711 |
import os
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
726 |
|
727 |
def cleanup():
|
728 |
dist.destroy_process_group()
|
@@ -737,6 +746,8 @@ if __name__ == '__main__':
|
|
737 |
parser.add_argument('--mismatch_loss', action='store_true',default=False, help="use mismatch loss")
|
738 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
739 |
parser.add_argument('--cross_attention', action='store_true',default=False, help="use cross attention in generator")
|
|
|
|
|
740 |
parser.add_argument('--fsdp', action='store_true',default=False, help='use FSDP')
|
741 |
parser.add_argument('--grad_checkpointing', action='store_true',default=False, help='use grad checkpointing')
|
742 |
|
@@ -809,7 +820,7 @@ if __name__ == '__main__':
|
|
809 |
parser.add_argument('--beta2', type=float, default=0.9,
|
810 |
help='beta2 for adam')
|
811 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
812 |
-
parser.add_argument('--grad_penalty_cond', action='store_true',default=False, help="cond based grad
|
813 |
|
814 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
815 |
help='use EMA or not')
|
@@ -828,6 +839,7 @@ if __name__ == '__main__':
|
|
828 |
parser.add_argument('--precision', type=str, default="fp32")
|
829 |
|
830 |
###ddp
|
|
|
831 |
parser.add_argument('--num_proc_node', type=int, default=1,
|
832 |
help='The number of nodes in multi node env.')
|
833 |
parser.add_argument('--num_process_per_node', type=int, default=1,
|
@@ -840,8 +852,10 @@ if __name__ == '__main__':
|
|
840 |
help='address for master')
|
841 |
|
842 |
args = parser.parse_args()
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
|
|
|
|
847 |
init_processes(args.rank, args.world_size, train, args)
|
|
|
385 |
backbone_kwargs={"cond_size": text_encoder.output_size}
|
386 |
)
|
387 |
netD = netD.to(device)
|
388 |
+
|
389 |
+
if args.world_size > 1:
|
390 |
+
broadcast_params(netG.parameters())
|
391 |
+
broadcast_params(netD.parameters())
|
392 |
|
393 |
if args.fsdp:
|
394 |
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
|
|
411 |
if args.fsdp:
|
412 |
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
|
413 |
else:
|
414 |
+
if args.world_size > 1:
|
415 |
+
netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
|
416 |
+
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu], find_unused_parameters=args.discr_type=="projected_gan")
|
417 |
#if args.discr_type == "projected_gan":
|
418 |
# netD._set_static_graph()
|
419 |
|
|
|
654 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
|
655 |
|
656 |
if args.save_content:
|
657 |
+
if args.world_size > 1:
|
658 |
+
dist.barrier()
|
659 |
if rank == 0:
|
660 |
print('Saving content.')
|
661 |
def to_cpu(d):
|
|
|
712 |
""" Initialize the distributed environment. """
|
713 |
|
714 |
import os
|
715 |
+
|
716 |
+
if size == 1:
|
717 |
+
args.rank = 0
|
718 |
+
args.world_size = 1
|
719 |
+
args.local_rank = 0
|
720 |
+
fn(rank,args.local_rank, args)
|
721 |
+
else:
|
722 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
723 |
+
args.world_size = int(os.getenv("SLURM_NTASKS"))
|
724 |
+
args.local_rank = int(os.environ['SLURM_LOCALID'])
|
725 |
+
print(args.rank, args.world_size)
|
726 |
+
args.master_address = os.getenv("SLURM_LAUNCH_NODE_IPADDR")
|
727 |
+
os.environ['MASTER_ADDR'] = args.master_address
|
728 |
+
os.environ['MASTER_PORT'] = "12345"
|
729 |
+
torch.cuda.set_device(args.local_rank)
|
730 |
+
gpu = args.local_rank
|
731 |
+
dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=args.world_size)
|
732 |
+
fn(rank, gpu, args)
|
733 |
+
dist.barrier()
|
734 |
+
cleanup()
|
735 |
|
736 |
def cleanup():
|
737 |
dist.destroy_process_group()
|
|
|
746 |
parser.add_argument('--mismatch_loss', action='store_true',default=False, help="use mismatch loss")
|
747 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
748 |
parser.add_argument('--cross_attention', action='store_true',default=False, help="use cross attention in generator")
|
749 |
+
parser.add_argument('--cross_attention_block', default="basic", help="cross attention block type")
|
750 |
+
|
751 |
parser.add_argument('--fsdp', action='store_true',default=False, help='use FSDP')
|
752 |
parser.add_argument('--grad_checkpointing', action='store_true',default=False, help='use grad checkpointing')
|
753 |
|
|
|
820 |
parser.add_argument('--beta2', type=float, default=0.9,
|
821 |
help='beta2 for adam')
|
822 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
823 |
+
parser.add_argument('--grad_penalty_cond', action='store_true',default=False, help="cond based grad")
|
824 |
|
825 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
826 |
help='use EMA or not')
|
|
|
839 |
parser.add_argument('--precision', type=str, default="fp32")
|
840 |
|
841 |
###ddp
|
842 |
+
|
843 |
parser.add_argument('--num_proc_node', type=int, default=1,
|
844 |
help='The number of nodes in multi node env.')
|
845 |
parser.add_argument('--num_process_per_node', type=int, default=1,
|
|
|
852 |
help='address for master')
|
853 |
|
854 |
args = parser.parse_args()
|
855 |
+
if 'SLURM_NTASKS' in os.environ:
|
856 |
+
args.world_size = int(os.getenv("SLURM_NTASKS"))
|
857 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
858 |
+
else:
|
859 |
+
args.world_size = 1
|
860 |
+
args.rank = 0
|
861 |
init_processes(args.rank, args.world_size, train, args)
|