Spaces:
No application file
No application file
File size: 541 Bytes
6755a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
def find_outlier(grid: torch.Tensor) -> torch.Tensor:
"""find outlier coordinary out of grid
Args:
grid (torch.Tensor): Bx2xHxW
Returns:
mask: ndarray, BxHxW, 1 for coordinary in grid, 0 for outlier
"""
b, _, h, w = grid.shape
mask = torch.ones((b, h, w))
outlier_x_coordinary = (grid[:,0,:,:] >= w).nonzero(as_tuple=True)
outlier_y_coordinary = (grid[:,1,:,:] >= h).nonzero(as_tuple=True)
mask[outlier_x_coordinary] = 0
mask[outlier_y_coordinary] = 0
return mask
|