TroglodyteDerivations commited on
Commit
e51ecd5
·
verified ·
1 Parent(s): 628a434

Create gan.py

Browse files
Files changed (1) hide show
  1. gan.py +240 -0
gan.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import cv2
4
+ import torchvision.transforms as transforms
5
+ import torch
6
+ import io
7
+ import os
8
+ import functools
9
+
10
+ class DataLoader():
11
+
12
+ def __init__(self, opt, cv_img):
13
+ super(DataLoader, self).__init__()
14
+
15
+ self.dataset = Dataset()
16
+ self.dataset.initialize(opt, cv_img)
17
+
18
+ self.dataloader = torch.utils.data.DataLoader(
19
+ self.dataset,
20
+ batch_size=opt.batchSize,
21
+ shuffle=not opt.serial_batches,
22
+ num_workers=int(opt.nThreads))
23
+
24
+ def load_data(self):
25
+ return self.dataloader
26
+
27
+ def __len__(self):
28
+ return 1
29
+
30
+ class Dataset(torch.utils.data.Dataset):
31
+ def __init__(self):
32
+ super(Dataset, self).__init__()
33
+
34
+ def initialize(self, opt, cv_img):
35
+ self.opt = opt
36
+ self.root = opt.dataroot
37
+
38
+ self.A = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
39
+ self.dataset_size = 1
40
+
41
+ def __getitem__(self, index):
42
+
43
+ transform_A = get_transform(self.opt)
44
+ A_tensor = transform_A(self.A.convert('RGB'))
45
+
46
+ B_tensor = inst_tensor = feat_tensor = 0
47
+
48
+ input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
49
+ 'feat': feat_tensor, 'path': ""}
50
+
51
+ return input_dict
52
+
53
+ def __len__(self):
54
+ return 1
55
+
56
+ class DeepModel(torch.nn.Module):
57
+
58
+ def initialize(self, opt):
59
+
60
+ torch.cuda.empty_cache()
61
+
62
+ self.opt = opt
63
+
64
+ self.gpu_ids = [] #FIX CPU
65
+
66
+ self.netG = self.__define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
67
+ opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
68
+ opt.n_blocks_local, opt.norm, self.gpu_ids)
69
+
70
+ # load networks
71
+ self.__load_network(self.netG)
72
+
73
+
74
+
75
+ def inference(self, label, inst):
76
+
77
+ # Encode Inputs
78
+ input_label, inst_map, _, _ = self.__encode_input(label, inst, infer=True)
79
+
80
+ # Fake Generation
81
+ input_concat = input_label
82
+
83
+ with torch.no_grad():
84
+ fake_image = self.netG.forward(input_concat)
85
+
86
+ return fake_image
87
+
88
+ # helper loading function that can be used by subclasses
89
+ def __load_network(self, network):
90
+
91
+ save_path = os.path.join(self.opt.checkpoints_dir)
92
+
93
+ network.load_state_dict(torch.load(save_path))
94
+
95
+ def __encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
96
+ if (len(self.gpu_ids) > 0):
97
+ input_label = label_map.data.cuda() #GPU
98
+ else:
99
+ input_label = label_map.data #CPU
100
+
101
+ return input_label, inst_map, real_image, feat_map
102
+
103
+ def __weights_init(self, m):
104
+ classname = m.__class__.__name__
105
+ if classname.find('Conv') != -1:
106
+ m.weight.data.normal_(0.0, 0.02)
107
+ elif classname.find('BatchNorm2d') != -1:
108
+ m.weight.data.normal_(1.0, 0.02)
109
+ m.bias.data.fill_(0)
110
+
111
+ def __define_G(self, input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
112
+ n_blocks_local=3, norm='instance', gpu_ids=[]):
113
+ norm_layer = self.__get_norm_layer(norm_type=norm)
114
+ netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
115
+
116
+ if len(gpu_ids) > 0:
117
+ netG.cuda(gpu_ids[0])
118
+ netG.apply(self.__weights_init)
119
+ return netG
120
+
121
+ def __get_norm_layer(self, norm_type='instance'):
122
+ norm_layer = functools.partial(torch.nn.InstanceNorm2d, affine=False)
123
+ return norm_layer
124
+
125
+ ##############################################################################
126
+ # Generator
127
+ ##############################################################################
128
+ class GlobalGenerator(torch.nn.Module):
129
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=torch.nn.BatchNorm2d,
130
+ padding_type='reflect'):
131
+ assert(n_blocks >= 0)
132
+ super(GlobalGenerator, self).__init__()
133
+ activation = torch.nn.ReLU(True)
134
+
135
+ model = [torch.nn.ReflectionPad2d(3), torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
136
+ ### downsample
137
+ for i in range(n_downsampling):
138
+ mult = 2**i
139
+ model += [torch.nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
140
+ norm_layer(ngf * mult * 2), activation]
141
+
142
+ ### resnet blocks
143
+ mult = 2**n_downsampling
144
+ for i in range(n_blocks):
145
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
146
+
147
+ ### upsample
148
+ for i in range(n_downsampling):
149
+ mult = 2**(n_downsampling - i)
150
+ model += [torch.nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
151
+ norm_layer(int(ngf * mult / 2)), activation]
152
+ model += [torch.nn.ReflectionPad2d(3), torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), torch.nn.Tanh()]
153
+ self.model = torch.nn.Sequential(*model)
154
+
155
+ def forward(self, input):
156
+ return self.model(input)
157
+
158
+ # Define a resnet block
159
+ class ResnetBlock(torch.nn.Module):
160
+ def __init__(self, dim, padding_type, norm_layer, activation=torch.nn.ReLU(True), use_dropout=False):
161
+ super(ResnetBlock, self).__init__()
162
+ self.conv_block = self.__build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
163
+
164
+ def __build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
165
+ conv_block = []
166
+ p = 0
167
+ if padding_type == 'reflect':
168
+ conv_block += [torch.nn.ReflectionPad2d(1)]
169
+ elif padding_type == 'replicate':
170
+ conv_block += [torch.nn.ReplicationPad2d(1)]
171
+ elif padding_type == 'zero':
172
+ p = 1
173
+ else:
174
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
175
+
176
+ conv_block += [torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
177
+ norm_layer(dim),
178
+ activation]
179
+ if use_dropout:
180
+ conv_block += [torch.nn.Dropout(0.5)]
181
+
182
+ p = 0
183
+ if padding_type == 'reflect':
184
+ conv_block += [torch.nn.ReflectionPad2d(1)]
185
+ elif padding_type == 'replicate':
186
+ conv_block += [torch.nn.ReplicationPad2d(1)]
187
+ elif padding_type == 'zero':
188
+ p = 1
189
+ else:
190
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
191
+ conv_block += [torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
192
+ norm_layer(dim)]
193
+
194
+ return torch.nn.Sequential(*conv_block)
195
+
196
+ def forward(self, x):
197
+ out = x + self.conv_block(x)
198
+ return out
199
+
200
+ # Data utils:
201
+ def get_transform(opt, method=Image.BICUBIC, normalize=True):
202
+ transform_list = []
203
+
204
+ base = float(2 ** opt.n_downsample_global)
205
+ if opt.netG == 'local':
206
+ base *= (2 ** opt.n_local_enhancers)
207
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
208
+
209
+ transform_list += [transforms.ToTensor()]
210
+
211
+ if normalize:
212
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
213
+ (0.5, 0.5, 0.5))]
214
+ return transforms.Compose(transform_list)
215
+
216
+ def __make_power_2(img, base, method=Image.BICUBIC):
217
+ ow, oh = img.size
218
+ h = int(round(oh / base) * base)
219
+ w = int(round(ow / base) * base)
220
+ if (h == oh) and (w == ow):
221
+ return img
222
+ return img.resize((w, h), method)
223
+
224
+ # Converts a Tensor into a Numpy array
225
+ # |imtype|: the desired type of the converted numpy array
226
+ def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
227
+ if isinstance(image_tensor, list):
228
+ image_numpy = []
229
+ for i in range(len(image_tensor)):
230
+ image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
231
+ return image_numpy
232
+ image_numpy = image_tensor.cpu().float().numpy()
233
+ if normalize:
234
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
235
+ else:
236
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
237
+ image_numpy = np.clip(image_numpy, 0, 255)
238
+ if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
239
+ image_numpy = image_numpy[:,:,0]
240
+ return image_numpy.astype(imtype)