Files changed (7) hide show
  1. .pre-commit-config.yaml +45 -59
  2. .style.yapf +5 -0
  3. README.md +1 -4
  4. app.py +179 -134
  5. model.py +119 -82
  6. requirements.txt +2 -2
  7. style.css +4 -1
.pre-commit-config.yaml CHANGED
@@ -1,60 +1,46 @@
 
1
  repos:
2
- - repo: https://github.com/pre-commit/pre-commit-hooks
3
- rev: v4.6.0
4
- hooks:
5
- - id: check-executables-have-shebangs
6
- - id: check-json
7
- - id: check-merge-conflict
8
- - id: check-shebang-scripts-are-executable
9
- - id: check-toml
10
- - id: check-yaml
11
- - id: end-of-file-fixer
12
- - id: mixed-line-ending
13
- args: ["--fix=lf"]
14
- - id: requirements-txt-fixer
15
- - id: trailing-whitespace
16
- - repo: https://github.com/myint/docformatter
17
- rev: v1.7.5
18
- hooks:
19
- - id: docformatter
20
- args: ["--in-place"]
21
- - repo: https://github.com/pycqa/isort
22
- rev: 5.13.2
23
- hooks:
24
- - id: isort
25
- args: ["--profile", "black"]
26
- - repo: https://github.com/pre-commit/mirrors-mypy
27
- rev: v1.10.0
28
- hooks:
29
- - id: mypy
30
- args: ["--ignore-missing-imports"]
31
- additional_dependencies:
32
- [
33
- "types-python-slugify",
34
- "types-requests",
35
- "types-PyYAML",
36
- "types-pytz",
37
- ]
38
- - repo: https://github.com/psf/black
39
- rev: 24.4.2
40
- hooks:
41
- - id: black
42
- language_version: python3.10
43
- args: ["--line-length", "119"]
44
- - repo: https://github.com/kynan/nbstripout
45
- rev: 0.7.1
46
- hooks:
47
- - id: nbstripout
48
- args:
49
- [
50
- "--extra-keys",
51
- "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
- ]
53
- - repo: https://github.com/nbQA-dev/nbQA
54
- rev: 1.8.5
55
- hooks:
56
- - id: nbqa-black
57
- - id: nbqa-pyupgrade
58
- args: ["--py37-plus"]
59
- - id: nbqa-isort
60
- args: ["--float-to-top"]
 
1
+ exclude: ^(ViTPose/|mmdet_configs/configs/)
2
  repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.812
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
37
+ - repo: https://github.com/kynan/nbstripout
38
+ rev: 0.5.0
39
+ hooks:
40
+ - id: nbstripout
41
+ args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
+ - repo: https://github.com/nbQA-dev/nbQA
43
+ rev: 1.3.1
44
+ hooks:
45
+ - id: nbqa-isort
46
+ - id: nbqa-yapf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,12 +4,9 @@ emoji: 📊
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
- suggested_hardware: t4-small
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
14
-
15
- https://arxiv.org/abs/2204.12484
 
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.1.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
app.py CHANGED
@@ -2,153 +2,198 @@
2
 
3
  from __future__ import annotations
4
 
5
- import os
6
  import pathlib
7
- import shlex
8
- import subprocess
9
  import tarfile
10
 
11
- if os.getenv("SYSTEM") == "spaces":
12
- subprocess.run(shlex.split("pip install click==7.1.2"))
13
- subprocess.run(shlex.split("pip install typer==0.9.4"))
14
 
15
- import mim
16
 
17
- mim.uninstall("mmcv-full", confirm_yes=True)
18
- mim.install("mmcv-full==1.5.0", is_yes=True)
19
 
20
- subprocess.run(shlex.split("pip uninstall -y opencv-python"))
21
- subprocess.run(shlex.split("pip uninstall -y opencv-python-headless"))
22
- subprocess.run(shlex.split("pip install opencv-python-headless==4.8.0.74"))
23
 
24
- import gradio as gr
25
 
26
- from model import AppDetModel, AppPoseModel
 
 
 
 
 
 
 
 
 
27
 
28
- DESCRIPTION = "# [ViTPose](https://github.com/ViTAE-Transformer/ViTPose)"
 
 
29
 
30
 
31
  def extract_tar() -> None:
32
- if pathlib.Path("mmdet_configs/configs").exists():
33
  return
34
- with tarfile.open("mmdet_configs/configs.tar") as f:
35
- f.extractall("mmdet_configs")
36
-
37
-
38
- extract_tar()
39
-
40
- det_model = AppDetModel()
41
- pose_model = AppPoseModel()
42
-
43
- with gr.Blocks(css="style.css") as demo:
44
- gr.Markdown(DESCRIPTION)
45
-
46
- with gr.Group():
47
- gr.Markdown("## Step 1")
48
- with gr.Row():
49
- with gr.Column():
50
- with gr.Row():
51
- input_image = gr.Image(label="Input Image", type="numpy")
52
- with gr.Row():
53
- detector_name = gr.Dropdown(
54
- label="Detector", choices=list(det_model.MODEL_DICT.keys()), value=det_model.model_name
55
- )
56
- with gr.Row():
57
- detect_button = gr.Button("Detect")
58
- det_preds = gr.State()
59
- with gr.Column():
60
- with gr.Row():
61
- detection_visualization = gr.Image(label="Detection Result", type="numpy", elem_id="det-result")
62
- with gr.Row():
63
- vis_det_score_threshold = gr.Slider(
64
- label="Visualization Score Threshold", minimum=0, maximum=1, step=0.05, value=0.5
65
- )
66
- with gr.Row():
67
- redraw_det_button = gr.Button(value="Redraw")
68
-
69
- with gr.Row():
70
- paths = sorted(pathlib.Path("images").rglob("*.jpg"))
71
- example_images = gr.Examples(examples=[[path.as_posix()] for path in paths], inputs=input_image)
72
-
73
- with gr.Group():
74
- gr.Markdown("## Step 2")
75
- with gr.Row():
76
- with gr.Column():
77
- with gr.Row():
78
- pose_model_name = gr.Dropdown(
79
- label="Pose Model", choices=list(pose_model.MODEL_DICT.keys()), value=pose_model.model_name
80
- )
81
- det_score_threshold = gr.Slider(
82
- label="Box Score Threshold", minimum=0, maximum=1, step=0.05, value=0.5
83
- )
84
- with gr.Row():
85
- predict_button = gr.Button("Predict")
86
- pose_preds = gr.State()
87
- with gr.Column():
88
- with gr.Row():
89
- pose_visualization = gr.Image(label="Result", type="numpy", elem_id="pose-result")
90
- with gr.Row():
91
- vis_kpt_score_threshold = gr.Slider(
92
- label="Visualization Score Threshold", minimum=0, maximum=1, step=0.05, value=0.3
93
- )
94
- with gr.Row():
95
- vis_dot_radius = gr.Slider(label="Dot Radius", minimum=1, maximum=10, step=1, value=4)
96
- with gr.Row():
97
- vis_line_thickness = gr.Slider(label="Line Thickness", minimum=1, maximum=10, step=1, value=2)
98
- with gr.Row():
99
- redraw_pose_button = gr.Button("Redraw")
100
-
101
- detector_name.change(fn=det_model.set_model, inputs=detector_name)
102
- detect_button.click(
103
- fn=det_model.run,
104
- inputs=[
105
- detector_name,
106
- input_image,
107
- vis_det_score_threshold,
108
- ],
109
- outputs=[
110
- det_preds,
111
- detection_visualization,
112
- ],
113
- )
114
- redraw_det_button.click(
115
- fn=det_model.visualize_detection_results,
116
- inputs=[
117
- input_image,
118
- det_preds,
119
- vis_det_score_threshold,
120
- ],
121
- outputs=detection_visualization,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
 
124
- pose_model_name.change(fn=pose_model.set_model, inputs=pose_model_name)
125
- predict_button.click(
126
- fn=pose_model.run,
127
- inputs=[
128
- pose_model_name,
129
- input_image,
130
- det_preds,
131
- det_score_threshold,
132
- vis_kpt_score_threshold,
133
- vis_dot_radius,
134
- vis_line_thickness,
135
- ],
136
- outputs=[
137
- pose_preds,
138
- pose_visualization,
139
- ],
140
- )
141
- redraw_pose_button.click(
142
- fn=pose_model.visualize_pose_results,
143
- inputs=[
144
- input_image,
145
- pose_preds,
146
- vis_kpt_score_threshold,
147
- vis_dot_radius,
148
- vis_line_thickness,
149
- ],
150
- outputs=pose_visualization,
151
- )
152
 
153
- if __name__ == "__main__":
154
- demo.queue(max_size=10).launch()
 
2
 
3
  from __future__ import annotations
4
 
5
+ import argparse
6
  import pathlib
 
 
7
  import tarfile
8
 
9
+ import gradio as gr
 
 
10
 
11
+ from model import AppDetModel, AppPoseModel
12
 
13
+ DESCRIPTION = '''# ViTPose
 
14
 
15
+ This is an unofficial demo for [https://github.com/ViTAE-Transformer/ViTPose](https://github.com/ViTAE-Transformer/ViTPose).'''
16
+ FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.vitpose" />'
 
17
 
 
18
 
19
+ def parse_args() -> argparse.Namespace:
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--device', type=str, default='cpu')
22
+ parser.add_argument('--theme', type=str)
23
+ parser.add_argument('--share', action='store_true')
24
+ parser.add_argument('--port', type=int)
25
+ parser.add_argument('--disable-queue',
26
+ dest='enable_queue',
27
+ action='store_false')
28
+ return parser.parse_args()
29
 
30
+
31
+ def set_example_image(example: list) -> dict:
32
+ return gr.Image.update(value=example[0])
33
 
34
 
35
  def extract_tar() -> None:
36
+ if pathlib.Path('mmdet_configs/configs').exists():
37
  return
38
+ with tarfile.open('mmdet_configs/configs.tar') as f:
39
+ f.extractall('mmdet_configs')
40
+
41
+
42
+ def main():
43
+ args = parse_args()
44
+
45
+ extract_tar()
46
+
47
+ det_model = AppDetModel(device=args.device)
48
+ pose_model = AppPoseModel(device=args.device)
49
+
50
+ with gr.Blocks(theme=args.theme, css='style.css') as demo:
51
+ gr.Markdown(DESCRIPTION)
52
+
53
+ with gr.Box():
54
+ gr.Markdown('## Step 1')
55
+ with gr.Row():
56
+ with gr.Column():
57
+ with gr.Row():
58
+ input_image = gr.Image(label='Input Image',
59
+ type='numpy')
60
+ with gr.Row():
61
+ detector_name = gr.Dropdown(list(
62
+ det_model.MODEL_DICT.keys()),
63
+ value=det_model.model_name,
64
+ label='Detector')
65
+ with gr.Row():
66
+ detect_button = gr.Button(value='Detect')
67
+ det_preds = gr.Variable()
68
+ with gr.Column():
69
+ with gr.Row():
70
+ detection_visualization = gr.Image(
71
+ label='Detection Result',
72
+ type='numpy',
73
+ elem_id='det-result')
74
+ with gr.Row():
75
+ vis_det_score_threshold = gr.Slider(
76
+ 0,
77
+ 1,
78
+ step=0.05,
79
+ value=0.5,
80
+ label='Visualization Score Threshold')
81
+ with gr.Row():
82
+ redraw_det_button = gr.Button(value='Redraw')
83
+
84
+ with gr.Row():
85
+ paths = sorted(pathlib.Path('images').rglob('*.jpg'))
86
+ example_images = gr.Dataset(components=[input_image],
87
+ samples=[[path.as_posix()]
88
+ for path in paths])
89
+
90
+ with gr.Box():
91
+ gr.Markdown('## Step 2')
92
+ with gr.Row():
93
+ with gr.Column():
94
+ with gr.Row():
95
+ pose_model_name = gr.Dropdown(
96
+ list(pose_model.MODEL_DICT.keys()),
97
+ value=pose_model.model_name,
98
+ label='Pose Model')
99
+ det_score_threshold = gr.Slider(
100
+ 0,
101
+ 1,
102
+ step=0.05,
103
+ value=0.5,
104
+ label='Box Score Threshold')
105
+ with gr.Row():
106
+ predict_button = gr.Button(value='Predict')
107
+ pose_preds = gr.Variable()
108
+ with gr.Column():
109
+ with gr.Row():
110
+ pose_visualization = gr.Image(label='Result',
111
+ type='numpy',
112
+ elem_id='pose-result')
113
+ with gr.Row():
114
+ vis_kpt_score_threshold = gr.Slider(
115
+ 0,
116
+ 1,
117
+ step=0.05,
118
+ value=0.3,
119
+ label='Visualization Score Threshold')
120
+ with gr.Row():
121
+ vis_dot_radius = gr.Slider(1,
122
+ 10,
123
+ step=1,
124
+ value=4,
125
+ label='Dot Radius')
126
+ with gr.Row():
127
+ vis_line_thickness = gr.Slider(1,
128
+ 10,
129
+ step=1,
130
+ value=2,
131
+ label='Line Thickness')
132
+ with gr.Row():
133
+ redraw_pose_button = gr.Button(value='Redraw')
134
+
135
+ gr.Markdown(FOOTER)
136
+
137
+ detector_name.change(fn=det_model.set_model,
138
+ inputs=detector_name,
139
+ outputs=None)
140
+ detect_button.click(fn=det_model.run,
141
+ inputs=[
142
+ detector_name,
143
+ input_image,
144
+ vis_det_score_threshold,
145
+ ],
146
+ outputs=[
147
+ det_preds,
148
+ detection_visualization,
149
+ ])
150
+ redraw_det_button.click(fn=det_model.visualize_detection_results,
151
+ inputs=[
152
+ input_image,
153
+ det_preds,
154
+ vis_det_score_threshold,
155
+ ],
156
+ outputs=detection_visualization)
157
+
158
+ pose_model_name.change(fn=pose_model.set_model,
159
+ inputs=pose_model_name,
160
+ outputs=None)
161
+ predict_button.click(fn=pose_model.run,
162
+ inputs=[
163
+ pose_model_name,
164
+ input_image,
165
+ det_preds,
166
+ det_score_threshold,
167
+ vis_kpt_score_threshold,
168
+ vis_dot_radius,
169
+ vis_line_thickness,
170
+ ],
171
+ outputs=[
172
+ pose_preds,
173
+ pose_visualization,
174
+ ])
175
+ redraw_pose_button.click(fn=pose_model.visualize_pose_results,
176
+ inputs=[
177
+ input_image,
178
+ pose_preds,
179
+ vis_kpt_score_threshold,
180
+ vis_dot_radius,
181
+ vis_line_thickness,
182
+ ],
183
+ outputs=pose_visualization)
184
+
185
+ example_images.click(
186
+ fn=set_example_image,
187
+ inputs=example_images,
188
+ outputs=input_image,
189
+ )
190
+
191
+ demo.launch(
192
+ enable_queue=args.enable_queue,
193
+ server_port=args.port,
194
+ share=args.share,
195
  )
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ if __name__ == '__main__':
199
+ main()
model.py CHANGED
@@ -1,50 +1,74 @@
1
  from __future__ import annotations
2
 
 
3
  import pathlib
 
4
  import sys
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import huggingface_hub
7
  import numpy as np
8
  import torch
9
  import torch.nn as nn
10
 
11
  app_dir = pathlib.Path(__file__).parent
12
- submodule_dir = app_dir / "ViTPose"
13
  sys.path.insert(0, submodule_dir.as_posix())
14
 
15
  from mmdet.apis import inference_detector, init_detector
16
- from mmpose.apis import (
17
- inference_top_down_pose_model,
18
- init_pose_model,
19
- process_mmdet_results,
20
- vis_pose_result,
21
- )
22
 
23
 
24
  class DetModel:
25
  MODEL_DICT = {
26
- "YOLOX-tiny": {
27
- "config": "mmdet_configs/configs/yolox/yolox_tiny_8x8_300e_coco.py",
28
- "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth",
 
 
29
  },
30
- "YOLOX-s": {
31
- "config": "mmdet_configs/configs/yolox/yolox_s_8x8_300e_coco.py",
32
- "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth",
 
 
33
  },
34
- "YOLOX-l": {
35
- "config": "mmdet_configs/configs/yolox/yolox_l_8x8_300e_coco.py",
36
- "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth",
 
 
37
  },
38
- "YOLOX-x": {
39
- "config": "mmdet_configs/configs/yolox/yolox_x_8x8_300e_coco.py",
40
- "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth",
 
 
41
  },
42
  }
43
 
44
- def __init__(self):
45
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
46
  self._load_all_models_once()
47
- self.model_name = "YOLOX-l"
48
  self.model = self._load_model(self.model_name)
49
 
50
  def _load_all_models_once(self) -> None:
@@ -52,8 +76,8 @@ class DetModel:
52
  self._load_model(name)
53
 
54
  def _load_model(self, name: str) -> nn.Module:
55
- d = self.MODEL_DICT[name]
56
- return init_detector(d["config"], d["model"], device=self.device)
57
 
58
  def set_model(self, name: str) -> None:
59
  if name == self.model_name:
@@ -61,7 +85,9 @@ class DetModel:
61
  self.model_name = name
62
  self.model = self._load_model(name)
63
 
64
- def detect_and_visualize(self, image: np.ndarray, score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
 
 
65
  out = self.detect(image)
66
  vis = self.visualize_detection_results(image, out, score_threshold)
67
  return out, vis
@@ -72,46 +98,56 @@ class DetModel:
72
  return out
73
 
74
  def visualize_detection_results(
75
- self, image: np.ndarray, detection_results: list[np.ndarray], score_threshold: float = 0.3
76
- ) -> np.ndarray:
 
 
77
  person_det = [detection_results[0]] + [np.array([]).reshape(0, 5)] * 79
78
 
79
  image = image[:, :, ::-1] # RGB -> BGR
80
- vis = self.model.show_result(
81
- image, person_det, score_thr=score_threshold, bbox_color=None, text_color=(200, 200, 200), mask_color=None
82
- )
 
 
 
83
  return vis[:, :, ::-1] # BGR -> RGB
84
 
85
 
86
  class AppDetModel(DetModel):
87
- def run(self, model_name: str, image: np.ndarray, score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
 
88
  self.set_model(model_name)
89
  return self.detect_and_visualize(image, score_threshold)
90
 
91
 
92
  class PoseModel:
93
  MODEL_DICT = {
94
- "ViTPose-B (single-task train)": {
95
- "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py",
96
- "model": "models/vitpose-b.pth",
 
97
  },
98
- "ViTPose-L (single-task train)": {
99
- "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py",
100
- "model": "models/vitpose-l.pth",
 
101
  },
102
- "ViTPose-B (multi-task train, COCO)": {
103
- "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py",
104
- "model": "models/vitpose-b-multi-coco.pth",
 
105
  },
106
- "ViTPose-L (multi-task train, COCO)": {
107
- "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py",
108
- "model": "models/vitpose-l-multi-coco.pth",
 
109
  },
110
  }
111
 
112
- def __init__(self):
113
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
114
- self.model_name = "ViTPose-B (multi-task train, COCO)"
115
  self.model = self._load_model(self.model_name)
116
 
117
  def _load_all_models_once(self) -> None:
@@ -119,9 +155,11 @@ class PoseModel:
119
  self._load_model(name)
120
 
121
  def _load_model(self, name: str) -> nn.Module:
122
- d = self.MODEL_DICT[name]
123
- ckpt_path = huggingface_hub.hf_hub_download("public-data/ViTPose", d["model"])
124
- model = init_pose_model(d["config"], ckpt_path, device=self.device)
 
 
125
  return model
126
 
127
  def set_model(self, name: str) -> None:
@@ -140,51 +178,50 @@ class PoseModel:
140
  vis_line_thickness: int,
141
  ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
142
  out = self.predict_pose(image, det_results, box_score_threshold)
143
- vis = self.visualize_pose_results(image, out, kpt_score_threshold, vis_dot_radius, vis_line_thickness)
 
144
  return out, vis
145
 
146
  def predict_pose(
147
- self, image: np.ndarray, det_results: list[np.ndarray], box_score_threshold: float = 0.5
148
- ) -> list[dict[str, np.ndarray]]:
 
 
149
  image = image[:, :, ::-1] # RGB -> BGR
150
  person_results = process_mmdet_results(det_results, 1)
151
- out, _ = inference_top_down_pose_model(
152
- self.model, image, person_results=person_results, bbox_thr=box_score_threshold, format="xyxy"
153
- )
 
 
154
  return out
155
 
156
- def visualize_pose_results(
157
- self,
158
- image: np.ndarray,
159
- pose_results: list[np.ndarray],
160
- kpt_score_threshold: float = 0.3,
161
- vis_dot_radius: int = 4,
162
- vis_line_thickness: int = 1,
163
- ) -> np.ndarray:
164
  image = image[:, :, ::-1] # RGB -> BGR
165
- vis = vis_pose_result(
166
- self.model,
167
- image,
168
- pose_results,
169
- kpt_score_thr=kpt_score_threshold,
170
- radius=vis_dot_radius,
171
- thickness=vis_line_thickness,
172
- )
173
  return vis[:, :, ::-1] # BGR -> RGB
174
 
175
 
176
  class AppPoseModel(PoseModel):
177
  def run(
178
- self,
179
- model_name: str,
180
- image: np.ndarray,
181
- det_results: list[np.ndarray],
182
- box_score_threshold: float,
183
- kpt_score_threshold: float,
184
- vis_dot_radius: int,
185
- vis_line_thickness: int,
186
  ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
187
  self.set_model(model_name)
188
- return self.predict_pose_and_visualize(
189
- image, det_results, box_score_threshold, kpt_score_threshold, vis_dot_radius, vis_line_thickness
190
- )
 
 
 
1
  from __future__ import annotations
2
 
3
+ import os
4
  import pathlib
5
+ import subprocess
6
  import sys
7
 
8
+ try:
9
+ from mmcv.ops import get_compiling_cuda_version, get_compiler_version
10
+ except:
11
+ import mim
12
+ mim.install('mmcv-full==1.5.0')
13
+
14
+ if os.getenv('SYSTEM') == 'spaces':
15
+ import mim
16
+
17
+ mim.uninstall('mmcv-full', confirm_yes=True)
18
+ mim.install('mmcv-full==1.5.0', is_yes=True)
19
+
20
+ subprocess.run('pip uninstall -y opencv-python'.split())
21
+ subprocess.run('pip uninstall -y opencv-python-headless'.split())
22
+ subprocess.run('pip install opencv-python-headless==4.5.5.64'.split())
23
+
24
  import huggingface_hub
25
  import numpy as np
26
  import torch
27
  import torch.nn as nn
28
 
29
  app_dir = pathlib.Path(__file__).parent
30
+ submodule_dir = app_dir / 'ViTPose/'
31
  sys.path.insert(0, submodule_dir.as_posix())
32
 
33
  from mmdet.apis import inference_detector, init_detector
34
+ from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
35
+ process_mmdet_results, vis_pose_result)
36
+
37
+ HF_TOKEN = os.environ['HF_TOKEN']
 
 
38
 
39
 
40
  class DetModel:
41
  MODEL_DICT = {
42
+ 'YOLOX-tiny': {
43
+ 'config':
44
+ 'mmdet_configs/configs/yolox/yolox_tiny_8x8_300e_coco.py',
45
+ 'model':
46
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth',
47
  },
48
+ 'YOLOX-s': {
49
+ 'config':
50
+ 'mmdet_configs/configs/yolox/yolox_s_8x8_300e_coco.py',
51
+ 'model':
52
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth',
53
  },
54
+ 'YOLOX-l': {
55
+ 'config':
56
+ 'mmdet_configs/configs/yolox/yolox_l_8x8_300e_coco.py',
57
+ 'model':
58
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth',
59
  },
60
+ 'YOLOX-x': {
61
+ 'config':
62
+ 'mmdet_configs/configs/yolox/yolox_x_8x8_300e_coco.py',
63
+ 'model':
64
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth',
65
  },
66
  }
67
 
68
+ def __init__(self, device: str | torch.device):
69
+ self.device = torch.device(device)
70
  self._load_all_models_once()
71
+ self.model_name = 'YOLOX-l'
72
  self.model = self._load_model(self.model_name)
73
 
74
  def _load_all_models_once(self) -> None:
 
76
  self._load_model(name)
77
 
78
  def _load_model(self, name: str) -> nn.Module:
79
+ dic = self.MODEL_DICT[name]
80
+ return init_detector(dic['config'], dic['model'], device=self.device)
81
 
82
  def set_model(self, name: str) -> None:
83
  if name == self.model_name:
 
85
  self.model_name = name
86
  self.model = self._load_model(name)
87
 
88
+ def detect_and_visualize(
89
+ self, image: np.ndarray,
90
+ score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
91
  out = self.detect(image)
92
  vis = self.visualize_detection_results(image, out, score_threshold)
93
  return out, vis
 
98
  return out
99
 
100
  def visualize_detection_results(
101
+ self,
102
+ image: np.ndarray,
103
+ detection_results: list[np.ndarray],
104
+ score_threshold: float = 0.3) -> np.ndarray:
105
  person_det = [detection_results[0]] + [np.array([]).reshape(0, 5)] * 79
106
 
107
  image = image[:, :, ::-1] # RGB -> BGR
108
+ vis = self.model.show_result(image,
109
+ person_det,
110
+ score_thr=score_threshold,
111
+ bbox_color=None,
112
+ text_color=(200, 200, 200),
113
+ mask_color=None)
114
  return vis[:, :, ::-1] # BGR -> RGB
115
 
116
 
117
  class AppDetModel(DetModel):
118
+ def run(self, model_name: str, image: np.ndarray,
119
+ score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
120
  self.set_model(model_name)
121
  return self.detect_and_visualize(image, score_threshold)
122
 
123
 
124
  class PoseModel:
125
  MODEL_DICT = {
126
+ 'ViTPose-B (single-task train)': {
127
+ 'config':
128
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
129
+ 'model': 'models/vitpose-b.pth',
130
  },
131
+ 'ViTPose-L (single-task train)': {
132
+ 'config':
133
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
134
+ 'model': 'models/vitpose-l.pth',
135
  },
136
+ 'ViTPose-B (multi-task train, COCO)': {
137
+ 'config':
138
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
139
+ 'model': 'models/vitpose-b-multi-coco.pth',
140
  },
141
+ 'ViTPose-L (multi-task train, COCO)': {
142
+ 'config':
143
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
144
+ 'model': 'models/vitpose-l-multi-coco.pth',
145
  },
146
  }
147
 
148
+ def __init__(self, device: str | torch.device):
149
+ self.device = torch.device(device)
150
+ self.model_name = 'ViTPose-B (multi-task train, COCO)'
151
  self.model = self._load_model(self.model_name)
152
 
153
  def _load_all_models_once(self) -> None:
 
155
  self._load_model(name)
156
 
157
  def _load_model(self, name: str) -> nn.Module:
158
+ dic = self.MODEL_DICT[name]
159
+ ckpt_path = huggingface_hub.hf_hub_download('hysts/ViTPose',
160
+ dic['model'],
161
+ use_auth_token=HF_TOKEN)
162
+ model = init_pose_model(dic['config'], ckpt_path, device=self.device)
163
  return model
164
 
165
  def set_model(self, name: str) -> None:
 
178
  vis_line_thickness: int,
179
  ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
180
  out = self.predict_pose(image, det_results, box_score_threshold)
181
+ vis = self.visualize_pose_results(image, out, kpt_score_threshold,
182
+ vis_dot_radius, vis_line_thickness)
183
  return out, vis
184
 
185
  def predict_pose(
186
+ self,
187
+ image: np.ndarray,
188
+ det_results: list[np.ndarray],
189
+ box_score_threshold: float = 0.5) -> list[dict[str, np.ndarray]]:
190
  image = image[:, :, ::-1] # RGB -> BGR
191
  person_results = process_mmdet_results(det_results, 1)
192
+ out, _ = inference_top_down_pose_model(self.model,
193
+ image,
194
+ person_results=person_results,
195
+ bbox_thr=box_score_threshold,
196
+ format='xyxy')
197
  return out
198
 
199
+ def visualize_pose_results(self,
200
+ image: np.ndarray,
201
+ pose_results: list[np.ndarray],
202
+ kpt_score_threshold: float = 0.3,
203
+ vis_dot_radius: int = 4,
204
+ vis_line_thickness: int = 1) -> np.ndarray:
 
 
205
  image = image[:, :, ::-1] # RGB -> BGR
206
+ vis = vis_pose_result(self.model,
207
+ image,
208
+ pose_results,
209
+ kpt_score_thr=kpt_score_threshold,
210
+ radius=vis_dot_radius,
211
+ thickness=vis_line_thickness)
 
 
212
  return vis[:, :, ::-1] # BGR -> RGB
213
 
214
 
215
  class AppPoseModel(PoseModel):
216
  def run(
217
+ self, model_name: str, image: np.ndarray,
218
+ det_results: list[np.ndarray], box_score_threshold: float,
219
+ kpt_score_threshold: float, vis_dot_radius: int,
220
+ vis_line_thickness: int
 
 
 
 
221
  ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
222
  self.set_model(model_name)
223
+ return self.predict_pose_and_visualize(image, det_results,
224
+ box_score_threshold,
225
+ kpt_score_threshold,
226
+ vis_dot_radius,
227
+ vis_line_thickness)
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  mmcv-full==1.5.0
2
  mmdet==2.24.1
3
  mmpose==0.25.1
4
- numpy==1.23.5
5
- opencv-python-headless==4.8.0.74
6
  openmim==0.1.5
7
  timm==0.5.4
8
  torch==1.11.0
 
1
  mmcv-full==1.5.0
2
  mmdet==2.24.1
3
  mmpose==0.25.1
4
+ numpy==1.22.4
5
+ opencv-python-headless==4.5.5.64
6
  openmim==0.1.5
7
  timm==0.5.4
8
  torch==1.11.0
style.css CHANGED
@@ -1,6 +1,5 @@
1
  h1 {
2
  text-align: center;
3
- display: block;
4
  }
5
  div#det-result {
6
  max-width: 600px;
@@ -10,3 +9,7 @@ div#pose-result {
10
  max-width: 600px;
11
  max-height: 600px;
12
  }
 
 
 
 
 
1
  h1 {
2
  text-align: center;
 
3
  }
4
  div#det-result {
5
  max-width: 600px;
 
9
  max-width: 600px;
10
  max-height: 600px;
11
  }
12
+ img#visitor-badge {
13
+ display: block;
14
+ margin: auto;
15
+ }