IDM-VTON
update IDM-VTON Demo
938e515
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from dataclasses import dataclass
from typing import Any, Optional
import torch
from detectron2.structures import BoxMode, Instances
from .utils import AnnotationsAccumulator
@dataclass
class PackedCseAnnotations:
x_gt: torch.Tensor
y_gt: torch.Tensor
coarse_segm_gt: Optional[torch.Tensor]
vertex_mesh_ids_gt: torch.Tensor
vertex_ids_gt: torch.Tensor
bbox_xywh_gt: torch.Tensor
bbox_xywh_est: torch.Tensor
point_bbox_with_dp_indices: torch.Tensor
point_bbox_indices: torch.Tensor
bbox_indices: torch.Tensor
class CseAnnotationsAccumulator(AnnotationsAccumulator):
"""
Accumulates annotations by batches that correspond to objects detected on
individual images. Can pack them together into single tensors.
"""
def __init__(self):
self.x_gt = []
self.y_gt = []
self.s_gt = []
self.vertex_mesh_ids_gt = []
self.vertex_ids_gt = []
self.bbox_xywh_gt = []
self.bbox_xywh_est = []
self.point_bbox_with_dp_indices = []
self.point_bbox_indices = []
self.bbox_indices = []
self.nxt_bbox_with_dp_index = 0
self.nxt_bbox_index = 0
def accumulate(self, instances_one_image: Instances):
"""
Accumulate instances data for one image
Args:
instances_one_image (Instances): instances data to accumulate
"""
boxes_xywh_est = BoxMode.convert(
instances_one_image.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
)
boxes_xywh_gt = BoxMode.convert(
instances_one_image.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
)
n_matches = len(boxes_xywh_gt)
assert n_matches == len(
boxes_xywh_est
), f"Got {len(boxes_xywh_est)} proposal boxes and {len(boxes_xywh_gt)} GT boxes"
if not n_matches:
# no detection - GT matches
return
if (
not hasattr(instances_one_image, "gt_densepose")
or instances_one_image.gt_densepose is None
):
# no densepose GT for the detections, just increase the bbox index
self.nxt_bbox_index += n_matches
return
for box_xywh_est, box_xywh_gt, dp_gt in zip(
boxes_xywh_est, boxes_xywh_gt, instances_one_image.gt_densepose
):
if (dp_gt is not None) and (len(dp_gt.x) > 0):
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `float`.
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`.
self._do_accumulate(box_xywh_gt, box_xywh_est, dp_gt)
self.nxt_bbox_index += 1
def _do_accumulate(self, box_xywh_gt: torch.Tensor, box_xywh_est: torch.Tensor, dp_gt: Any):
"""
Accumulate instances data for one image, given that the data is not empty
Args:
box_xywh_gt (tensor): GT bounding box
box_xywh_est (tensor): estimated bounding box
dp_gt: GT densepose data with the following attributes:
- x: normalized X coordinates
- y: normalized Y coordinates
- segm: tensor of size [S, S] with coarse segmentation
-
"""
self.x_gt.append(dp_gt.x)
self.y_gt.append(dp_gt.y)
if hasattr(dp_gt, "segm"):
self.s_gt.append(dp_gt.segm.unsqueeze(0))
self.vertex_ids_gt.append(dp_gt.vertex_ids)
self.vertex_mesh_ids_gt.append(torch.full_like(dp_gt.vertex_ids, dp_gt.mesh_id))
self.bbox_xywh_gt.append(box_xywh_gt.view(-1, 4))
self.bbox_xywh_est.append(box_xywh_est.view(-1, 4))
self.point_bbox_with_dp_indices.append(
torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_with_dp_index)
)
self.point_bbox_indices.append(torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_index))
self.bbox_indices.append(self.nxt_bbox_index)
self.nxt_bbox_with_dp_index += 1
def pack(self) -> Optional[PackedCseAnnotations]:
"""
Pack data into tensors
"""
if not len(self.x_gt):
# TODO:
# returning proper empty annotations would require
# creating empty tensors of appropriate shape and
# type on an appropriate device;
# we return None so far to indicate empty annotations
return None
return PackedCseAnnotations(
x_gt=torch.cat(self.x_gt, 0),
y_gt=torch.cat(self.y_gt, 0),
vertex_mesh_ids_gt=torch.cat(self.vertex_mesh_ids_gt, 0),
vertex_ids_gt=torch.cat(self.vertex_ids_gt, 0),
# ignore segmentation annotations, if not all the instances contain those
coarse_segm_gt=torch.cat(self.s_gt, 0)
if len(self.s_gt) == len(self.bbox_xywh_gt)
else None,
bbox_xywh_gt=torch.cat(self.bbox_xywh_gt, 0),
bbox_xywh_est=torch.cat(self.bbox_xywh_est, 0),
point_bbox_with_dp_indices=torch.cat(self.point_bbox_with_dp_indices, 0),
point_bbox_indices=torch.cat(self.point_bbox_indices, 0),
bbox_indices=torch.as_tensor(
self.bbox_indices, dtype=torch.long, device=self.x_gt[0].device
),
)