File size: 541 Bytes
6755a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch


def find_outlier(grid: torch.Tensor) -> torch.Tensor:
    """find outlier coordinary out of grid

    Args:
        grid (torch.Tensor): Bx2xHxW

    Returns:
        mask: ndarray, BxHxW, 1 for coordinary in grid, 0 for outlier
    """
    b, _, h, w = grid.shape
    mask = torch.ones((b, h, w))
    outlier_x_coordinary = (grid[:,0,:,:] >= w).nonzero(as_tuple=True)
    outlier_y_coordinary = (grid[:,1,:,:] >= h).nonzero(as_tuple=True)
    mask[outlier_x_coordinary] = 0
    mask[outlier_y_coordinary] = 0
    return mask