|
import laspy |
|
import torch |
|
import numpy as np |
|
import open3d as o3d |
|
from torch.utils.data import Dataset |
|
|
|
def random_sample(point, npoint): |
|
if len(point) > npoint: |
|
sampled_indices = np.random.choice(len(point), npoint, replace=False) |
|
point = point[sampled_indices] |
|
else: |
|
padding = np.zeros((npoint - len(point), 3)) |
|
point = np.vstack((point, padding)) |
|
return point |
|
|
|
|
|
class SingleTreePointCloudLoader(Dataset): |
|
def __init__(self, file, file_type, npoints=2048): |
|
self.file = file |
|
self.npoints = npoints |
|
self.list_of_points = [] |
|
self.list_of_labels = [] |
|
|
|
if file_type == 'pcd': |
|
pcd = o3d.io.read_point_cloud(self.file) |
|
point = np.asarray(pcd.points) |
|
else: |
|
las_file = laspy.read(self.file) |
|
point = np.vstack((las_file.x, las_file.y, las_file.z)).transpose() |
|
|
|
point_set = random_sample(point, self.npoints) |
|
point_set = torch.tensor(point_set, dtype=torch.float32) |
|
self.list_of_points.append(point_set) |
|
self.list_of_labels.append(np.array([-1]).astype(np.int32)) |
|
|
|
def __len__(self): |
|
return 1 |
|
|
|
def __getitem__(self, index): |
|
point_set, label = self.list_of_points[index], self.list_of_labels[index] |
|
return point_set, label[0] |
|
|
|
|
|
if __name__ == '__main__': |
|
dataset = SingleTreePointCloudLoader(file='E:/Important PDFs/Wildlife Institute of India/PointNet ML/Pointnet_Pointnet2_pytorch-master/data/tree_species') |
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0) |
|
|
|
for point, label in dataloader: |
|
print(point.shape) |
|
print(label.shape) |
|
|