Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .BasePIFuNet import BasePIFuNet | |
class VhullPIFuNet(BasePIFuNet): | |
''' | |
Vhull Piximp network is a minimal network demonstrating how the template works | |
also, it helps debugging the training/test schemes | |
It does the following: | |
1. Compute the masks of images and stores under self.im_feats | |
2. Calculate calibration and indexing | |
3. Return if the points fall into the intersection of all masks | |
''' | |
def __init__(self, | |
num_views, | |
projection_mode='orthogonal', | |
error_term=nn.MSELoss(), | |
): | |
super(VhullPIFuNet, self).__init__( | |
projection_mode=projection_mode, | |
error_term=error_term) | |
self.name = 'vhull' | |
self.num_views = num_views | |
self.im_feat = None | |
def filter(self, images): | |
''' | |
Filter the input images | |
store all intermediate features. | |
:param images: [B, C, H, W] input images | |
''' | |
# If the image has alpha channel, use the alpha channel | |
if images.shape[1] > 3: | |
self.im_feat = images[:, 3:4, :, :] | |
# Else, tell if it's not white | |
else: | |
self.im_feat = images[:, 0:1, :, :] | |
def query(self, points, calibs, transforms=None, labels=None): | |
''' | |
Given 3D points, query the network predictions for each point. | |
Image features should be pre-computed before this call. | |
store all intermediate features. | |
query() function may behave differently during training/testing. | |
:param points: [B, 3, N] world space coordinates of points | |
:param calibs: [B, 3, 4] calibration matrices for each image | |
:param transforms: Optional [B, 2, 3] image space coordinate transforms | |
:param labels: Optional [B, Res, N] gt labeling | |
:return: [B, Res, N] predictions for each point | |
''' | |
if labels is not None: | |
self.labels = labels | |
xyz = self.projection(points, calibs, transforms) | |
xy = xyz[:, :2, :] | |
point_local_feat = self.index(self.im_feat, xy) | |
local_shape = point_local_feat.shape | |
point_feat = point_local_feat.view( | |
local_shape[0] // self.num_views, | |
local_shape[1] * self.num_views, | |
-1) | |
pred = torch.prod(point_feat, dim=1) | |
self.preds = pred.unsqueeze(1) | |