Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import math | |
from typing import List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule, is_norm | |
from mmcv.ops import batched_nms | |
from mmengine.model import (BaseModule, bias_init_with_prob, constant_init, | |
normal_init) | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmdet.models.layers.transformer import inverse_sigmoid | |
from mmdet.models.utils import (filter_scores_and_topk, multi_apply, | |
select_single_mlvl, sigmoid_geometric_mean) | |
from mmdet.registry import MODELS | |
from mmdet.structures.bbox import (cat_boxes, distance2bbox, get_box_tensor, | |
get_box_wh, scale_boxes) | |
from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean | |
from mmdet.models.dense_heads.rtmdet_head import RTMDetHead | |
from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsHead, RTMDetInsSepBNHead, MaskFeatModule | |
from mmdet.utils import AvoidCUDAOOM | |
def sthgoeswrong(logits): | |
return torch.any(torch.isnan(logits)) or torch.any(torch.isinf(logits)) | |
from time import time | |
class RTMDetInsHeadCustom(RTMDetInsHead): | |
def loss_by_feat(self, | |
cls_scores: List[Tensor], | |
bbox_preds: List[Tensor], | |
kernel_preds: List[Tensor], | |
mask_feat: Tensor, | |
batch_gt_instances: InstanceList, | |
batch_img_metas: List[dict], | |
batch_gt_instances_ignore: OptInstanceList = None): | |
"""Compute losses of the head. | |
Args: | |
cls_scores (list[Tensor]): Box scores for each scale level | |
Has shape (N, num_anchors * num_classes, H, W) | |
bbox_preds (list[Tensor]): Decoded box for each scale | |
level with shape (N, num_anchors * 4, H, W) in | |
[tl_x, tl_y, br_x, br_y] format. | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): | |
Batch of gt_instances_ignore. It includes ``bboxes`` attribute | |
data that is ignored during training and testing. | |
Defaults to None. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
num_imgs = len(batch_img_metas) | |
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] | |
assert len(featmap_sizes) == self.prior_generator.num_levels | |
device = cls_scores[0].device | |
anchor_list, valid_flag_list = self.get_anchors( | |
featmap_sizes, batch_img_metas, device=device) | |
flatten_cls_scores = torch.cat([ | |
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, | |
self.cls_out_channels) | |
for cls_score in cls_scores | |
], 1) | |
flatten_kernels = torch.cat([ | |
kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, | |
self.num_gen_params) | |
for kernel_pred in kernel_preds | |
], 1) | |
decoded_bboxes = [] | |
for anchor, bbox_pred in zip(anchor_list[0], bbox_preds): | |
anchor = anchor.reshape(-1, 4) | |
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) | |
bbox_pred = distance2bbox(anchor, bbox_pred) | |
decoded_bboxes.append(bbox_pred) | |
flatten_bboxes = torch.cat(decoded_bboxes, 1) | |
for gt_instances in batch_gt_instances: | |
gt_instances.masks = gt_instances.masks.to_tensor( | |
dtype=torch.bool, device=device) | |
cls_reg_targets = self.get_targets( | |
flatten_cls_scores, | |
flatten_bboxes, | |
anchor_list, | |
valid_flag_list, | |
batch_gt_instances, | |
batch_img_metas, | |
batch_gt_instances_ignore=batch_gt_instances_ignore) | |
(anchor_list, labels_list, label_weights_list, bbox_targets_list, | |
assign_metrics_list, sampling_results_list) = cls_reg_targets | |
losses_cls, losses_bbox,\ | |
cls_avg_factors, bbox_avg_factors = multi_apply( | |
self.loss_by_feat_single, | |
cls_scores, | |
decoded_bboxes, | |
labels_list, | |
label_weights_list, | |
bbox_targets_list, | |
assign_metrics_list, | |
self.prior_generator.strides) | |
cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() | |
losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls)) | |
bbox_avg_factor = reduce_mean( | |
sum(bbox_avg_factors)).clamp_(min=1).item() | |
losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) | |
loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels, | |
sampling_results_list, | |
batch_gt_instances) | |
loss = dict( | |
loss_cls=losses_cls, loss_bbox=losses_bbox, loss_mask=loss_mask) | |
return loss | |
def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor, | |
priors: Tensor) -> Tensor: | |
ori_maskfeat = mask_feat | |
num_inst = priors.shape[0] | |
h, w = mask_feat.size()[-2:] | |
if num_inst < 1: | |
return torch.empty( | |
size=(num_inst, h, w), | |
dtype=mask_feat.dtype, | |
device=mask_feat.device) | |
if len(mask_feat.shape) < 4: | |
mask_feat.unsqueeze(0) | |
coord = self.prior_generator.single_level_grid_priors( | |
(h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2) | |
num_inst = priors.shape[0] | |
points = priors[:, :2].reshape(-1, 1, 2) | |
strides = priors[:, 2:].reshape(-1, 1, 2) | |
relative_coord = (points - coord).permute(0, 2, 1) / ( | |
strides[..., 0].reshape(-1, 1, 1) * 8) | |
relative_coord = relative_coord.reshape(num_inst, 2, h, w) | |
mask_feat = torch.cat( | |
[relative_coord, | |
mask_feat.repeat(num_inst, 1, 1, 1)], dim=1) | |
weights, biases = self.parse_dynamic_params(kernels) | |
fp16_used = weights[0].dtype == torch.float16 | |
n_layers = len(weights) | |
x = mask_feat.reshape(1, -1, h, w) | |
for i, (weight, bias) in enumerate(zip(weights, biases)): | |
with torch.cuda.amp.autocast(enabled=False): | |
if fp16_used: | |
weight = weight.to(torch.float32) | |
bias = bias.to(torch.float32) | |
x = F.conv2d( | |
x, weight, bias=bias, stride=1, padding=0, groups=num_inst) | |
if i < n_layers - 1: | |
x = F.relu(x) | |
if fp16_used: | |
x = torch.clip(x, -8192, 8192) | |
if sthgoeswrong(x): | |
torch.save({'mask_feat': ori_maskfeat, 'kernels': kernels, 'priors': priors}, 'maskhead_nan_input.pt') | |
raise Exception('Mask Head NaN') | |
x = x.reshape(num_inst, h, w) | |
return x | |
def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor, | |
sampling_results_list: list, | |
batch_gt_instances: InstanceList) -> Tensor: | |
batch_pos_mask_logits = [] | |
pos_gt_masks = [] | |
ignore_masks = [] | |
for idx, (mask_feat, kernels, sampling_results, | |
gt_instances) in enumerate( | |
zip(mask_feats, flatten_kernels, sampling_results_list, | |
batch_gt_instances)): | |
pos_priors = sampling_results.pos_priors | |
pos_inds = sampling_results.pos_inds | |
pos_kernels = kernels[pos_inds] # n_pos, num_gen_params | |
pos_mask_logits = self._mask_predict_by_feat_single( | |
mask_feat, pos_kernels, pos_priors) | |
if gt_instances.masks.numel() == 0: | |
gt_masks = torch.empty_like(gt_instances.masks) | |
if gt_masks.shape[0] > 0: | |
ignore = torch.zeros(gt_masks.shape[0], dtype=torch.bool).to(device=gt_masks.device) | |
ignore_masks.append(ignore) | |
else: | |
gt_masks = gt_instances.masks[ | |
sampling_results.pos_assigned_gt_inds, :] | |
ignore_masks.append(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds]) | |
batch_pos_mask_logits.append(pos_mask_logits) | |
pos_gt_masks.append(gt_masks) | |
pos_gt_masks = torch.cat(pos_gt_masks, 0) | |
batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0) | |
ignore_masks = torch.logical_not(torch.cat(ignore_masks, 0)) | |
pos_gt_masks = pos_gt_masks[ignore_masks] | |
batch_pos_mask_logits = batch_pos_mask_logits[ignore_masks] | |
# avg_factor | |
num_pos = batch_pos_mask_logits.shape[0] | |
num_pos = reduce_mean(mask_feats.new_tensor([num_pos | |
])).clamp_(min=1).item() | |
if batch_pos_mask_logits.shape[0] == 0: | |
return mask_feats.sum() * 0 | |
scale = self.prior_generator.strides[0][0] // self.mask_loss_stride | |
# upsample pred masks | |
batch_pos_mask_logits = F.interpolate( | |
batch_pos_mask_logits.unsqueeze(0), | |
scale_factor=scale, | |
mode='bilinear', | |
align_corners=False).squeeze(0) | |
# downsample gt masks | |
pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride // | |
2::self.mask_loss_stride, | |
self.mask_loss_stride // | |
2::self.mask_loss_stride] | |
loss_mask = self.loss_mask( | |
batch_pos_mask_logits, | |
pos_gt_masks, | |
weight=None, | |
avg_factor=num_pos) | |
return loss_mask | |
class RTMDetInsSepBNHeadCustom(RTMDetInsSepBNHead): | |
def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor, | |
priors: Tensor) -> Tensor: | |
ori_maskfeat = mask_feat | |
num_inst = priors.shape[0] | |
h, w = mask_feat.size()[-2:] | |
if num_inst < 1: | |
return torch.empty( | |
size=(num_inst, h, w), | |
dtype=mask_feat.dtype, | |
device=mask_feat.device) | |
if len(mask_feat.shape) < 4: | |
mask_feat.unsqueeze(0) | |
coord = self.prior_generator.single_level_grid_priors( | |
(h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2) | |
num_inst = priors.shape[0] | |
points = priors[:, :2].reshape(-1, 1, 2) | |
strides = priors[:, 2:].reshape(-1, 1, 2) | |
relative_coord = (points - coord).permute(0, 2, 1) / ( | |
strides[..., 0].reshape(-1, 1, 1) * 8) | |
relative_coord = relative_coord.reshape(num_inst, 2, h, w) | |
mask_feat = torch.cat( | |
[relative_coord, | |
mask_feat.repeat(num_inst, 1, 1, 1)], dim=1) | |
weights, biases = self.parse_dynamic_params(kernels) | |
fp16_used = weights[0].dtype == torch.float16 | |
n_layers = len(weights) | |
x = mask_feat.reshape(1, -1, h, w) | |
for i, (weight, bias) in enumerate(zip(weights, biases)): | |
with torch.cuda.amp.autocast(enabled=False): | |
if fp16_used: | |
weight = weight.to(torch.float32) | |
bias = bias.to(torch.float32) | |
x = F.conv2d( | |
x, weight, bias=bias, stride=1, padding=0, groups=num_inst) | |
if i < n_layers - 1: | |
x = F.relu(x) | |
if fp16_used: | |
x = torch.clip(x, -8192, 8192) | |
if sthgoeswrong(x): | |
torch.save({'mask_feat': ori_maskfeat, 'kernels': kernels, 'priors': priors}, 'maskhead_nan_input.pt') | |
raise Exception('Mask Head NaN') | |
x = x.reshape(num_inst, h, w) | |
return x | |
def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor, | |
sampling_results_list: list, | |
batch_gt_instances: InstanceList) -> Tensor: | |
batch_pos_mask_logits = [] | |
pos_gt_masks = [] | |
ignore_masks = [] | |
for idx, (mask_feat, kernels, sampling_results, | |
gt_instances) in enumerate( | |
zip(mask_feats, flatten_kernels, sampling_results_list, | |
batch_gt_instances)): | |
pos_priors = sampling_results.pos_priors | |
pos_inds = sampling_results.pos_inds | |
pos_kernels = kernels[pos_inds] # n_pos, num_gen_params | |
pos_mask_logits = self._mask_predict_by_feat_single( | |
mask_feat, pos_kernels, pos_priors) | |
if gt_instances.masks.numel() == 0: | |
gt_masks = torch.empty_like(gt_instances.masks) | |
# if gt_masks.shape[0] > 0: | |
# ignore = torch.zeros(gt_masks.shape[0], dtype=torch.bool).to(device=gt_masks.device) | |
# ignore_masks.append(ignore) | |
else: | |
msk = torch.logical_not(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds]) | |
gt_masks = gt_instances.masks[ | |
sampling_results.pos_assigned_gt_inds, :][msk] | |
pos_mask_logits = pos_mask_logits[msk] | |
# ignore_masks.append(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds]) | |
batch_pos_mask_logits.append(pos_mask_logits) | |
pos_gt_masks.append(gt_masks) | |
pos_gt_masks = torch.cat(pos_gt_masks, 0) | |
batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0) | |
# ignore_masks = torch.logical_not(torch.cat(ignore_masks, 0)) | |
# pos_gt_masks = pos_gt_masks[ignore_masks] | |
# batch_pos_mask_logits = batch_pos_mask_logits[ignore_masks] | |
# avg_factor | |
num_pos = batch_pos_mask_logits.shape[0] | |
num_pos = reduce_mean(mask_feats.new_tensor([num_pos | |
])).clamp_(min=1).item() | |
if batch_pos_mask_logits.shape[0] == 0: | |
return mask_feats.sum() * 0 | |
scale = self.prior_generator.strides[0][0] // self.mask_loss_stride | |
# upsample pred masks | |
batch_pos_mask_logits = F.interpolate( | |
batch_pos_mask_logits.unsqueeze(0), | |
scale_factor=scale, | |
mode='bilinear', | |
align_corners=False).squeeze(0) | |
# downsample gt masks | |
pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride // | |
2::self.mask_loss_stride, | |
self.mask_loss_stride // | |
2::self.mask_loss_stride] | |
loss_mask = self.loss_mask( | |
batch_pos_mask_logits, | |
pos_gt_masks, | |
weight=None, | |
avg_factor=num_pos) | |
return loss_mask | |