Spaces:
Runtime error
Runtime error
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py | |
# Reference: https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py | |
# Modified by Qihang Yu | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
_SOFTMAX_MASKING_CONSTANT = -99999.0 | |
# https://www.tensorflow.org/api_docs/python/tf/math/divide_no_nan | |
def divide_no_nan(x: torch.Tensor, y: torch.Tensor): | |
return torch.nan_to_num(x / y, nan=0.0, posinf=0.0, neginf=0.0) | |
# https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L393 | |
def focal_cross_entropy_loss( | |
pred: torch.Tensor, | |
gt: torch.Tensor, | |
weight: torch.Tensor, # This is for PQ-loss weighting | |
focal_loss_alpha: float = 0.75, | |
focal_loss_gamma: float = 0.0, | |
background_channel_index: int = -1): | |
""" | |
pred: B x N x C | |
gt: B x N | |
weight: B x N | |
""" | |
pred = pred.transpose(1, 2) # B x C x N | |
gt = F.one_hot(gt, num_classes=pred.shape[1]).transpose(1, 2).to(pred) # B x C x N | |
loss = F.cross_entropy(pred, gt, reduction="none") # B x N | |
if focal_loss_gamma == 0.0: | |
focal_loss = loss | |
else: | |
pred = F.softmax(pred, dim=1) # B x C x N | |
pt = (pred * gt).sum(1) # B x N | |
focal_loss = torch.pow(1.0 - pt, focal_loss_gamma) * loss # B x N | |
if focal_loss_alpha >= 0: | |
alpha_weights = ( | |
focal_loss_alpha * (1.0 - gt[:, background_channel_index]) | |
+ (1 - focal_loss_alpha) * gt[:, background_channel_index]) # B x N | |
focal_loss = alpha_weights * focal_loss # B x N | |
focal_loss = focal_loss * weight # B x N | |
focal_loss = focal_loss.flatten(1) | |
num_non_zero = (focal_loss != 0.0).to(focal_loss).sum(-1) # B | |
num_non_zero = torch.clamp(num_non_zero, min=1.0) | |
loss_sum_per_sample = focal_loss.sum(-1) # B | |
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 | |
# https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L50 | |
def _gumbel_topk_sample(logits: torch.Tensor, k: int): | |
"""Samples k points from the softmax distribution with Gumbel-Top-k trick.""" | |
# Note that torch.rand is [0, 1), we need to make it (0, 1) to ensure the log is valid. | |
gumbel_noise = torch.rand(size=logits.shape, dtype=logits.dtype, device=logits.device) | |
gumbel_noise = -torch.log(-torch.log(gumbel_noise)) | |
_, indices = torch.topk(logits + gumbel_noise, k) | |
return indices | |
# https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L576 | |
def pixelwise_insdis_loss( | |
pixel_feature: torch.Tensor, | |
gt_mask: torch.Tensor, | |
sample_temperature: float, | |
sample_k: int, | |
instance_discrimination_temperature: float, | |
pixel_gt_void_mask: torch.Tensor, | |
inverse_gt_mask_area: torch.Tensor | |
): | |
# pixel_feature: B x C x H x W | |
# gt_mask: B x N x H x W | |
pixel_feature = pixel_feature.flatten(2) # B x C x HW | |
gt_mask = gt_mask.flatten(2) # B x N x HW | |
pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW | |
inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW | |
sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW | |
# sample_logits.masked_fill_(pixel_gt_void_mask, float('-inf')) | |
sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT | |
sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K | |
# Sample ground truth one-hot encodings and compute gt_similarity. | |
pixel_gt_sampled_feature = torch.gather(gt_mask, dim=2, index=sample_indices.unsqueeze(1).repeat(1, gt_mask.shape[1], 1)) # B x N x K | |
sampled_gt_similarity = torch.einsum('bnk,bnj->bkj', pixel_gt_sampled_feature, pixel_gt_sampled_feature) # B x K x K | |
# Normalize the ground truth similarity into a distribution (sum to 1). | |
pixel_normalizing_constant = sampled_gt_similarity.sum(dim=1, keepdim=True) # B x 1 x K | |
sampled_gt_similarity /= torch.clamp(pixel_normalizing_constant, min=1.0) # B x K x K | |
# Sample predicted features and compute pred_similarity. | |
pixel_pred_sampled_feature = torch.gather(pixel_feature, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pixel_feature.shape[1], 1)) # B x C x K | |
sampled_pred_similarity = torch.einsum('bck,bcj->bkj', pixel_pred_sampled_feature, pixel_pred_sampled_feature) # B x K x K | |
sampled_pred_similarity /= instance_discrimination_temperature # B x K x K | |
loss = F.cross_entropy(sampled_pred_similarity, sampled_gt_similarity, reduction="none") # B x K | |
num_non_zero = (loss != 0.0).to(loss).sum(-1) # B | |
num_non_zero = torch.clamp(num_non_zero, min=1.0) | |
loss_sum_per_sample = loss.sum(-1) # B | |
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 | |
def aux_semantic_loss( | |
pred_semantic_logits: torch.Tensor, | |
ground_truth_semantic: torch.Tensor, | |
sample_temperature: float, | |
sample_k: int, | |
pixel_gt_void_mask: torch.Tensor, | |
inverse_gt_mask_area: torch.Tensor, | |
num_classes: int): | |
pred_semantic_logits = pred_semantic_logits.flatten(2) # B x C x HW | |
ground_truth_semantic = ground_truth_semantic.flatten(1) # B x HW | |
pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW | |
inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW | |
sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW | |
sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT | |
sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K | |
sampled_ground_truth_semantic = torch.gather(ground_truth_semantic, dim=1, index=sample_indices) # B x K | |
sampled_pred_semantic_logits = torch.gather(pred_semantic_logits, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pred_semantic_logits.shape[1], 1)) # B x C x K | |
# ignore the class index num_classes. | |
keep_mask = (sampled_ground_truth_semantic != num_classes) # B x K | |
loss = F.cross_entropy(sampled_pred_semantic_logits, sampled_ground_truth_semantic, ignore_index=num_classes, reduction='none') # B x K | |
loss = loss * keep_mask.to(loss) | |
num_non_zero = (loss != 0.0).to(loss).sum(-1) # B | |
num_non_zero = torch.clamp(num_non_zero, min=1.0) | |
loss_sum_per_sample = loss.sum(-1) # B | |
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 | |
# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L56 | |
# https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L510 | |
def dice_loss( | |
inputs: torch.Tensor, | |
targets: torch.Tensor, | |
pixel_gt_void_mask: torch.Tensor, | |
matched_cls_prob: torch.Tensor | |
): | |
""" | |
Compute the DICE loss, similar to generalized IOU for masks | |
Args: | |
inputs: A float tensor of arbitrary shape. | |
The predictions for each example. | |
targets: A float tensor with the same shape as inputs. Stores the binary | |
classification label for each element in inputs | |
(0 for the negative class and 1 for the positive class). | |
""" | |
inputs = inputs.softmax(1) # B N HW | |
# https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L111 | |
inputs = inputs.masked_fill(pixel_gt_void_mask.unsqueeze(1), 0) # remove void pixels. | |
smooth = 1.0 | |
intersection = 2 * (inputs * targets).sum(-1) + smooth # B x N | |
denominator = inputs.sum(-1) + targets.sum(-1) + smooth # B x N | |
loss = 1.0 - divide_no_nan(intersection, denominator) | |
loss *= matched_cls_prob | |
# Note: kMaX-DeepLab sum over num_masks and avg over batches. But here batch and num_mask are one | |
# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L559 | |
# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/max_deeplab_loss.py#L402 | |
# As the existing of modifer, it equals to multiplier by 0.75 | |
return (loss.sum(1) * 0.75/128).mean() # sum over masks and mean over batches. | |
def softmax_ce_loss( | |
inputs: torch.Tensor, | |
targets: torch.Tensor, | |
pixel_gt_void_mask: torch.Tensor, | |
): | |
""" | |
Args: | |
inputs: A float tensor of arbitrary shape. | |
The predictions for each example. | |
targets: A float tensor with the same shape as inputs. Stores the binary | |
classification label for each element in inputs | |
(0 for the negative class and 1 for the positive class). | |
Returns: | |
Loss tensor | |
""" | |
loss = F.cross_entropy(inputs, targets, reduction="none") # B x HW | |
loss = loss.masked_fill(pixel_gt_void_mask, 0) # remove void pixels. | |
num_non_zero = (loss != 0.0).to(loss).sum(-1) # B | |
num_non_zero = torch.clamp(num_non_zero, min=1.0) | |
loss_sum_per_sample = loss.sum(-1) # B | |
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 | |
class SetCriterion(nn.Module): | |
"""This class computes the loss for DETR. | |
The process happens in two steps: | |
1) we compute hungarian assignment between ground truth boxes and the outputs of the model | |
2) we supervise each pair of matched ground-truth / prediction (supervise class and box) | |
""" | |
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, share_final_matching, | |
pixel_insdis_temperature=1.5, pixel_insdis_sample_k=4096, | |
aux_semantic_temperature=2.0, aux_semantic_sample_k=4096): | |
"""Create the criterion. | |
Parameters: | |
num_classes: number of object categories, omitting the special no-object category | |
matcher: module able to compute a matching between targets and proposals | |
eos_coef: relative classification weight applied to the no-object category | |
losses: list of all the losses to be applied. See get_loss for list of available losses. | |
""" | |
super().__init__() | |
self.num_classes = num_classes | |
self.matcher = matcher | |
self.weight_dict = weight_dict | |
self.eos_coef = eos_coef | |
self.losses = losses | |
self.share_final_matching = share_final_matching | |
self.pixel_insdis_temperature = pixel_insdis_temperature | |
self.pixel_insdis_sample_k = pixel_insdis_sample_k | |
self.aux_semantic_temperature = aux_semantic_temperature | |
self.aux_semantic_sample_k = aux_semantic_sample_k | |
def loss_labels(self, outputs, targets): | |
"""Classification loss (NLL) | |
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] | |
""" | |
assert "pred_logits" in outputs | |
src_logits = outputs["pred_logits"] # B x N x C | |
target_classes = targets["labels"] # B x N | |
pq_loss_class_weight = targets["pq_loss_class_weight"] | |
losses = {"loss_ce": focal_cross_entropy_loss(src_logits, target_classes, pq_loss_class_weight)} | |
return losses | |
def loss_masks(self, outputs, targets): | |
"""Compute the losses related to the masks: the focal loss and the dice loss. | |
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] | |
""" | |
src_masks = outputs["pred_masks"] # B x N x H x W | |
target_masks = targets["masks"] | |
pq_loss_mask_weight = targets["pq_loss_mask_weight"] | |
pixel_gt_void_mask = targets["pixel_gt_void_mask"] | |
src_masks = src_masks.flatten(2) # B x N x HW | |
target_masks = target_masks.flatten(2) # B x N x HW | |
pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW | |
losses = { | |
"loss_mask": softmax_ce_loss(src_masks, target_masks, pixel_gt_void_mask), | |
"loss_dice": dice_loss(src_masks, target_masks, pixel_gt_void_mask, pq_loss_mask_weight), | |
} | |
return losses | |
def loss_pixels(self, outputs, targets): | |
pixel_feature = outputs["pixel_feature"] | |
target_masks = targets["masks"] | |
pixel_gt_void_mask = targets["pixel_gt_void_mask"] | |
inverse_gt_mask_area = targets["inverse_gt_mask_area"] | |
losses = {"loss_pixel_insdis": pixelwise_insdis_loss( | |
pixel_feature=pixel_feature, | |
gt_mask=target_masks, | |
sample_temperature=self.pixel_insdis_temperature, | |
sample_k=self.pixel_insdis_sample_k, | |
instance_discrimination_temperature=0.3, | |
pixel_gt_void_mask=pixel_gt_void_mask, | |
inverse_gt_mask_area=inverse_gt_mask_area | |
)} | |
del target_masks | |
return losses | |
def loss_semantic(self, outputs, targets): | |
pred_semantic_logits = outputs["aux_semantic_pred"] | |
ground_truth_semantic = targets["ground_truth_semantic"] | |
pixel_gt_void_mask = targets["pixel_gt_void_mask"].flatten(1) | |
inverse_gt_mask_area = targets["inverse_gt_mask_area"].flatten(1) | |
losses = {"loss_aux_semantic": aux_semantic_loss( | |
pred_semantic_logits=pred_semantic_logits, | |
ground_truth_semantic=ground_truth_semantic, | |
sample_temperature=self.aux_semantic_temperature, | |
sample_k=self.aux_semantic_sample_k, | |
pixel_gt_void_mask=pixel_gt_void_mask, | |
inverse_gt_mask_area=inverse_gt_mask_area, | |
num_classes=self.num_classes | |
)} | |
return losses | |
def _get_src_permutation_idx(self, indices): | |
# permute predictions following indices | |
# torch.full_like gives a tensor full of i in shape of src.shape | |
# at each iter, i is the index, src is the src ind in shape of (N) | |
# so batch_idx is concat of (0,0,...), (1,1,...), with shape (N0+N1+N2+...+Nb) | |
# so if we flatten gt/pred across bathces, this gives the batch_id of each sample | |
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
# src_idx is src_ind concated to shape (N0+N1+N2+...+Nb) | |
# it is a flattened concat of mask_id at each batch | |
src_idx = torch.cat([src for (src, _) in indices]) | |
return batch_idx, src_idx | |
def get_loss(self, loss, outputs, targets): | |
loss_map = { | |
'labels': self.loss_labels, | |
'masks': self.loss_masks, | |
'pixels': self.loss_pixels, | |
'aux_semantic': self.loss_semantic, | |
} | |
assert loss in loss_map, f"do you really want to compute {loss} loss?" | |
return loss_map[loss](outputs, targets) | |
def process_gt(self, outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=False): | |
# Permute&Pad Pred> for loss compuation. | |
# By controling process_gt, we can share the matching results for all preds. | |
src_idx = self._get_src_permutation_idx(indices) | |
src_masks = outputs["pred_masks"].detach() # B x N x H x W | |
# Pad and permute the target_mask to B x N x H x W | |
target_masks = torch.zeros_like(src_masks) | |
target_masks_o = torch.cat([t["masks"][J] for t, (_, J) in zip(targets, indices)]).to(target_masks) | |
target_masks[src_idx] = target_masks_o | |
# Pad and permute the matched_cls_prob to B x N | |
matched_cls_prob_o = torch.cat([cls_prob for cls_prob in matched_cls_prob]) | |
matched_cls_prob_o = torch.clamp(matched_cls_prob_o, min=self.eos_coef) | |
# https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L1034 | |
# no penalty for unmatched masks. | |
matched_cls_prob = torch.full( | |
src_masks.shape[:2], 0, dtype=src_masks.dtype, device=src_masks.device | |
) # B x N | |
matched_cls_prob[src_idx] = matched_cls_prob_o.to(matched_cls_prob) | |
# pixel_gt_void_mask is used to indicate those pixels without labels. | |
pixel_gt_void_mask = (target_masks.sum(1) < 1) # B x H x W | |
# inverse_gt_mask_area is used to sample pixels. | |
mask_gt_area = target_masks.sum(2).sum(2) # B x N | |
pixel_gt_area = torch.einsum('bnhw,bn->bhw', target_masks, mask_gt_area) # B x H x W | |
inverse_gt_mask_area = (pixel_gt_area.shape[1] * pixel_gt_area.shape[2]) / torch.clamp(pixel_gt_area, min=1.0) # B x H x W | |
src_logits = outputs["pred_logits"] # B x N x C | |
# Pad and permute the target_classes to B x N | |
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) | |
# This serves as a padding. | |
target_classes = torch.full( | |
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device | |
) | |
# We put real GT to those corresponds to src_idx, and put void into other places. | |
target_classes[src_idx] = target_classes_o | |
src_masks_prob = src_masks.softmax(1) | |
void_mask = pixel_gt_void_mask.to(src_masks_prob) # B x H x W | |
# compute iou instead of dice for void overlapping. | |
def computer_iou_score(x, y): | |
# x : B x N x H x W | |
# y : B x H x W | |
x = x.flatten(2) # B x N x L | |
y = y.flatten(1) # B x L | |
intersection = torch.einsum('bnl,bl->bn', x, y) # B x N | |
denominator = x.sum(-1) # B x N | |
return intersection / (denominator + 1e-5) # B x N | |
# Pad and permute the matched_dice to B x N | |
matched_dice_o = torch.cat([dice for dice in matched_dice]) | |
matched_dice = computer_iou_score(src_masks_prob, void_mask) # unmatched masks use their dice with void | |
matched_dice[src_idx] = matched_dice_o.to(matched_dice) | |
matched_dice = torch.clamp(matched_dice, min=self.eos_coef) | |
processed_gt = {"masks": target_masks, "labels": target_classes, | |
"pq_loss_mask_weight": matched_cls_prob, | |
"pq_loss_class_weight": matched_dice, | |
"pixel_gt_void_mask": pixel_gt_void_mask, | |
"inverse_gt_mask_area": inverse_gt_mask_area,} | |
if process_semantic: | |
# To obtain semantic gt | |
ground_truth_semantic = [t["semantic_masks"] for t in targets] | |
ground_truth_semantic = torch.stack(ground_truth_semantic, dim=0) # B x H x W | |
# self.num_classes is set to ignore label | |
ground_truth_semantic[ground_truth_semantic==-1] = self.num_classes | |
processed_gt.update({"ground_truth_semantic": ground_truth_semantic}) | |
return processed_gt | |
def forward(self, outputs, targets): | |
"""This performs the loss computation. | |
Parameters: | |
outputs: dict of tensors, see the output specification of the model for the format | |
targets: list of dicts, such that len(targets) == batch_size. | |
The expected keys in each dict depends on the losses applied, see each loss' doc | |
""" | |
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} | |
indices, matched_dice, matched_cls_prob = self.matcher(outputs_without_aux, targets) | |
# Pad GT to the same number of prediction. | |
processed_targets = self.process_gt(outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=True) | |
# Compute all the requested losses | |
losses = {} | |
for loss in self.losses: | |
losses.update(self.get_loss(loss, outputs, processed_targets)) | |
if "aux_outputs" in outputs: | |
for i, aux_outputs in enumerate(outputs["aux_outputs"]): | |
# We share matching results across predictions. | |
if not self.share_final_matching: | |
indices, matched_dice, matched_cls_prob = self.matcher(aux_outputs, targets) | |
if not self.share_final_matching: | |
processed_targets = self.process_gt(aux_outputs, targets, indices, matched_dice, matched_cls_prob) | |
for loss in self.losses: | |
if loss in ['aux_semantic']: | |
# Only for final output. | |
continue | |
l_dict = self.get_loss(loss, aux_outputs, processed_targets) | |
l_dict = {k + f"_{i}": v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
return losses | |
def __repr__(self): | |
head = "Criterion " + self.__class__.__name__ | |
body = [ | |
"matcher: {}".format(self.matcher.__repr__(_repr_indent=8)), | |
"losses: {}".format(self.losses), | |
"weight_dict: {}".format(self.weight_dict), | |
"num_classes: {}".format(self.num_classes), | |
"eos_coef: {}".format(self.eos_coef), | |
] | |
_repr_indent = 4 | |
lines = [head] + [" " * _repr_indent + line for line in body] | |
return "\n".join(lines) |