macguyver commited on
Commit
86e1696
·
verified ·
1 Parent(s): 682e493

Create run_inference_train_x.py

Browse files
Files changed (1) hide show
  1. anydoor/run_inference_train_x.py +270 -0
anydoor/run_inference_train_x.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import einops
3
+ import numpy as np
4
+ import torch
5
+ import random
6
+ from pytorch_lightning import seed_everything
7
+ from cldm.model import create_model, load_state_dict
8
+ from cldm.ddim_hacked import DDIMSampler
9
+ from cldm.hack import disable_verbosity, enable_sliced_attention
10
+ from datasets.data_utils import *
11
+ cv2.setNumThreads(0)
12
+ cv2.ocl.setUseOpenCL(False)
13
+ import albumentations as A
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+
17
+
18
+ save_memory = False
19
+ disable_verbosity()
20
+ if save_memory:
21
+ enable_sliced_attention()
22
+
23
+
24
+ config = OmegaConf.load('./configs/inference.yaml')
25
+ model_ckpt = config.pretrained_model
26
+ model_config = config.config_file
27
+
28
+ model = create_model(model_config ).cpu()
29
+ model.load_state_dict(load_state_dict(model_ckpt, location='cuda'))
30
+ model = model.cuda()
31
+ ddim_sampler = DDIMSampler(model)
32
+
33
+
34
+
35
+ def aug_data_mask(image, mask):
36
+ transform = A.Compose([
37
+ A.HorizontalFlip(p=0.5),
38
+ A.RandomBrightnessContrast(p=0.5),
39
+ ])
40
+ transformed = transform(image=image.astype(np.uint8), mask = mask)
41
+ transformed_image = transformed["image"]
42
+ transformed_mask = transformed["mask"]
43
+ return transformed_image, transformed_mask
44
+
45
+
46
+ def process_pairs(ref_image, ref_mask, tar_image, tar_mask):
47
+ # ========= Reference ===========
48
+ # ref expand
49
+ ref_box_yyxx = get_bbox_from_mask(ref_mask)
50
+
51
+ # ref filter mask
52
+ ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
53
+ masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3)
54
+
55
+ y1,y2,x1,x2 = ref_box_yyxx
56
+ masked_ref_image = masked_ref_image[y1:y2,x1:x2,:]
57
+ ref_mask = ref_mask[y1:y2,x1:x2]
58
+
59
+
60
+ ratio = np.random.randint(12, 13) / 10
61
+ masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
62
+ ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
63
+
64
+ # to square and resize
65
+ masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False)
66
+ masked_ref_image = cv2.resize(masked_ref_image, (224,224) ).astype(np.uint8)
67
+
68
+ ref_mask_3 = pad_to_square(ref_mask_3 * 255, pad_value = 0, random = False)
69
+ ref_mask_3 = cv2.resize(ref_mask_3, (224,224) ).astype(np.uint8)
70
+ ref_mask = ref_mask_3[:,:,0]
71
+
72
+ # ref aug
73
+ masked_ref_image_aug = masked_ref_image #aug_data(masked_ref_image)
74
+
75
+ # collage aug
76
+ masked_ref_image_compose, ref_mask_compose = masked_ref_image, ref_mask #aug_data_mask(masked_ref_image, ref_mask)
77
+ masked_ref_image_aug = masked_ref_image_compose.copy()
78
+ ref_mask_3 = np.stack([ref_mask_compose,ref_mask_compose,ref_mask_compose],-1)
79
+ ref_image_collage = sobel(masked_ref_image_compose, ref_mask_compose/255)
80
+
81
+ # ========= Target ===========
82
+ tar_box_yyxx = get_bbox_from_mask(tar_mask)
83
+ tar_box_yyxx = expand_bbox(tar_mask, tar_box_yyxx, ratio=[1.1,1.2])
84
+
85
+ # crop
86
+ tar_box_yyxx_crop = expand_bbox(tar_image, tar_box_yyxx, ratio=[1.5, 3]) #1.2 1.6
87
+ tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop) # crop box
88
+ y1,y2,x1,x2 = tar_box_yyxx_crop
89
+
90
+ cropped_target_image = tar_image[y1:y2,x1:x2,:]
91
+ tar_box_yyxx = box_in_box(tar_box_yyxx, tar_box_yyxx_crop)
92
+ y1,y2,x1,x2 = tar_box_yyxx
93
+
94
+ # collage
95
+ ref_image_collage = cv2.resize(ref_image_collage, (x2-x1, y2-y1))
96
+ ref_mask_compose = cv2.resize(ref_mask_compose.astype(np.uint8), (x2-x1, y2-y1))
97
+ ref_mask_compose = (ref_mask_compose > 128).astype(np.uint8)
98
+
99
+ collage = cropped_target_image.copy()
100
+ collage[y1:y2,x1:x2,:] = ref_image_collage
101
+
102
+ collage_mask = cropped_target_image.copy() * 0.0
103
+ collage_mask[y1:y2,x1:x2,:] = 1.0
104
+
105
+ # the size before pad
106
+ H1, W1 = collage.shape[0], collage.shape[1]
107
+ cropped_target_image = pad_to_square(cropped_target_image, pad_value = 0, random = False).astype(np.uint8)
108
+ collage = pad_to_square(collage, pad_value = 0, random = False).astype(np.uint8)
109
+ collage_mask = pad_to_square(collage_mask, pad_value = -1, random = False).astype(np.uint8)
110
+
111
+ # the size after pad
112
+ H2, W2 = collage.shape[0], collage.shape[1]
113
+ cropped_target_image = cv2.resize(cropped_target_image, (512,512)).astype(np.float32)
114
+ collage = cv2.resize(collage, (512,512)).astype(np.float32)
115
+ collage_mask = (cv2.resize(collage_mask, (512,512)).astype(np.float32) > 0.5).astype(np.float32)
116
+
117
+ masked_ref_image_aug = masked_ref_image_aug / 255
118
+ cropped_target_image = cropped_target_image / 127.5 - 1.0
119
+ collage = collage / 127.5 - 1.0
120
+ collage = np.concatenate([collage, collage_mask[:,:,:1] ] , -1)
121
+
122
+ item = dict(ref=masked_ref_image_aug.copy(), jpg=cropped_target_image.copy(), hint=collage.copy(), extra_sizes=np.array([H1, W1, H2, W2]), tar_box_yyxx_crop=np.array( tar_box_yyxx_crop ) )
123
+ return item
124
+
125
+
126
+ def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop):
127
+ H1, W1, H2, W2 = extra_sizes
128
+ y1,y2,x1,x2 = tar_box_yyxx_crop
129
+ pred = cv2.resize(pred, (W2, H2))
130
+ m = 5 # maigin_pixel
131
+
132
+ if W1 == H1:
133
+ tar_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m]
134
+ return tar_image
135
+
136
+ if W1 < W2:
137
+ pad1 = int((W2 - W1) / 2)
138
+ pad2 = W2 - W1 - pad1
139
+ pred = pred[:,pad1: -pad2, :]
140
+ else:
141
+ pad1 = int((H2 - H1) / 2)
142
+ pad2 = H2 - H1 - pad1
143
+ pred = pred[pad1: -pad2, :, :]
144
+
145
+ gen_image = tar_image.copy()
146
+ gen_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m]
147
+ return gen_image
148
+
149
+
150
+ def inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_scale = 5.0):
151
+ item = process_pairs(ref_image, ref_mask, tar_image, tar_mask)
152
+ ref = item['ref'] * 255
153
+ tar = item['jpg'] * 127.5 + 127.5
154
+ hint = item['hint'] * 127.5 + 127.5
155
+
156
+ hint_image = hint[:,:,:-1]
157
+ hint_mask = item['hint'][:,:,-1] * 255
158
+ hint_mask = np.stack([hint_mask,hint_mask,hint_mask],-1)
159
+ ref = cv2.resize(ref.astype(np.uint8), (512,512))
160
+
161
+ seed = random.randint(0, 65535)
162
+ if save_memory:
163
+ model.low_vram_shift(is_diffusing=False)
164
+
165
+ ref = item['ref']
166
+ tar = item['jpg']
167
+ hint = item['hint']
168
+ num_samples = 1
169
+
170
+ control = torch.from_numpy(hint.copy()).float().cuda()
171
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
172
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
173
+
174
+
175
+ clip_input = torch.from_numpy(ref.copy()).float().cuda()
176
+ clip_input = torch.stack([clip_input for _ in range(num_samples)], dim=0)
177
+ clip_input = einops.rearrange(clip_input, 'b h w c -> b c h w').clone()
178
+
179
+ guess_mode = False
180
+ H,W = 512,512
181
+
182
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning( clip_input )]}
183
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([torch.zeros((1,3,224,224))] * num_samples)]}
184
+ shape = (4, H // 8, W // 8)
185
+
186
+ if save_memory:
187
+ model.low_vram_shift(is_diffusing=True)
188
+
189
+ # ====
190
+ num_samples = 1 #gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
191
+ image_resolution = 512 #gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
192
+ strength = 1 #gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
193
+ guess_mode = False #gr.Checkbox(label='Guess Mode', value=False)
194
+ #detect_resolution = 512 #gr.Slider(label="Segmentation Resolution", minimum=128, maximum=1024, value=512, step=1)
195
+ ddim_steps = 50 #gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
196
+ scale = guidance_scale #gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
197
+ seed = -1 #gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
198
+ eta = 0.0 #gr.Number(label="eta (DDIM)", value=0.0)
199
+
200
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
201
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
202
+ shape, cond, verbose=False, eta=eta,
203
+ unconditional_guidance_scale=scale,
204
+ unconditional_conditioning=un_cond)
205
+ if save_memory:
206
+ model.low_vram_shift(is_diffusing=False)
207
+
208
+ x_samples = model.decode_first_stage(samples)
209
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy()#.clip(0, 255).astype(np.uint8)
210
+
211
+ result = x_samples[0][:,:,::-1]
212
+ result = np.clip(result,0,255)
213
+
214
+ pred = x_samples[0]
215
+ pred = np.clip(pred,0,255)[1:,:,:]
216
+ sizes = item['extra_sizes']
217
+ tar_box_yyxx_crop = item['tar_box_yyxx_crop']
218
+ gen_image = crop_back(pred, tar_image, sizes, tar_box_yyxx_crop)
219
+ return gen_image
220
+
221
+
222
+ if __name__ == '__main__':
223
+
224
+ import os
225
+ import cv2
226
+ import itertools
227
+
228
+ # Assuming 'inference_single_image' function is defined elsewhere
229
+
230
+ save_dir = '/work/pink_girl/out'
231
+ cloth_dir = '/work/pink_girl/cloth/top'
232
+ cloth_mask_dir = '/work/pink_girl/cloth-mask'
233
+ image_dir = '/work/pink_girl/image'
234
+ image_parse_v3_dir = '/work/pink_girl/image-mask'
235
+
236
+ # Fixed reference image and its mask
237
+ fixed_ref_image_name = 'we_picked_8.jpg'
238
+ fixed_ref_image_path = os.path.join(image_dir, fixed_ref_image_name)
239
+ fixed_ref_mask_path = os.path.join(image_parse_v3_dir, 'top_long_collar.png')
240
+
241
+ # Ensure the save directory exists
242
+ if not os.path.exists(save_dir):
243
+ os.makedirs(save_dir)
244
+
245
+ # Create list of cloth image names
246
+ cloth_image_names = os.listdir(cloth_dir)
247
+
248
+ for cloth_image_name in cloth_image_names:
249
+ # Construct paths for cloth and its mask
250
+ cloth_image_path = os.path.join(cloth_dir, cloth_image_name)
251
+ cloth_mask_path = os.path.join(cloth_mask_dir, cloth_image_name)
252
+
253
+ # Load and process the cloth image and mask
254
+ cloth_image = cv2.imread(cloth_image_path)
255
+ cloth_image = cv2.cvtColor(cloth_image, cv2.COLOR_BGR2RGB)
256
+ cloth_mask = (cv2.imread(cloth_mask_path) > 128).astype(np.uint8)[:, :, 0]
257
+
258
+ # Load and process the fixed reference image and mask
259
+ ref_image = cv2.imread(fixed_ref_image_path)
260
+ ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB)
261
+ ref_mask = Image.open(fixed_ref_mask_path).convert('P')
262
+ ref_mask = np.array(ref_mask) == 5 # Update this value if the mask classification is different
263
+
264
+ # Generate the image using the provided function
265
+ gen_image = inference_single_image(cloth_image, cloth_mask, ref_image, ref_mask)
266
+ gen_path = os.path.join(save_dir, '5_' + cloth_image_name)
267
+
268
+ # Concatenate and save the visualization
269
+ vis_image = cv2.hconcat([cloth_image, ref_image, gen_image])
270
+ cv2.imwrite(gen_path, cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR))