TroglodyteDerivations
commited on
Create run.py
Browse files
run.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
#Import Neural Network Model
|
4 |
+
from gan import DataLoader, DeepModel, tensor2im
|
5 |
+
|
6 |
+
#OpenCv Transform:
|
7 |
+
from opencv_transform.mask_to_maskref import create_maskref
|
8 |
+
from opencv_transform.maskdet_to_maskfin import create_maskfin
|
9 |
+
from opencv_transform.dress_to_correct import create_correct
|
10 |
+
from opencv_transform.nude_to_watermark import create_watermark
|
11 |
+
|
12 |
+
"""
|
13 |
+
run.py
|
14 |
+
This script manage the entire transormation.
|
15 |
+
Transformation happens in 6 phases:
|
16 |
+
0: dress -> correct [opencv] dress_to_correct
|
17 |
+
1: correct -> mask: [GAN] correct_to_mask
|
18 |
+
2: mask -> maskref [opencv] mask_to_maskref
|
19 |
+
3: maskref -> maskdet [GAN] maskref_to_maskdet
|
20 |
+
4: maskdet -> maskfin [opencv] maskdet_to_maskfin
|
21 |
+
5: maskfin -> nude [GAN] maskfin_to_nude
|
22 |
+
6: nude -> watermark [opencv] nude_to_watermark
|
23 |
+
"""
|
24 |
+
|
25 |
+
phases = ["dress_to_correct", "correct_to_mask", "mask_to_maskref", "maskref_to_maskdet", "maskdet_to_maskfin", "maskfin_to_nude", "nude_to_watermark"]
|
26 |
+
|
27 |
+
class Options():
|
28 |
+
|
29 |
+
#Init options with default values
|
30 |
+
def __init__(self):
|
31 |
+
|
32 |
+
# experiment specifics
|
33 |
+
self.norm = 'batch' #instance normalization or batch normalization
|
34 |
+
self.use_dropout = False #use dropout for the generator
|
35 |
+
self.data_type = 32 #Supported data type i.e. 8, 16, 32 bit
|
36 |
+
|
37 |
+
# input/output sizes
|
38 |
+
self.batchSize = 1 #input batch size
|
39 |
+
self.input_nc = 3 # of input image channels
|
40 |
+
self.output_nc = 3 # of output image channels
|
41 |
+
|
42 |
+
# for setting inputs
|
43 |
+
self.serial_batches = True #if true, takes images in order to make batches, otherwise takes them randomly
|
44 |
+
self.nThreads = 1 ## threads for loading data (???)
|
45 |
+
self.max_dataset_size = 1 #Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.
|
46 |
+
|
47 |
+
# for generator
|
48 |
+
self.netG = 'global' #selects model to use for netG
|
49 |
+
self.ngf = 64 ## of gen filters in first conv layer
|
50 |
+
self.n_downsample_global = 4 #number of downsampling layers in netG
|
51 |
+
self.n_blocks_global = 9 #number of residual blocks in the global generator network
|
52 |
+
self.n_blocks_local = 0 #number of residual blocks in the local enhancer network
|
53 |
+
self.n_local_enhancers = 0 #number of local enhancers to use
|
54 |
+
self.niter_fix_global = 0 #number of epochs that we only train the outmost local enhancer
|
55 |
+
|
56 |
+
#Phase specific options
|
57 |
+
self.checkpoints_dir = ""
|
58 |
+
self.dataroot = ""
|
59 |
+
|
60 |
+
#Changes options accordlying to actual phase
|
61 |
+
def updateOptions(self, phase):
|
62 |
+
|
63 |
+
if phase == "correct_to_mask":
|
64 |
+
self.checkpoints_dir = "checkpoints/cm.lib"
|
65 |
+
|
66 |
+
elif phase == "maskref_to_maskdet":
|
67 |
+
self.checkpoints_dir = "checkpoints/mm.lib"
|
68 |
+
|
69 |
+
elif phase == "maskfin_to_nude":
|
70 |
+
self.checkpoints_dir = "checkpoints/mn.lib"
|
71 |
+
|
72 |
+
# process(cv_img, mode)
|
73 |
+
# return:
|
74 |
+
# watermark image
|
75 |
+
def process(cv_img):
|
76 |
+
|
77 |
+
#InMemory cv2 images:
|
78 |
+
dress = cv_img
|
79 |
+
correct = None
|
80 |
+
mask = None
|
81 |
+
maskref = None
|
82 |
+
maskfin = None
|
83 |
+
maskdet = None
|
84 |
+
nude = None
|
85 |
+
watermark = None
|
86 |
+
|
87 |
+
for index, phase in enumerate(phases):
|
88 |
+
|
89 |
+
print("Executing phase: " + phase)
|
90 |
+
|
91 |
+
#GAN phases:
|
92 |
+
if (phase == "correct_to_mask") or (phase == "maskref_to_maskdet") or (phase == "maskfin_to_nude"):
|
93 |
+
|
94 |
+
#Load global option
|
95 |
+
opt = Options()
|
96 |
+
|
97 |
+
#Load custom phase options:
|
98 |
+
opt.updateOptions(phase)
|
99 |
+
|
100 |
+
#Load Data
|
101 |
+
if (phase == "correct_to_mask"):
|
102 |
+
data_loader = DataLoader(opt, correct)
|
103 |
+
elif (phase == "maskref_to_maskdet"):
|
104 |
+
data_loader = DataLoader(opt, maskref)
|
105 |
+
elif (phase == "maskfin_to_nude"):
|
106 |
+
data_loader = DataLoader(opt, maskfin)
|
107 |
+
|
108 |
+
dataset = data_loader.load_data()
|
109 |
+
|
110 |
+
#Create Model
|
111 |
+
model = DeepModel()
|
112 |
+
model.initialize(opt)
|
113 |
+
|
114 |
+
#Run for every image:
|
115 |
+
for i, data in enumerate(dataset):
|
116 |
+
|
117 |
+
generated = model.inference(data['label'], data['inst'])
|
118 |
+
|
119 |
+
im = tensor2im(generated.data[0])
|
120 |
+
|
121 |
+
#Save Data
|
122 |
+
if (phase == "correct_to_mask"):
|
123 |
+
mask = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
124 |
+
|
125 |
+
elif (phase == "maskref_to_maskdet"):
|
126 |
+
maskdet = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
127 |
+
|
128 |
+
elif (phase == "maskfin_to_nude"):
|
129 |
+
nude = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
130 |
+
|
131 |
+
#Correcting:
|
132 |
+
elif (phase == 'dress_to_correct'):
|
133 |
+
correct = create_correct(dress)
|
134 |
+
|
135 |
+
#mask_ref phase (opencv)
|
136 |
+
elif (phase == "mask_to_maskref"):
|
137 |
+
maskref = create_maskref(mask, correct)
|
138 |
+
|
139 |
+
#mask_fin phase (opencv)
|
140 |
+
elif (phase == "maskdet_to_maskfin"):
|
141 |
+
maskfin = create_maskfin(maskref, maskdet)
|
142 |
+
|
143 |
+
#nude_to_watermark phase (opencv)
|
144 |
+
elif (phase == "nude_to_watermark"):
|
145 |
+
watermark = create_watermark(nude)
|
146 |
+
|
147 |
+
return watermark
|