MuseV-test / mmcm /vision /utils /torch_util.py
kevinwang676's picture
Upload folder using huggingface_hub
6755a2d verified
raw
history blame contribute delete
541 Bytes
import torch
def find_outlier(grid: torch.Tensor) -> torch.Tensor:
"""find outlier coordinary out of grid
Args:
grid (torch.Tensor): Bx2xHxW
Returns:
mask: ndarray, BxHxW, 1 for coordinary in grid, 0 for outlier
"""
b, _, h, w = grid.shape
mask = torch.ones((b, h, w))
outlier_x_coordinary = (grid[:,0,:,:] >= w).nonzero(as_tuple=True)
outlier_y_coordinary = (grid[:,1,:,:] >= h).nonzero(as_tuple=True)
mask[outlier_x_coordinary] = 0
mask[outlier_y_coordinary] = 0
return mask