|
from pathlib import Path |
|
from typing import Optional, Union |
|
|
|
from mmdet.apis import inference_detector, init_detector |
|
from mmpose.apis import ( |
|
inference_top_down_pose_model, |
|
init_pose_model, |
|
process_mmdet_results, |
|
vis_pose_result, |
|
) |
|
from mmpose.datasets import DatasetInfo |
|
from PIL import Image, ImageDraw |
|
from torch import ge |
|
|
|
from internals.util.commons import download_file, download_image |
|
from internals.util.config import get_root_dir |
|
|
|
|
|
class PoseDetector: |
|
__det_model = "https://comic-assets.s3.ap-south-1.amazonaws.com/models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth" |
|
__pose_model = "https://comic-assets.s3.ap-south-1.amazonaws.com/models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth" |
|
|
|
__loaded = False |
|
|
|
def load(self): |
|
if self.__loaded: |
|
return |
|
|
|
det_path = Path.home() / ".cache" / self.__det_model.split("/")[-1] |
|
pose_path = Path.home() / ".cache" / self.__pose_model.split("/")[-1] |
|
|
|
download_file(self.__det_model, det_path) |
|
download_file(self.__pose_model, pose_path) |
|
|
|
self.det_model = init_detector( |
|
f"{get_root_dir()}/external/faster_rcnn_r50_fpn_coco.py", |
|
str(det_path), |
|
device="cpu", |
|
) |
|
self.pose_model = init_pose_model( |
|
f"{get_root_dir()}/external/hrnet_w48_coco_256x192.py", |
|
str(pose_path), |
|
device="cpu", |
|
) |
|
self.dataset = self.pose_model.cfg.data["test"]["type"] |
|
self.dataset_info = self.pose_model.cfg.data["test"].get("dataset_info", None) |
|
self.dataset_info = DatasetInfo(self.dataset_info) |
|
|
|
self.__loaded = True |
|
|
|
def transform( |
|
self, |
|
image: Union[str, Image.Image], |
|
width: int, |
|
height: int, |
|
client_coordinates: Optional[dict], |
|
) -> Image.Image: |
|
"Infer pose coordinates from image, map head and body coordinates to infered ones, create pose" |
|
if type(image) is str: |
|
image = download_image(image) |
|
|
|
infer_coordinates = self.infer(image, width, height) |
|
if client_coordinates and client_coordinates["candidate"]: |
|
client_coordinates = self.resize_coordinates( |
|
client_coordinates, 384, 384, width, height |
|
) |
|
infer_coordinates = self.map_head_to_body( |
|
client_coordinates, infer_coordinates |
|
) |
|
|
|
print(infer_coordinates) |
|
|
|
return self.create_pose(infer_coordinates, width, height) |
|
|
|
def resize_coordinates( |
|
self, data: dict, ori_width, ori_height, new_width, new_height |
|
): |
|
points = data["candidate"] |
|
new_points = [] |
|
|
|
if new_width > new_height: |
|
ori_min = min(ori_width, ori_height) |
|
new_min = min(new_width, new_height) |
|
else: |
|
ori_min = max(ori_width, ori_height) |
|
new_min = max(new_width, new_height) |
|
|
|
for _, pair in enumerate(points): |
|
x = pair[0] * new_min / ori_min |
|
y = pair[1] * new_min / ori_min |
|
new_points.append([x, y]) |
|
|
|
return {"candidate": new_points, "subset": data["subset"]} |
|
|
|
def create_pose(self, data: dict, width: int, height: int) -> Image.Image: |
|
image = Image.new("RGB", (width, height), "black") |
|
draw = ImageDraw.Draw(image) |
|
|
|
points = data["candidate"] |
|
for pair in self.__pose_logical_map: |
|
xy = points[pair[0] - 1] |
|
x1y1 = points[pair[1] - 1] |
|
|
|
draw.line( |
|
(xy[0], xy[1], x1y1[0], x1y1[1]), |
|
fill=pair[2], |
|
width=4, |
|
) |
|
for i, point in enumerate(points): |
|
x = point[0] |
|
y = point[1] |
|
draw.ellipse((x - 3, y - 3, x + 3, y + 3), fill=self.__points_color[i]) |
|
|
|
return image |
|
|
|
def infer(self, imageUrl: Union[str, Image.Image], width, height) -> dict: |
|
candidate = [] |
|
subset = [] |
|
|
|
if type(imageUrl) == Image.Image: |
|
image_path = Path.home() / ".cache" / "input.png" |
|
imageUrl.resize((width, height)).save(image_path) |
|
elif type(imageUrl) == str: |
|
image_path = Path.home() / ".cache" / imageUrl.split("/")[-1] |
|
image = download_image(imageUrl).resize((width, height)) |
|
image.save(image_path) |
|
else: |
|
raise Exception("Invalid image type") |
|
mmdet_results = inference_detector(self.det_model, str(image_path)) |
|
person_results = process_mmdet_results(mmdet_results, 1) |
|
|
|
pose_results, _ = inference_top_down_pose_model( |
|
self.pose_model, |
|
str(image_path), |
|
person_results, |
|
bbox_thr=0.3, |
|
format="xyxy", |
|
dataset=self.dataset, |
|
dataset_info=self.dataset_info, |
|
return_heatmap=False, |
|
outputs=None, |
|
) |
|
|
|
for d in pose_results: |
|
n = len(candidate) |
|
if d["bbox"][4] < 0.9: |
|
continue |
|
keypoints = d["keypoints"][:, :2].tolist() |
|
midpoint = [ |
|
(keypoints[5][0] + keypoints[6][0]) / 2, |
|
(keypoints[5][1] + keypoints[6][1]) / 2, |
|
] |
|
keypoints.append(midpoint) |
|
candidate.extend(self.__convert_keypoints(keypoints)) |
|
m = len(candidate) |
|
subset.append([j for j in range(n, m)]) |
|
|
|
return {"candidate": candidate[:18], "subset": subset[:18]} |
|
|
|
def map_head_to_body( |
|
self, client_coordinates: dict, infer_coordinates: dict |
|
) -> dict: |
|
client_points = client_coordinates["candidate"] |
|
infer_points = infer_coordinates["candidate"] |
|
|
|
c_neck = client_points[1] |
|
i_neck = infer_points[1] |
|
|
|
dx = i_neck[0] - c_neck[0] |
|
dy = i_neck[1] - c_neck[1] |
|
|
|
for i in range(2, 15): |
|
point = client_points[i - 1] |
|
infer_points[i - 1] = [point[0] + dx, point[1] + dy] |
|
|
|
return {"candidate": infer_points, "subset": infer_coordinates["subset"]} |
|
|
|
def __convert_keypoints(self, keypoints): |
|
return [keypoints[i] for i in self.__kim] |
|
|
|
__kim = [0, 17, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3] |
|
__pose_logical_map = [ |
|
[1, 2, "#000099"], |
|
[1, 16, "#330099"], |
|
[1, 15, "#660099"], |
|
[16, 18, "#990099"], |
|
[15, 17, "#990066"], |
|
[2, 3, "#990001"], |
|
[2, 6, "#993301"], |
|
[3, 4, "#996502"], |
|
[4, 5, "#999900"], |
|
[6, 7, "#669900"], |
|
[7, 8, "#349900"], |
|
[2, 9, "#009900"], |
|
[2, 12, "#009999"], |
|
[9, 10, "#009966"], |
|
[10, 11, "#009966"], |
|
[12, 13, "#006699"], |
|
[13, 14, "#013399"], |
|
] |
|
__points_color = [ |
|
"#ff0000", |
|
"#ff5600", |
|
"#ffaa01", |
|
"#ffff00", |
|
"#aaff03", |
|
"#53ff00", |
|
"#03ff00", |
|
"#03ff55", |
|
"#03ffaa", |
|
"#03ffff", |
|
"#05aaff", |
|
"#0055ff", |
|
"#0000ff", |
|
"#5500ff", |
|
"#aa00ff", |
|
"#ff00aa", |
|
"#ff00ff", |
|
"#ff0055", |
|
] |
|
|