Adapter commited on
Commit
aa01e5b
·
1 Parent(s): 070c4da

Update demo/model.py

Browse files
Files changed (1) hide show
  1. demo/model.py +23 -17
demo/model.py CHANGED
@@ -17,6 +17,7 @@ import cv2
17
  import numpy as np
18
  import torch.nn.functional as F
19
 
 
20
  def preprocessing(image, device):
21
  # Resize
22
  scale = 640 / max(image.shape[:2])
@@ -39,6 +40,7 @@ def preprocessing(image, device):
39
 
40
  return image, raw_image
41
 
 
42
  def imshow_keypoints(img,
43
  pose_result,
44
  skeleton=None,
@@ -138,18 +140,22 @@ class Model_all:
138
  use_conv=False).to(device)
139
  self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
140
  self.model_edge = pidinet().to(device)
141
- self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in torch.load('models/table5_pidinet.pth', map_location=device)['state_dict'].items()})
 
 
142
 
143
  # segmentation part
144
  self.model_seger = seger().to(device)
145
  self.model_seger.eval()
146
  self.coler = Colorize(n=182)
147
- self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
 
148
  self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
149
  self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
150
 
151
  # depth part
152
- self.model_depth = Adapter(cin=3*64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
 
153
  self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
154
 
155
  # keypose part
@@ -183,7 +189,7 @@ class Model_all:
183
  [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
184
  [51, 153, 255],
185
  [51, 153, 255], [51, 153, 255], [51, 153, 255]]
186
-
187
  def load_vae(self):
188
  vae_sd = torch.load(os.path.join('models', 'anything-v4.0.vae.pt'), map_location="cuda")
189
  sd = vae_sd["state_dict"]
@@ -254,7 +260,7 @@ class Model_all:
254
 
255
  @torch.no_grad()
256
  def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
257
- con_strength, base_model):
258
  if self.current_base != base_model:
259
  ckpt = os.path.join("models", base_model)
260
  pl_sd = torch.load(ckpt, map_location="cuda")
@@ -312,7 +318,8 @@ class Model_all:
312
  return [im_depth, x_samples_ddim]
313
 
314
  @torch.no_grad()
315
- def process_depth_keypose(self, input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth, w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
 
316
  if self.current_base != base_model:
317
  ckpt = os.path.join("models", base_model)
318
  pl_sd = torch.load(ckpt, map_location="cuda")
@@ -343,8 +350,7 @@ class Model_all:
343
 
344
  # get keypose
345
  if type_in_keypose == 'Keypose':
346
- im_keypose_out = im_keypose.copy()
347
- pose = img2tensor(im_keypose).unsqueeze(0) / 255.
348
  elif type_in_keypose == 'Image':
349
  image = im_keypose.copy()
350
  im_keypose = img2tensor(im_keypose).unsqueeze(0) / 255.
@@ -378,7 +384,7 @@ class Model_all:
378
  pose_link_color=self.pose_link_color,
379
  radius=2,
380
  thickness=2)
381
- im_keypose_out = im_keypose_out.astype(np.uint8)[:,:,::-1]
382
 
383
  # extract condition features
384
  c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
@@ -387,7 +393,8 @@ class Model_all:
387
  pose = img2tensor(im_keypose_out, bgr2rgb=True, float32=True) / 255.
388
  pose = pose.unsqueeze(0)
389
  features_adapter_keypose = self.model_pose(pose.to(self.device))
390
- features_adapter = [f_d*w_depth + f_k*w_keypose for f_d, f_k in zip(features_adapter_depth, features_adapter_keypose)]
 
391
  shape = [4, 64, 64]
392
 
393
  # sampling
@@ -416,7 +423,7 @@ class Model_all:
416
 
417
  @torch.no_grad()
418
  def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
419
- con_strength, base_model):
420
  if self.current_base != base_model:
421
  ckpt = os.path.join("models", base_model)
422
  pl_sd = torch.load(ckpt, map_location="cuda")
@@ -450,10 +457,10 @@ class Model_all:
450
  labelmap = np.argmax(probs, axis=0)
451
 
452
  labelmap = self.coler(labelmap)
453
- labelmap = np.transpose(labelmap, (1,2,0))
454
  labelmap = cv2.resize(labelmap, (512, 512))
455
- labelmap = img2tensor(labelmap, bgr2rgb=False, float32=True)/255.
456
- im_seg = tensor2img(labelmap)[:,:,::-1]
457
  labelmap = labelmap.unsqueeze(0)
458
 
459
  # extract condition features
@@ -564,8 +571,7 @@ class Model_all:
564
  im = cv2.resize(input_img, (512, 512))
565
 
566
  if type_in == 'Keypose':
567
- im_pose = im.copy()
568
- im = img2tensor(im).unsqueeze(0) / 255.
569
  elif type_in == 'Image':
570
  image = im.copy()
571
  im = img2tensor(im).unsqueeze(0) / 255.
@@ -599,7 +605,7 @@ class Model_all:
599
  pose_link_color=self.pose_link_color,
600
  radius=2,
601
  thickness=2)
602
- im_pose = cv2.resize(im_pose, (512, 512))
603
 
604
  # extract condition features
605
  c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
 
17
  import numpy as np
18
  import torch.nn.functional as F
19
 
20
+
21
  def preprocessing(image, device):
22
  # Resize
23
  scale = 640 / max(image.shape[:2])
 
40
 
41
  return image, raw_image
42
 
43
+
44
  def imshow_keypoints(img,
45
  pose_result,
46
  skeleton=None,
 
140
  use_conv=False).to(device)
141
  self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
142
  self.model_edge = pidinet().to(device)
143
+ self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in
144
+ torch.load('models/table5_pidinet.pth', map_location=device)[
145
+ 'state_dict'].items()})
146
 
147
  # segmentation part
148
  self.model_seger = seger().to(device)
149
  self.model_seger.eval()
150
  self.coler = Colorize(n=182)
151
+ self.model_seg = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
152
+ use_conv=False).to(device)
153
  self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
154
  self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
155
 
156
  # depth part
157
+ self.model_depth = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
158
+ use_conv=False).to(device)
159
  self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
160
 
161
  # keypose part
 
189
  [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
190
  [51, 153, 255],
191
  [51, 153, 255], [51, 153, 255], [51, 153, 255]]
192
+
193
  def load_vae(self):
194
  vae_sd = torch.load(os.path.join('models', 'anything-v4.0.vae.pt'), map_location="cuda")
195
  sd = vae_sd["state_dict"]
 
260
 
261
  @torch.no_grad()
262
  def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
263
+ con_strength, base_model):
264
  if self.current_base != base_model:
265
  ckpt = os.path.join("models", base_model)
266
  pl_sd = torch.load(ckpt, map_location="cuda")
 
318
  return [im_depth, x_samples_ddim]
319
 
320
  @torch.no_grad()
321
+ def process_depth_keypose(self, input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth,
322
+ w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
323
  if self.current_base != base_model:
324
  ckpt = os.path.join("models", base_model)
325
  pl_sd = torch.load(ckpt, map_location="cuda")
 
350
 
351
  # get keypose
352
  if type_in_keypose == 'Keypose':
353
+ im_keypose_out = im_keypose.copy()[:,:,::-1]
 
354
  elif type_in_keypose == 'Image':
355
  image = im_keypose.copy()
356
  im_keypose = img2tensor(im_keypose).unsqueeze(0) / 255.
 
384
  pose_link_color=self.pose_link_color,
385
  radius=2,
386
  thickness=2)
387
+ im_keypose_out = im_keypose_out.astype(np.uint8)[:, :, ::-1]
388
 
389
  # extract condition features
390
  c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
 
393
  pose = img2tensor(im_keypose_out, bgr2rgb=True, float32=True) / 255.
394
  pose = pose.unsqueeze(0)
395
  features_adapter_keypose = self.model_pose(pose.to(self.device))
396
+ features_adapter = [f_d * w_depth + f_k * w_keypose for f_d, f_k in
397
+ zip(features_adapter_depth, features_adapter_keypose)]
398
  shape = [4, 64, 64]
399
 
400
  # sampling
 
423
 
424
  @torch.no_grad()
425
  def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
426
+ con_strength, base_model):
427
  if self.current_base != base_model:
428
  ckpt = os.path.join("models", base_model)
429
  pl_sd = torch.load(ckpt, map_location="cuda")
 
457
  labelmap = np.argmax(probs, axis=0)
458
 
459
  labelmap = self.coler(labelmap)
460
+ labelmap = np.transpose(labelmap, (1, 2, 0))
461
  labelmap = cv2.resize(labelmap, (512, 512))
462
+ labelmap = img2tensor(labelmap, bgr2rgb=False, float32=True) / 255.
463
+ im_seg = tensor2img(labelmap)[:, :, ::-1]
464
  labelmap = labelmap.unsqueeze(0)
465
 
466
  # extract condition features
 
571
  im = cv2.resize(input_img, (512, 512))
572
 
573
  if type_in == 'Keypose':
574
+ im_pose = im.copy()[:,:,::-1]
 
575
  elif type_in == 'Image':
576
  image = im.copy()
577
  im = img2tensor(im).unsqueeze(0) / 255.
 
605
  pose_link_color=self.pose_link_color,
606
  radius=2,
607
  thickness=2)
608
+ # im_pose = cv2.resize(im_pose, (512, 512))
609
 
610
  # extract condition features
611
  c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])