Spaces:
Runtime error
Runtime error
Upload 7 files
Browse files- .gitattributes +13 -0
- app.py +108 -0
- models/det.onnx +3 -0
- models/humanparsing_572_384.pt +3 -0
- models/pose.onnx +3 -0
- requirements.txt +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
|
37 |
+
### Python ###
|
38 |
+
# Byte-compiled / optimized / DLL files
|
39 |
+
__pycache__/
|
40 |
+
*.py[cod]
|
41 |
+
*$py.class
|
42 |
+
|
43 |
+
# PyCharm
|
44 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
45 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
46 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
47 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
48 |
+
#.idea/
|
app.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rtmlib import YOLOX, RTMPose, draw_bbox, draw_skeleton
|
2 |
+
import functools
|
3 |
+
from typing import Callable
|
4 |
+
from pathlib import Path
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import PIL.Image
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from PIL import Image
|
12 |
+
import torchvision.transforms as T
|
13 |
+
|
14 |
+
|
15 |
+
TITLE = 'Human Parsing'
|
16 |
+
|
17 |
+
def get_palette(num_cls):
|
18 |
+
""" Returns the color map for visualizing the segmentation mask.
|
19 |
+
Args:
|
20 |
+
num_cls: Number of classes
|
21 |
+
Returns:
|
22 |
+
The color map
|
23 |
+
"""
|
24 |
+
|
25 |
+
n = num_cls
|
26 |
+
palette = [0] * (n * 3)
|
27 |
+
for j in range(0, n):
|
28 |
+
lab = j
|
29 |
+
palette[j * 3 + 0] = 0
|
30 |
+
palette[j * 3 + 1] = 0
|
31 |
+
palette[j * 3 + 2] = 0
|
32 |
+
i = 0
|
33 |
+
while lab:
|
34 |
+
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
|
35 |
+
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
|
36 |
+
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
|
37 |
+
i += 1
|
38 |
+
lab >>= 3
|
39 |
+
return palette
|
40 |
+
|
41 |
+
@torch.inference_mode()
|
42 |
+
def predict(image: PIL.Image.Image, model, transform: Callable,
|
43 |
+
device: torch.device,palette) -> np.ndarray:
|
44 |
+
img_show = np.array(image.copy())
|
45 |
+
bboxes = model[1](np. array(image))
|
46 |
+
img_show = draw_bbox(img_show, bboxes)
|
47 |
+
keypoints,scores = model[2](np. array(image),bboxes=bboxes)
|
48 |
+
img_show = draw_skeleton(img_show,keypoints,scores)
|
49 |
+
|
50 |
+
data = transform(image)
|
51 |
+
data = data.unsqueeze(0).to(device)
|
52 |
+
out = model[0](data)
|
53 |
+
out = F.interpolate(out, [image.size[1],image.size[0]], mode="bilinear")
|
54 |
+
output = out[0].permute(1,2,0)
|
55 |
+
parsing = torch.argmax(output,dim=2).cpu().numpy()
|
56 |
+
|
57 |
+
output_im = Image.fromarray(np.asarray(parsing, dtype=np.uint8))
|
58 |
+
image = Image.fromarray(np.asarray(img_show, dtype=np.uint8))
|
59 |
+
output_im.putpalette(palette)
|
60 |
+
output_im = output_im.convert('RGB')
|
61 |
+
# output_im.save('output.png')
|
62 |
+
|
63 |
+
res = Image.blend(image.convert('RGB'), output_im, 0.5)
|
64 |
+
return output_im, res
|
65 |
+
|
66 |
+
|
67 |
+
def load_parsing_model():
|
68 |
+
model = torch.jit.load(Path("models/humanparsing_572_384.pt"))
|
69 |
+
model.eval()
|
70 |
+
return model
|
71 |
+
|
72 |
+
|
73 |
+
def main():
|
74 |
+
device = torch.device('cpu')
|
75 |
+
model_ls =[]
|
76 |
+
model = load_parsing_model()
|
77 |
+
|
78 |
+
transform = T.Compose([
|
79 |
+
T.Resize((572, 384), interpolation=PIL.Image.NEAREST),
|
80 |
+
T.ToTensor(),
|
81 |
+
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
82 |
+
])
|
83 |
+
palette = get_palette(20)
|
84 |
+
det_model = YOLOX('models/det.onnx',model_input_size=(640,640),backend='onnxruntime', device='cpu')
|
85 |
+
pose_model = RTMPose('models/pose.onnx', model_input_size=(192, 256),to_openpose=False, backend='onnxruntime', device='cpu')
|
86 |
+
|
87 |
+
model_ls.append(model)
|
88 |
+
model_ls.append(det_model)
|
89 |
+
model_ls.append(pose_model)
|
90 |
+
|
91 |
+
func = functools.partial(predict,
|
92 |
+
model=model_ls,
|
93 |
+
transform=transform,
|
94 |
+
device=device,palette=palette)
|
95 |
+
|
96 |
+
|
97 |
+
gr.Interface(
|
98 |
+
fn=func,
|
99 |
+
inputs=gr.Image(label='Input', type='pil'),
|
100 |
+
outputs=[
|
101 |
+
gr.Image(label='Predicted Labels', type='pil'),
|
102 |
+
gr.Image(label='Masked', type='pil'),
|
103 |
+
],
|
104 |
+
title=TITLE,
|
105 |
+
).queue().launch(show_api=False)
|
106 |
+
|
107 |
+
if __name__ == "__main__":
|
108 |
+
main()
|
models/det.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3dea6513388889f0fff4b77bf7a26013600321b9eb9ceb0e9a400a82572f5f23
|
3 |
+
size 101400344
|
models/humanparsing_572_384.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91fa5fa1cbbc59336e1a4c9cbb51f572ebff8289a084226d6b8b79fbeae922a6
|
3 |
+
size 257770490
|
models/pose.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7bca282009dd5e5b8a94bb27e8015f610630643659b862818803febf9107e2e5
|
3 |
+
size 368041127
|
requirements.txt
ADDED
Binary file (192 Bytes). View file
|
|