import torch.nn as nn import torch.nn.functional as F from pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction class get_model(nn.Module): def __init__(self,num_class,normal_channel=True): super(get_model, self).__init__() in_channel = 3 if normal_channel else 0 self.normal_channel = normal_channel self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]]) self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]]) self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True) self.fc1 = nn.Linear(1024, 512) self.bn1 = nn.BatchNorm1d(512) self.drop1 = nn.Dropout(0.4) self.fc2 = nn.Linear(512, 256) self.bn2 = nn.BatchNorm1d(256) self.drop2 = nn.Dropout(0.5) self.fc3 = nn.Linear(256, num_class) def forward(self, xyz): B, _, _ = xyz.shape if self.normal_channel: norm = xyz[:, 3:, :] xyz = xyz[:, :3, :] else: norm = None l1_xyz, l1_points = self.sa1(xyz, norm) l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) x = l3_points.view(B, 1024) x = self.drop1(F.relu(self.bn1(self.fc1(x)))) x = self.drop2(F.relu(self.bn2(self.fc2(x)))) x = self.fc3(x) x = F.log_softmax(x, -1) return x,l3_points class get_loss(nn.Module): def __init__(self): super(get_loss, self).__init__() def forward(self, pred, target, trans_feat): total_loss = F.nll_loss(pred, target) return total_loss