Spaces:
Runtime error
Runtime error
File size: 4,370 Bytes
5d756f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
import lzma
from dp2.detection.base import BaseDetector
from .utils import combine_cse_maskrcnn_dets
from .models.cse import CSEDetector
from .models.mask_rcnn import MaskRCNNDetector
from .models.keypoint_maskrcnn import KeypointMaskRCNN
from .structures import CSEPersonDetection, PersonDetection
from pathlib import Path
class CSEPersonDetector(BaseDetector):
def __init__(
self,
score_threshold: float,
mask_rcnn_cfg: dict,
cse_cfg: dict,
cse_post_process_cfg: dict,
**kwargs
) -> None:
super().__init__(**kwargs)
self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
self.cse_detector = CSEDetector(**cse_cfg, score_thres=score_threshold)
self.post_process_cfg = cse_post_process_cfg
self.iou_combine_threshold = self.post_process_cfg.pop("iou_combine_threshold")
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def load_from_cache(self, cache_path: Path):
with lzma.open(cache_path, "rb") as fp:
state_dict = torch.load(fp)
kwargs = dict(
post_process_cfg=self.post_process_cfg,
embed_map=self.cse_detector.embed_map,
)
return [
state["cls"].from_state_dict(**kwargs, state_dict=state)
for state in state_dict
]
@torch.no_grad()
def forward(self, im: torch.Tensor, cse_dets=None):
mask_dets = self.mask_rcnn(im)
if cse_dets is None:
cse_dets = self.cse_detector(im)
segmentation = mask_dets["segmentation"]
segmentation, cse_dets, _ = combine_cse_maskrcnn_dets(
segmentation, cse_dets, self.iou_combine_threshold
)
det = CSEPersonDetection(
segmentation=segmentation,
cse_dets=cse_dets,
embed_map=self.cse_detector.embed_map,
orig_imshape_CHW=im.shape,
**self.post_process_cfg
)
return [det]
class MaskRCNNPersonDetector(BaseDetector):
def __init__(
self,
score_threshold: float,
mask_rcnn_cfg: dict,
cse_post_process_cfg: dict,
**kwargs
) -> None:
super().__init__(**kwargs)
self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
self.post_process_cfg = cse_post_process_cfg
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def load_from_cache(self, cache_path: Path):
with lzma.open(cache_path, "rb") as fp:
state_dict = torch.load(fp)
kwargs = dict(
post_process_cfg=self.post_process_cfg,
)
return [
state["cls"].from_state_dict(**kwargs, state_dict=state)
for state in state_dict
]
@torch.no_grad()
def forward(self, im: torch.Tensor):
mask_dets = self.mask_rcnn(im)
segmentation = mask_dets["segmentation"]
det = PersonDetection(
segmentation, **self.post_process_cfg, orig_imshape_CHW=im.shape
)
return [det]
class KeypointMaskRCNNPersonDetector(BaseDetector):
def __init__(
self,
score_threshold: float,
mask_rcnn_cfg: dict,
cse_post_process_cfg: dict,
**kwargs
) -> None:
super().__init__(**kwargs)
self.mask_rcnn = KeypointMaskRCNN(
**mask_rcnn_cfg, score_threshold=score_threshold
)
self.post_process_cfg = cse_post_process_cfg
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def load_from_cache(self, cache_path: Path):
with lzma.open(cache_path, "rb") as fp:
state_dict = torch.load(fp)
kwargs = dict(
post_process_cfg=self.post_process_cfg,
)
return [
state["cls"].from_state_dict(**kwargs, state_dict=state)
for state in state_dict
]
@torch.no_grad()
def forward(self, im: torch.Tensor):
mask_dets = self.mask_rcnn(im)
segmentation = mask_dets["segmentation"]
det = PersonDetection(
segmentation,
**self.post_process_cfg,
orig_imshape_CHW=im.shape,
keypoints=mask_dets["keypoints"]
)
return [det]
|