Spaces:
Runtime error
Runtime error
Update demo/model.py
Browse files- demo/model.py +23 -17
demo/model.py
CHANGED
@@ -17,6 +17,7 @@ import cv2
|
|
17 |
import numpy as np
|
18 |
import torch.nn.functional as F
|
19 |
|
|
|
20 |
def preprocessing(image, device):
|
21 |
# Resize
|
22 |
scale = 640 / max(image.shape[:2])
|
@@ -39,6 +40,7 @@ def preprocessing(image, device):
|
|
39 |
|
40 |
return image, raw_image
|
41 |
|
|
|
42 |
def imshow_keypoints(img,
|
43 |
pose_result,
|
44 |
skeleton=None,
|
@@ -138,18 +140,22 @@ class Model_all:
|
|
138 |
use_conv=False).to(device)
|
139 |
self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
|
140 |
self.model_edge = pidinet().to(device)
|
141 |
-
self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in
|
|
|
|
|
142 |
|
143 |
# segmentation part
|
144 |
self.model_seger = seger().to(device)
|
145 |
self.model_seger.eval()
|
146 |
self.coler = Colorize(n=182)
|
147 |
-
self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
|
|
148 |
self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
149 |
self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
|
150 |
|
151 |
# depth part
|
152 |
-
self.model_depth = Adapter(cin=3*64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
|
|
153 |
self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
|
154 |
|
155 |
# keypose part
|
@@ -183,7 +189,7 @@ class Model_all:
|
|
183 |
[0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
|
184 |
[51, 153, 255],
|
185 |
[51, 153, 255], [51, 153, 255], [51, 153, 255]]
|
186 |
-
|
187 |
def load_vae(self):
|
188 |
vae_sd = torch.load(os.path.join('models', 'anything-v4.0.vae.pt'), map_location="cuda")
|
189 |
sd = vae_sd["state_dict"]
|
@@ -254,7 +260,7 @@ class Model_all:
|
|
254 |
|
255 |
@torch.no_grad()
|
256 |
def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
257 |
-
|
258 |
if self.current_base != base_model:
|
259 |
ckpt = os.path.join("models", base_model)
|
260 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
@@ -312,7 +318,8 @@ class Model_all:
|
|
312 |
return [im_depth, x_samples_ddim]
|
313 |
|
314 |
@torch.no_grad()
|
315 |
-
def process_depth_keypose(self, input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth,
|
|
|
316 |
if self.current_base != base_model:
|
317 |
ckpt = os.path.join("models", base_model)
|
318 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
@@ -343,8 +350,7 @@ class Model_all:
|
|
343 |
|
344 |
# get keypose
|
345 |
if type_in_keypose == 'Keypose':
|
346 |
-
im_keypose_out = im_keypose.copy()
|
347 |
-
pose = img2tensor(im_keypose).unsqueeze(0) / 255.
|
348 |
elif type_in_keypose == 'Image':
|
349 |
image = im_keypose.copy()
|
350 |
im_keypose = img2tensor(im_keypose).unsqueeze(0) / 255.
|
@@ -378,7 +384,7 @@ class Model_all:
|
|
378 |
pose_link_color=self.pose_link_color,
|
379 |
radius=2,
|
380 |
thickness=2)
|
381 |
-
im_keypose_out = im_keypose_out.astype(np.uint8)[
|
382 |
|
383 |
# extract condition features
|
384 |
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
@@ -387,7 +393,8 @@ class Model_all:
|
|
387 |
pose = img2tensor(im_keypose_out, bgr2rgb=True, float32=True) / 255.
|
388 |
pose = pose.unsqueeze(0)
|
389 |
features_adapter_keypose = self.model_pose(pose.to(self.device))
|
390 |
-
features_adapter = [f_d*w_depth + f_k*w_keypose for f_d, f_k in
|
|
|
391 |
shape = [4, 64, 64]
|
392 |
|
393 |
# sampling
|
@@ -416,7 +423,7 @@ class Model_all:
|
|
416 |
|
417 |
@torch.no_grad()
|
418 |
def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
419 |
-
|
420 |
if self.current_base != base_model:
|
421 |
ckpt = os.path.join("models", base_model)
|
422 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
@@ -450,10 +457,10 @@ class Model_all:
|
|
450 |
labelmap = np.argmax(probs, axis=0)
|
451 |
|
452 |
labelmap = self.coler(labelmap)
|
453 |
-
labelmap = np.transpose(labelmap, (1,2,0))
|
454 |
labelmap = cv2.resize(labelmap, (512, 512))
|
455 |
-
labelmap = img2tensor(labelmap, bgr2rgb=False, float32=True)/255.
|
456 |
-
im_seg = tensor2img(labelmap)[
|
457 |
labelmap = labelmap.unsqueeze(0)
|
458 |
|
459 |
# extract condition features
|
@@ -564,8 +571,7 @@ class Model_all:
|
|
564 |
im = cv2.resize(input_img, (512, 512))
|
565 |
|
566 |
if type_in == 'Keypose':
|
567 |
-
im_pose = im.copy()
|
568 |
-
im = img2tensor(im).unsqueeze(0) / 255.
|
569 |
elif type_in == 'Image':
|
570 |
image = im.copy()
|
571 |
im = img2tensor(im).unsqueeze(0) / 255.
|
@@ -599,7 +605,7 @@ class Model_all:
|
|
599 |
pose_link_color=self.pose_link_color,
|
600 |
radius=2,
|
601 |
thickness=2)
|
602 |
-
im_pose = cv2.resize(im_pose, (512, 512))
|
603 |
|
604 |
# extract condition features
|
605 |
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
|
|
17 |
import numpy as np
|
18 |
import torch.nn.functional as F
|
19 |
|
20 |
+
|
21 |
def preprocessing(image, device):
|
22 |
# Resize
|
23 |
scale = 640 / max(image.shape[:2])
|
|
|
40 |
|
41 |
return image, raw_image
|
42 |
|
43 |
+
|
44 |
def imshow_keypoints(img,
|
45 |
pose_result,
|
46 |
skeleton=None,
|
|
|
140 |
use_conv=False).to(device)
|
141 |
self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
|
142 |
self.model_edge = pidinet().to(device)
|
143 |
+
self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in
|
144 |
+
torch.load('models/table5_pidinet.pth', map_location=device)[
|
145 |
+
'state_dict'].items()})
|
146 |
|
147 |
# segmentation part
|
148 |
self.model_seger = seger().to(device)
|
149 |
self.model_seger.eval()
|
150 |
self.coler = Colorize(n=182)
|
151 |
+
self.model_seg = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
152 |
+
use_conv=False).to(device)
|
153 |
self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
154 |
self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
|
155 |
|
156 |
# depth part
|
157 |
+
self.model_depth = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
158 |
+
use_conv=False).to(device)
|
159 |
self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
|
160 |
|
161 |
# keypose part
|
|
|
189 |
[0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
|
190 |
[51, 153, 255],
|
191 |
[51, 153, 255], [51, 153, 255], [51, 153, 255]]
|
192 |
+
|
193 |
def load_vae(self):
|
194 |
vae_sd = torch.load(os.path.join('models', 'anything-v4.0.vae.pt'), map_location="cuda")
|
195 |
sd = vae_sd["state_dict"]
|
|
|
260 |
|
261 |
@torch.no_grad()
|
262 |
def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
263 |
+
con_strength, base_model):
|
264 |
if self.current_base != base_model:
|
265 |
ckpt = os.path.join("models", base_model)
|
266 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
|
|
318 |
return [im_depth, x_samples_ddim]
|
319 |
|
320 |
@torch.no_grad()
|
321 |
+
def process_depth_keypose(self, input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth,
|
322 |
+
w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
323 |
if self.current_base != base_model:
|
324 |
ckpt = os.path.join("models", base_model)
|
325 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
|
|
350 |
|
351 |
# get keypose
|
352 |
if type_in_keypose == 'Keypose':
|
353 |
+
im_keypose_out = im_keypose.copy()[:,:,::-1]
|
|
|
354 |
elif type_in_keypose == 'Image':
|
355 |
image = im_keypose.copy()
|
356 |
im_keypose = img2tensor(im_keypose).unsqueeze(0) / 255.
|
|
|
384 |
pose_link_color=self.pose_link_color,
|
385 |
radius=2,
|
386 |
thickness=2)
|
387 |
+
im_keypose_out = im_keypose_out.astype(np.uint8)[:, :, ::-1]
|
388 |
|
389 |
# extract condition features
|
390 |
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
|
|
393 |
pose = img2tensor(im_keypose_out, bgr2rgb=True, float32=True) / 255.
|
394 |
pose = pose.unsqueeze(0)
|
395 |
features_adapter_keypose = self.model_pose(pose.to(self.device))
|
396 |
+
features_adapter = [f_d * w_depth + f_k * w_keypose for f_d, f_k in
|
397 |
+
zip(features_adapter_depth, features_adapter_keypose)]
|
398 |
shape = [4, 64, 64]
|
399 |
|
400 |
# sampling
|
|
|
423 |
|
424 |
@torch.no_grad()
|
425 |
def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
426 |
+
con_strength, base_model):
|
427 |
if self.current_base != base_model:
|
428 |
ckpt = os.path.join("models", base_model)
|
429 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
|
|
457 |
labelmap = np.argmax(probs, axis=0)
|
458 |
|
459 |
labelmap = self.coler(labelmap)
|
460 |
+
labelmap = np.transpose(labelmap, (1, 2, 0))
|
461 |
labelmap = cv2.resize(labelmap, (512, 512))
|
462 |
+
labelmap = img2tensor(labelmap, bgr2rgb=False, float32=True) / 255.
|
463 |
+
im_seg = tensor2img(labelmap)[:, :, ::-1]
|
464 |
labelmap = labelmap.unsqueeze(0)
|
465 |
|
466 |
# extract condition features
|
|
|
571 |
im = cv2.resize(input_img, (512, 512))
|
572 |
|
573 |
if type_in == 'Keypose':
|
574 |
+
im_pose = im.copy()[:,:,::-1]
|
|
|
575 |
elif type_in == 'Image':
|
576 |
image = im.copy()
|
577 |
im = img2tensor(im).unsqueeze(0) / 255.
|
|
|
605 |
pose_link_color=self.pose_link_color,
|
606 |
radius=2,
|
607 |
thickness=2)
|
608 |
+
# im_pose = cv2.resize(im_pose, (512, 512))
|
609 |
|
610 |
# extract condition features
|
611 |
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|