# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py # Reference: https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py # Modified by Qihang Yu import torch import torch.nn.functional as F from torch import nn _SOFTMAX_MASKING_CONSTANT = -99999.0 # https://www.tensorflow.org/api_docs/python/tf/math/divide_no_nan def divide_no_nan(x: torch.Tensor, y: torch.Tensor): return torch.nan_to_num(x / y, nan=0.0, posinf=0.0, neginf=0.0) # https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L393 def focal_cross_entropy_loss( pred: torch.Tensor, gt: torch.Tensor, weight: torch.Tensor, # This is for PQ-loss weighting focal_loss_alpha: float = 0.75, focal_loss_gamma: float = 0.0, background_channel_index: int = -1): """ pred: B x N x C gt: B x N weight: B x N """ pred = pred.transpose(1, 2) # B x C x N gt = F.one_hot(gt, num_classes=pred.shape[1]).transpose(1, 2).to(pred) # B x C x N loss = F.cross_entropy(pred, gt, reduction="none") # B x N if focal_loss_gamma == 0.0: focal_loss = loss else: pred = F.softmax(pred, dim=1) # B x C x N pt = (pred * gt).sum(1) # B x N focal_loss = torch.pow(1.0 - pt, focal_loss_gamma) * loss # B x N if focal_loss_alpha >= 0: alpha_weights = ( focal_loss_alpha * (1.0 - gt[:, background_channel_index]) + (1 - focal_loss_alpha) * gt[:, background_channel_index]) # B x N focal_loss = alpha_weights * focal_loss # B x N focal_loss = focal_loss * weight # B x N focal_loss = focal_loss.flatten(1) num_non_zero = (focal_loss != 0.0).to(focal_loss).sum(-1) # B num_non_zero = torch.clamp(num_non_zero, min=1.0) loss_sum_per_sample = focal_loss.sum(-1) # B return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 # https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L50 def _gumbel_topk_sample(logits: torch.Tensor, k: int): """Samples k points from the softmax distribution with Gumbel-Top-k trick.""" # Note that torch.rand is [0, 1), we need to make it (0, 1) to ensure the log is valid. gumbel_noise = torch.rand(size=logits.shape, dtype=logits.dtype, device=logits.device) gumbel_noise = -torch.log(-torch.log(gumbel_noise)) _, indices = torch.topk(logits + gumbel_noise, k) return indices # https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L576 def pixelwise_insdis_loss( pixel_feature: torch.Tensor, gt_mask: torch.Tensor, sample_temperature: float, sample_k: int, instance_discrimination_temperature: float, pixel_gt_void_mask: torch.Tensor, inverse_gt_mask_area: torch.Tensor ): # pixel_feature: B x C x H x W # gt_mask: B x N x H x W pixel_feature = pixel_feature.flatten(2) # B x C x HW gt_mask = gt_mask.flatten(2) # B x N x HW pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW # sample_logits.masked_fill_(pixel_gt_void_mask, float('-inf')) sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K # Sample ground truth one-hot encodings and compute gt_similarity. pixel_gt_sampled_feature = torch.gather(gt_mask, dim=2, index=sample_indices.unsqueeze(1).repeat(1, gt_mask.shape[1], 1)) # B x N x K sampled_gt_similarity = torch.einsum('bnk,bnj->bkj', pixel_gt_sampled_feature, pixel_gt_sampled_feature) # B x K x K # Normalize the ground truth similarity into a distribution (sum to 1). pixel_normalizing_constant = sampled_gt_similarity.sum(dim=1, keepdim=True) # B x 1 x K sampled_gt_similarity /= torch.clamp(pixel_normalizing_constant, min=1.0) # B x K x K # Sample predicted features and compute pred_similarity. pixel_pred_sampled_feature = torch.gather(pixel_feature, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pixel_feature.shape[1], 1)) # B x C x K sampled_pred_similarity = torch.einsum('bck,bcj->bkj', pixel_pred_sampled_feature, pixel_pred_sampled_feature) # B x K x K sampled_pred_similarity /= instance_discrimination_temperature # B x K x K loss = F.cross_entropy(sampled_pred_similarity, sampled_gt_similarity, reduction="none") # B x K num_non_zero = (loss != 0.0).to(loss).sum(-1) # B num_non_zero = torch.clamp(num_non_zero, min=1.0) loss_sum_per_sample = loss.sum(-1) # B return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 def aux_semantic_loss( pred_semantic_logits: torch.Tensor, ground_truth_semantic: torch.Tensor, sample_temperature: float, sample_k: int, pixel_gt_void_mask: torch.Tensor, inverse_gt_mask_area: torch.Tensor, num_classes: int): pred_semantic_logits = pred_semantic_logits.flatten(2) # B x C x HW ground_truth_semantic = ground_truth_semantic.flatten(1) # B x HW pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K sampled_ground_truth_semantic = torch.gather(ground_truth_semantic, dim=1, index=sample_indices) # B x K sampled_pred_semantic_logits = torch.gather(pred_semantic_logits, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pred_semantic_logits.shape[1], 1)) # B x C x K # ignore the class index num_classes. keep_mask = (sampled_ground_truth_semantic != num_classes) # B x K loss = F.cross_entropy(sampled_pred_semantic_logits, sampled_ground_truth_semantic, ignore_index=num_classes, reduction='none') # B x K loss = loss * keep_mask.to(loss) num_non_zero = (loss != 0.0).to(loss).sum(-1) # B num_non_zero = torch.clamp(num_non_zero, min=1.0) loss_sum_per_sample = loss.sum(-1) # B return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L56 # https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L510 def dice_loss( inputs: torch.Tensor, targets: torch.Tensor, pixel_gt_void_mask: torch.Tensor, matched_cls_prob: torch.Tensor ): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.softmax(1) # B N HW # https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L111 inputs = inputs.masked_fill(pixel_gt_void_mask.unsqueeze(1), 0) # remove void pixels. smooth = 1.0 intersection = 2 * (inputs * targets).sum(-1) + smooth # B x N denominator = inputs.sum(-1) + targets.sum(-1) + smooth # B x N loss = 1.0 - divide_no_nan(intersection, denominator) loss *= matched_cls_prob # Note: kMaX-DeepLab sum over num_masks and avg over batches. But here batch and num_mask are one # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L559 # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/max_deeplab_loss.py#L402 # As the existing of modifer, it equals to multiplier by 0.75 return (loss.sum(1) * 0.75/128).mean() # sum over masks and mean over batches. def softmax_ce_loss( inputs: torch.Tensor, targets: torch.Tensor, pixel_gt_void_mask: torch.Tensor, ): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ loss = F.cross_entropy(inputs, targets, reduction="none") # B x HW loss = loss.masked_fill(pixel_gt_void_mask, 0) # remove void pixels. num_non_zero = (loss != 0.0).to(loss).sum(-1) # B num_non_zero = torch.clamp(num_non_zero, min=1.0) loss_sum_per_sample = loss.sum(-1) # B return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1 class SetCriterion(nn.Module): """This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, share_final_matching, pixel_insdis_temperature=1.5, pixel_insdis_sample_k=4096, aux_semantic_temperature=2.0, aux_semantic_sample_k=4096): """Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category matcher: module able to compute a matching between targets and proposals eos_coef: relative classification weight applied to the no-object category losses: list of all the losses to be applied. See get_loss for list of available losses. """ super().__init__() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.eos_coef = eos_coef self.losses = losses self.share_final_matching = share_final_matching self.pixel_insdis_temperature = pixel_insdis_temperature self.pixel_insdis_sample_k = pixel_insdis_sample_k self.aux_semantic_temperature = aux_semantic_temperature self.aux_semantic_sample_k = aux_semantic_sample_k def loss_labels(self, outputs, targets): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ assert "pred_logits" in outputs src_logits = outputs["pred_logits"] # B x N x C target_classes = targets["labels"] # B x N pq_loss_class_weight = targets["pq_loss_class_weight"] losses = {"loss_ce": focal_cross_entropy_loss(src_logits, target_classes, pq_loss_class_weight)} return losses def loss_masks(self, outputs, targets): """Compute the losses related to the masks: the focal loss and the dice loss. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] """ src_masks = outputs["pred_masks"] # B x N x H x W target_masks = targets["masks"] pq_loss_mask_weight = targets["pq_loss_mask_weight"] pixel_gt_void_mask = targets["pixel_gt_void_mask"] src_masks = src_masks.flatten(2) # B x N x HW target_masks = target_masks.flatten(2) # B x N x HW pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW losses = { "loss_mask": softmax_ce_loss(src_masks, target_masks, pixel_gt_void_mask), "loss_dice": dice_loss(src_masks, target_masks, pixel_gt_void_mask, pq_loss_mask_weight), } return losses def loss_pixels(self, outputs, targets): pixel_feature = outputs["pixel_feature"] target_masks = targets["masks"] pixel_gt_void_mask = targets["pixel_gt_void_mask"] inverse_gt_mask_area = targets["inverse_gt_mask_area"] losses = {"loss_pixel_insdis": pixelwise_insdis_loss( pixel_feature=pixel_feature, gt_mask=target_masks, sample_temperature=self.pixel_insdis_temperature, sample_k=self.pixel_insdis_sample_k, instance_discrimination_temperature=0.3, pixel_gt_void_mask=pixel_gt_void_mask, inverse_gt_mask_area=inverse_gt_mask_area )} del target_masks return losses def loss_semantic(self, outputs, targets): pred_semantic_logits = outputs["aux_semantic_pred"] ground_truth_semantic = targets["ground_truth_semantic"] pixel_gt_void_mask = targets["pixel_gt_void_mask"].flatten(1) inverse_gt_mask_area = targets["inverse_gt_mask_area"].flatten(1) losses = {"loss_aux_semantic": aux_semantic_loss( pred_semantic_logits=pred_semantic_logits, ground_truth_semantic=ground_truth_semantic, sample_temperature=self.aux_semantic_temperature, sample_k=self.aux_semantic_sample_k, pixel_gt_void_mask=pixel_gt_void_mask, inverse_gt_mask_area=inverse_gt_mask_area, num_classes=self.num_classes )} return losses @torch.no_grad() def _get_src_permutation_idx(self, indices): # permute predictions following indices # torch.full_like gives a tensor full of i in shape of src.shape # at each iter, i is the index, src is the src ind in shape of (N) # so batch_idx is concat of (0,0,...), (1,1,...), with shape (N0+N1+N2+...+Nb) # so if we flatten gt/pred across bathces, this gives the batch_id of each sample batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) # src_idx is src_ind concated to shape (N0+N1+N2+...+Nb) # it is a flattened concat of mask_id at each batch src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def get_loss(self, loss, outputs, targets): loss_map = { 'labels': self.loss_labels, 'masks': self.loss_masks, 'pixels': self.loss_pixels, 'aux_semantic': self.loss_semantic, } assert loss in loss_map, f"do you really want to compute {loss} loss?" return loss_map[loss](outputs, targets) @torch.no_grad() def process_gt(self, outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=False): # Permute&Pad Pred> for loss compuation. # By controling process_gt, we can share the matching results for all preds. src_idx = self._get_src_permutation_idx(indices) src_masks = outputs["pred_masks"].detach() # B x N x H x W # Pad and permute the target_mask to B x N x H x W target_masks = torch.zeros_like(src_masks) target_masks_o = torch.cat([t["masks"][J] for t, (_, J) in zip(targets, indices)]).to(target_masks) target_masks[src_idx] = target_masks_o # Pad and permute the matched_cls_prob to B x N matched_cls_prob_o = torch.cat([cls_prob for cls_prob in matched_cls_prob]) matched_cls_prob_o = torch.clamp(matched_cls_prob_o, min=self.eos_coef) # https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L1034 # no penalty for unmatched masks. matched_cls_prob = torch.full( src_masks.shape[:2], 0, dtype=src_masks.dtype, device=src_masks.device ) # B x N matched_cls_prob[src_idx] = matched_cls_prob_o.to(matched_cls_prob) # pixel_gt_void_mask is used to indicate those pixels without labels. pixel_gt_void_mask = (target_masks.sum(1) < 1) # B x H x W # inverse_gt_mask_area is used to sample pixels. mask_gt_area = target_masks.sum(2).sum(2) # B x N pixel_gt_area = torch.einsum('bnhw,bn->bhw', target_masks, mask_gt_area) # B x H x W inverse_gt_mask_area = (pixel_gt_area.shape[1] * pixel_gt_area.shape[2]) / torch.clamp(pixel_gt_area, min=1.0) # B x H x W src_logits = outputs["pred_logits"] # B x N x C # Pad and permute the target_classes to B x N target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) # This serves as a padding. target_classes = torch.full( src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device ) # We put real GT to those corresponds to src_idx, and put void into other places. target_classes[src_idx] = target_classes_o src_masks_prob = src_masks.softmax(1) void_mask = pixel_gt_void_mask.to(src_masks_prob) # B x H x W # compute iou instead of dice for void overlapping. def computer_iou_score(x, y): # x : B x N x H x W # y : B x H x W x = x.flatten(2) # B x N x L y = y.flatten(1) # B x L intersection = torch.einsum('bnl,bl->bn', x, y) # B x N denominator = x.sum(-1) # B x N return intersection / (denominator + 1e-5) # B x N # Pad and permute the matched_dice to B x N matched_dice_o = torch.cat([dice for dice in matched_dice]) matched_dice = computer_iou_score(src_masks_prob, void_mask) # unmatched masks use their dice with void matched_dice[src_idx] = matched_dice_o.to(matched_dice) matched_dice = torch.clamp(matched_dice, min=self.eos_coef) processed_gt = {"masks": target_masks, "labels": target_classes, "pq_loss_mask_weight": matched_cls_prob, "pq_loss_class_weight": matched_dice, "pixel_gt_void_mask": pixel_gt_void_mask, "inverse_gt_mask_area": inverse_gt_mask_area,} if process_semantic: # To obtain semantic gt ground_truth_semantic = [t["semantic_masks"] for t in targets] ground_truth_semantic = torch.stack(ground_truth_semantic, dim=0) # B x H x W # self.num_classes is set to ignore label ground_truth_semantic[ground_truth_semantic==-1] = self.num_classes processed_gt.update({"ground_truth_semantic": ground_truth_semantic}) return processed_gt def forward(self, outputs, targets): """This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} indices, matched_dice, matched_cls_prob = self.matcher(outputs_without_aux, targets) # Pad GT to the same number of prediction. processed_targets = self.process_gt(outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=True) # Compute all the requested losses losses = {} for loss in self.losses: losses.update(self.get_loss(loss, outputs, processed_targets)) if "aux_outputs" in outputs: for i, aux_outputs in enumerate(outputs["aux_outputs"]): # We share matching results across predictions. if not self.share_final_matching: indices, matched_dice, matched_cls_prob = self.matcher(aux_outputs, targets) if not self.share_final_matching: processed_targets = self.process_gt(aux_outputs, targets, indices, matched_dice, matched_cls_prob) for loss in self.losses: if loss in ['aux_semantic']: # Only for final output. continue l_dict = self.get_loss(loss, aux_outputs, processed_targets) l_dict = {k + f"_{i}": v for k, v in l_dict.items()} losses.update(l_dict) return losses def __repr__(self): head = "Criterion " + self.__class__.__name__ body = [ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)), "losses: {}".format(self.losses), "weight_dict: {}".format(self.weight_dict), "num_classes: {}".format(self.num_classes), "eos_coef: {}".format(self.eos_coef), ] _repr_indent = 4 lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines)