Paras Shah commited on
Commit
26e5c1d
·
1 Parent(s): 2d17020

Add webapp files

Browse files
SingleTreePointCloudLoader.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import laspy
2
+ import torch
3
+ import numpy as np
4
+ import open3d as o3d
5
+ from torch.utils.data import Dataset
6
+
7
+ def random_sample(point, npoint):
8
+ if len(point) > npoint:
9
+ sampled_indices = np.random.choice(len(point), npoint, replace=False)
10
+ point = point[sampled_indices]
11
+ else:
12
+ padding = np.zeros((npoint - len(point), 3))
13
+ point = np.vstack((point, padding))
14
+ return point
15
+
16
+
17
+ class SingleTreePointCloudLoader(Dataset):
18
+ def __init__(self, file, file_type, npoints=2048):
19
+ self.file = file
20
+ self.npoints = npoints
21
+ self.list_of_points = []
22
+ self.list_of_labels = []
23
+
24
+ if file_type == 'pcd':
25
+ pcd = o3d.io.read_point_cloud(self.file)
26
+ point = np.asarray(pcd.points)
27
+ else:
28
+ las_file = laspy.read(self.file)
29
+ point = np.vstack((las_file.x, las_file.y, las_file.z)).transpose()
30
+
31
+ point_set = random_sample(point, self.npoints)
32
+ point_set = torch.tensor(point_set, dtype=torch.float32)
33
+ self.list_of_points.append(point_set)
34
+ self.list_of_labels.append(np.array([-1]).astype(np.int32))
35
+
36
+ def __len__(self):
37
+ return 1
38
+
39
+ def __getitem__(self, index):
40
+ point_set, label = self.list_of_points[index], self.list_of_labels[index]
41
+ return point_set, label[0]
42
+
43
+
44
+ if __name__ == '__main__':
45
+ dataset = SingleTreePointCloudLoader(file='E:/Important PDFs/Wildlife Institute of India/PointNet ML/Pointnet_Pointnet2_pytorch-master/data/tree_species')
46
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)
47
+
48
+ for point, label in dataloader:
49
+ print(point.shape)
50
+ print(label.shape)
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import laspy
3
+ import torch
4
+ import tempfile
5
+ import numpy as np
6
+ import open3d as o3d
7
+ import streamlit as st
8
+ import plotly.graph_objs as go
9
+
10
+ import pointnet2_cls_msg as pn2
11
+ from utils import calculate_dbh, calc_canopy_volume, CLASSES
12
+ from SingleTreePointCloudLoader import SingleTreePointCloudLoader
13
+ gc.enable()
14
+
15
+ with st.spinner("Loading PointNet++ model..."):
16
+ checkpoint = torch.load('checkpoints/best_model.pth', map_location=torch.device('cuda'))
17
+ classifier = pn2.get_model(num_class=4, normal_channel=False)
18
+ classifier.load_state_dict(checkpoint['model_state_dict'])
19
+ classifier.eval()
20
+
21
+ st.title("Tree Species Identification")
22
+
23
+ uploaded_file = st.file_uploader(
24
+ label="Upload Point Cloud Data",
25
+ type=['laz', 'las', 'pcd'],
26
+ help="Please upload trees with ground points removed"
27
+ )
28
+ Z_THRESHOLD = st.slider(
29
+ label="Z-Threshold(%)",
30
+ min_value=5,
31
+ max_value=100,
32
+ value=50,
33
+ step=1,
34
+ help="Please select a Z-Threshold for canopy volume calculation"
35
+ )
36
+ DBH_HEIGHT = st.slider(
37
+ label="DBH Height(m)",
38
+ min_value=1.3,
39
+ max_value=1.4,
40
+ value=1.4,
41
+ step=0.01,
42
+ help="Enter height used for DBH calculation"
43
+ )
44
+ proceed = None
45
+
46
+ if uploaded_file:
47
+ try:
48
+ with st.spinner("Reading point cloud file..."):
49
+ file_type = uploaded_file.name.split('.')[-1].lower()
50
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
51
+ tmp.write(uploaded_file.read())
52
+ temp_file_path = tmp.name
53
+
54
+ if file_type == 'pcd':
55
+ pcd = o3d.io.read_point_cloud(temp_file_path)
56
+ points = np.asarray(pcd.points)
57
+ else:
58
+ point_cloud = laspy.read(temp_file_path)
59
+ points = np.vstack((point_cloud.x, point_cloud.y, point_cloud.z)).transpose()
60
+
61
+ proceed = st.button("Run model")
62
+ except Exception as e:
63
+ st.error(f"An error occured: {str(e)}")
64
+
65
+ if proceed:
66
+ try:
67
+ with st.spinner("Calculating tree inventory..."):
68
+ dbh, trunk_points = calculate_dbh(points, DBH_HEIGHT)
69
+
70
+ z_min = np.min(points[:, 2])
71
+ z_max = np.max(points[:, 2])
72
+ height = z_max - z_min
73
+
74
+ canopy_volume, canopy_points = calc_canopy_volume(points, Z_THRESHOLD, height, z_min)
75
+
76
+ with st.spinner("Visualizing point cloud..."):
77
+ fig = go.Figure()
78
+ fig.add_trace(go.Scatter3d(
79
+ x=points[:, 0],
80
+ y=points[:, 1],
81
+ z=points[:, 2],
82
+ mode='markers',
83
+ marker=dict(
84
+ size=0.5,
85
+ color=points[:, 2],
86
+ colorscale='Viridis',
87
+ opacity=1.0,
88
+ ),
89
+ name='Tree'
90
+ ))
91
+ fig.add_trace(go.Scatter3d(
92
+ x=canopy_points[:, 0],
93
+ y=canopy_points[:, 1],
94
+ z=canopy_points[:, 2],
95
+ mode='markers',
96
+ marker=dict(
97
+ size=2,
98
+ color='blue',
99
+ opacity=0.8,
100
+ ),
101
+ name='Canopy points'
102
+ ))
103
+ fig.add_trace(go.Scatter3d(
104
+ x=trunk_points[:, 0],
105
+ y=trunk_points[:, 1],
106
+ z=trunk_points[:, 2],
107
+ mode='markers',
108
+ marker=dict(
109
+ size=2,
110
+ color='red',
111
+ opacity=0.9,
112
+ ),
113
+ name='DBH'
114
+ ))
115
+ fig.update_layout(
116
+ margin=dict(l=0, r=0, b=0, t=0),
117
+ scene=dict(
118
+ xaxis_title="X",
119
+ yaxis_title="Y",
120
+ zaxis_title="Z",
121
+ aspectmode='data'
122
+ )
123
+ )
124
+ st.plotly_chart(fig, use_container_width=True)
125
+
126
+
127
+ with st.spinner("Running inference..."):
128
+ testFile = SingleTreePointCloudLoader(temp_file_path, file_type)
129
+ testFileLoader = torch.utils.data.DataLoader(testFile, batch_size=8, shuffle=False, num_workers=0)
130
+ point_set, _ = next(iter(testFileLoader))
131
+ point_set = point_set.transpose(2, 1)
132
+
133
+ with torch.no_grad():
134
+ logits, _ = classifier(point_set)
135
+ probabilities = torch.softmax(logits, dim=-1)
136
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
137
+ confidence_score = (probabilities.numpy().tolist())[0][predicted_class] * 100
138
+ predicted_label = CLASSES[predicted_class]
139
+
140
+ st.write(f"**Predicted class: {predicted_label}**")
141
+ # st.write(f"Class Probabilities: {probabilities.numpy().tolist()}")
142
+ st.write(f"**Confidence score: {confidence_score:.2f}%**")
143
+ st.write(f"**Height of tree: {height:.2f}m**")
144
+ st.write(f"**Canopy volume: {canopy_volume:.2f}m\u00b3**")
145
+ st.write(f"**DBH: {dbh:.2f}m**")
146
+
147
+ except Exception as e:
148
+ st.error(f"An error occured: {str(e)}")
checkpoints/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:241e6ebaa818ecd1def1406bc3fa60702e49287e11bac30d4fcd8dd4b3c40b04
3
+ size 21010087
pointnet2_cls_msg.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction
4
+
5
+
6
+ class get_model(nn.Module):
7
+ def __init__(self,num_class,normal_channel=True):
8
+ super(get_model, self).__init__()
9
+ in_channel = 3 if normal_channel else 0
10
+ self.normal_channel = normal_channel
11
+ self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
12
+ self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]])
13
+ self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True)
14
+ self.fc1 = nn.Linear(1024, 512)
15
+ self.bn1 = nn.BatchNorm1d(512)
16
+ self.drop1 = nn.Dropout(0.4)
17
+ self.fc2 = nn.Linear(512, 256)
18
+ self.bn2 = nn.BatchNorm1d(256)
19
+ self.drop2 = nn.Dropout(0.5)
20
+ self.fc3 = nn.Linear(256, num_class)
21
+
22
+ def forward(self, xyz):
23
+ B, _, _ = xyz.shape
24
+ if self.normal_channel:
25
+ norm = xyz[:, 3:, :]
26
+ xyz = xyz[:, :3, :]
27
+ else:
28
+ norm = None
29
+ l1_xyz, l1_points = self.sa1(xyz, norm)
30
+ l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
31
+ l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
32
+ x = l3_points.view(B, 1024)
33
+ x = self.drop1(F.relu(self.bn1(self.fc1(x))))
34
+ x = self.drop2(F.relu(self.bn2(self.fc2(x))))
35
+ x = self.fc3(x)
36
+ x = F.log_softmax(x, -1)
37
+
38
+
39
+ return x,l3_points
40
+
41
+
42
+ class get_loss(nn.Module):
43
+ def __init__(self):
44
+ super(get_loss, self).__init__()
45
+
46
+ def forward(self, pred, target, trans_feat):
47
+ total_loss = F.nll_loss(pred, target)
48
+
49
+ return total_loss
pointnet2_utils.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from time import time
5
+ import numpy as np
6
+
7
+ def timeit(tag, t):
8
+ print("{}: {}s".format(tag, time() - t))
9
+ return time()
10
+
11
+ def pc_normalize(pc):
12
+ l = pc.shape[0]
13
+ centroid = np.mean(pc, axis=0)
14
+ pc = pc - centroid
15
+ m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
16
+ pc = pc / m
17
+ return pc
18
+
19
+ def square_distance(src, dst):
20
+ """
21
+ Calculate Euclid distance between each two points.
22
+
23
+ src^T * dst = xn * xm + yn * ym + zn * zm;
24
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
25
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
26
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
27
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
28
+
29
+ Input:
30
+ src: source points, [B, N, C]
31
+ dst: target points, [B, M, C]
32
+ Output:
33
+ dist: per-point square distance, [B, N, M]
34
+ """
35
+ B, N, _ = src.shape
36
+ _, M, _ = dst.shape
37
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
38
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
39
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
40
+ return dist
41
+
42
+
43
+ def index_points(points, idx):
44
+ """
45
+
46
+ Input:
47
+ points: input points data, [B, N, C]
48
+ idx: sample index data, [B, S]
49
+ Return:
50
+ new_points:, indexed points data, [B, S, C]
51
+ """
52
+ device = points.device
53
+ B = points.shape[0]
54
+ view_shape = list(idx.shape)
55
+ view_shape[1:] = [1] * (len(view_shape) - 1)
56
+ repeat_shape = list(idx.shape)
57
+ repeat_shape[0] = 1
58
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
59
+ new_points = points[batch_indices, idx, :]
60
+ return new_points
61
+
62
+
63
+ def farthest_point_sample(xyz, npoint):
64
+ """
65
+ Input:
66
+ xyz: pointcloud data, [B, N, 3]
67
+ npoint: number of samples
68
+ Return:
69
+ centroids: sampled pointcloud index, [B, npoint]
70
+ """
71
+ device = xyz.device
72
+ B, N, C = xyz.shape
73
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
74
+ distance = torch.ones(B, N).to(device) * 1e10
75
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
76
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
77
+ for i in range(npoint):
78
+ centroids[:, i] = farthest
79
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
80
+ dist = torch.sum((xyz - centroid) ** 2, -1)
81
+ mask = dist < distance
82
+ distance[mask] = dist[mask]
83
+ farthest = torch.max(distance, -1)[1]
84
+ return centroids
85
+
86
+
87
+ def query_ball_point(radius, nsample, xyz, new_xyz):
88
+ """
89
+ Input:
90
+ radius: local region radius
91
+ nsample: max sample number in local region
92
+ xyz: all points, [B, N, 3]
93
+ new_xyz: query points, [B, S, 3]
94
+ Return:
95
+ group_idx: grouped points index, [B, S, nsample]
96
+ """
97
+ device = xyz.device
98
+ B, N, C = xyz.shape
99
+ _, S, _ = new_xyz.shape
100
+ group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
101
+ sqrdists = square_distance(new_xyz, xyz)
102
+ group_idx[sqrdists > radius ** 2] = N
103
+ group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
104
+ group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
105
+ mask = group_idx == N
106
+ group_idx[mask] = group_first[mask]
107
+ return group_idx
108
+
109
+
110
+ def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
111
+ """
112
+ Input:
113
+ npoint:
114
+ radius:
115
+ nsample:
116
+ xyz: input points position data, [B, N, 3]
117
+ points: input points data, [B, N, D]
118
+ Return:
119
+ new_xyz: sampled points position data, [B, npoint, nsample, 3]
120
+ new_points: sampled points data, [B, npoint, nsample, 3+D]
121
+ """
122
+ B, N, C = xyz.shape
123
+ S = npoint
124
+ fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
125
+ new_xyz = index_points(xyz, fps_idx)
126
+ idx = query_ball_point(radius, nsample, xyz, new_xyz)
127
+ grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
128
+ grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
129
+
130
+ if points is not None:
131
+ grouped_points = index_points(points, idx)
132
+ new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
133
+ else:
134
+ new_points = grouped_xyz_norm
135
+ if returnfps:
136
+ return new_xyz, new_points, grouped_xyz, fps_idx
137
+ else:
138
+ return new_xyz, new_points
139
+
140
+
141
+ def sample_and_group_all(xyz, points):
142
+ """
143
+ Input:
144
+ xyz: input points position data, [B, N, 3]
145
+ points: input points data, [B, N, D]
146
+ Return:
147
+ new_xyz: sampled points position data, [B, 1, 3]
148
+ new_points: sampled points data, [B, 1, N, 3+D]
149
+ """
150
+ device = xyz.device
151
+ B, N, C = xyz.shape
152
+ new_xyz = torch.zeros(B, 1, C).to(device)
153
+ grouped_xyz = xyz.view(B, 1, N, C)
154
+ if points is not None:
155
+ new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
156
+ else:
157
+ new_points = grouped_xyz
158
+ return new_xyz, new_points
159
+
160
+
161
+ class PointNetSetAbstraction(nn.Module):
162
+ def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
163
+ super(PointNetSetAbstraction, self).__init__()
164
+ self.npoint = npoint
165
+ self.radius = radius
166
+ self.nsample = nsample
167
+ self.mlp_convs = nn.ModuleList()
168
+ self.mlp_bns = nn.ModuleList()
169
+ last_channel = in_channel
170
+ for out_channel in mlp:
171
+ self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
172
+ self.mlp_bns.append(nn.BatchNorm2d(out_channel))
173
+ last_channel = out_channel
174
+ self.group_all = group_all
175
+
176
+ def forward(self, xyz, points):
177
+ """
178
+ Input:
179
+ xyz: input points position data, [B, C, N]
180
+ points: input points data, [B, D, N]
181
+ Return:
182
+ new_xyz: sampled points position data, [B, C, S]
183
+ new_points_concat: sample points feature data, [B, D', S]
184
+ """
185
+ xyz = xyz.permute(0, 2, 1)
186
+ if points is not None:
187
+ points = points.permute(0, 2, 1)
188
+
189
+ if self.group_all:
190
+ new_xyz, new_points = sample_and_group_all(xyz, points)
191
+ else:
192
+ new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
193
+ # new_xyz: sampled points position data, [B, npoint, C]
194
+ # new_points: sampled points data, [B, npoint, nsample, C+D]
195
+ new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
196
+ for i, conv in enumerate(self.mlp_convs):
197
+ bn = self.mlp_bns[i]
198
+ new_points = F.relu(bn(conv(new_points)))
199
+
200
+ new_points = torch.max(new_points, 2)[0]
201
+ new_xyz = new_xyz.permute(0, 2, 1)
202
+ return new_xyz, new_points
203
+
204
+
205
+ class PointNetSetAbstractionMsg(nn.Module):
206
+ def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
207
+ super(PointNetSetAbstractionMsg, self).__init__()
208
+ self.npoint = npoint
209
+ self.radius_list = radius_list
210
+ self.nsample_list = nsample_list
211
+ self.conv_blocks = nn.ModuleList()
212
+ self.bn_blocks = nn.ModuleList()
213
+ for i in range(len(mlp_list)):
214
+ convs = nn.ModuleList()
215
+ bns = nn.ModuleList()
216
+ last_channel = in_channel + 3
217
+ for out_channel in mlp_list[i]:
218
+ convs.append(nn.Conv2d(last_channel, out_channel, 1))
219
+ bns.append(nn.BatchNorm2d(out_channel))
220
+ last_channel = out_channel
221
+ self.conv_blocks.append(convs)
222
+ self.bn_blocks.append(bns)
223
+
224
+ def forward(self, xyz, points):
225
+ """
226
+ Input:
227
+ xyz: input points position data, [B, C, N]
228
+ points: input points data, [B, D, N]
229
+ Return:
230
+ new_xyz: sampled points position data, [B, C, S]
231
+ new_points_concat: sample points feature data, [B, D', S]
232
+ """
233
+ xyz = xyz.permute(0, 2, 1)
234
+ if points is not None:
235
+ points = points.permute(0, 2, 1)
236
+
237
+ B, N, C = xyz.shape
238
+ S = self.npoint
239
+ new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
240
+ new_points_list = []
241
+ for i, radius in enumerate(self.radius_list):
242
+ K = self.nsample_list[i]
243
+ group_idx = query_ball_point(radius, K, xyz, new_xyz)
244
+ grouped_xyz = index_points(xyz, group_idx)
245
+ grouped_xyz -= new_xyz.view(B, S, 1, C)
246
+ if points is not None:
247
+ grouped_points = index_points(points, group_idx)
248
+ grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
249
+ else:
250
+ grouped_points = grouped_xyz
251
+
252
+ grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
253
+ for j in range(len(self.conv_blocks[i])):
254
+ conv = self.conv_blocks[i][j]
255
+ bn = self.bn_blocks[i][j]
256
+ grouped_points = F.relu(bn(conv(grouped_points)))
257
+ new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
258
+ new_points_list.append(new_points)
259
+
260
+ new_xyz = new_xyz.permute(0, 2, 1)
261
+ new_points_concat = torch.cat(new_points_list, dim=1)
262
+ return new_xyz, new_points_concat
263
+
264
+
265
+ class PointNetFeaturePropagation(nn.Module):
266
+ def __init__(self, in_channel, mlp):
267
+ super(PointNetFeaturePropagation, self).__init__()
268
+ self.mlp_convs = nn.ModuleList()
269
+ self.mlp_bns = nn.ModuleList()
270
+ last_channel = in_channel
271
+ for out_channel in mlp:
272
+ self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
273
+ self.mlp_bns.append(nn.BatchNorm1d(out_channel))
274
+ last_channel = out_channel
275
+
276
+ def forward(self, xyz1, xyz2, points1, points2):
277
+ """
278
+ Input:
279
+ xyz1: input points position data, [B, C, N]
280
+ xyz2: sampled input points position data, [B, C, S]
281
+ points1: input points data, [B, D, N]
282
+ points2: input points data, [B, D, S]
283
+ Return:
284
+ new_points: upsampled points data, [B, D', N]
285
+ """
286
+ xyz1 = xyz1.permute(0, 2, 1)
287
+ xyz2 = xyz2.permute(0, 2, 1)
288
+
289
+ points2 = points2.permute(0, 2, 1)
290
+ B, N, C = xyz1.shape
291
+ _, S, _ = xyz2.shape
292
+
293
+ if S == 1:
294
+ interpolated_points = points2.repeat(1, N, 1)
295
+ else:
296
+ dists = square_distance(xyz1, xyz2)
297
+ dists, idx = dists.sort(dim=-1)
298
+ dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
299
+
300
+ dist_recip = 1.0 / (dists + 1e-8)
301
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
302
+ weight = dist_recip / norm
303
+ interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
304
+
305
+ if points1 is not None:
306
+ points1 = points1.permute(0, 2, 1)
307
+ new_points = torch.cat([points1, interpolated_points], dim=-1)
308
+ else:
309
+ new_points = interpolated_points
310
+
311
+ new_points = new_points.permute(0, 2, 1)
312
+ for i, conv in enumerate(self.mlp_convs):
313
+ bn = self.mlp_bns[i]
314
+ new_points = F.relu(bn(conv(new_points)))
315
+ return new_points
utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.cluster import DBSCAN
3
+ from scipy.spatial import ConvexHull
4
+ from scipy.optimize import least_squares
5
+
6
+ CLASSES = [
7
+ 'Betula_pendula',
8
+ 'Fagus_sylvatica',
9
+ 'Picea_abies',
10
+ 'Pinus_sylvestris'
11
+ ]
12
+
13
+ def fit_circle(x, y):
14
+ """Fit a circle to given x, y points."""
15
+ def calc_radius(params):
16
+ cx, cy, r = params
17
+ return np.sqrt((x - cx)**2 + (y - cy)**2) - r
18
+
19
+ # Initial guess for circle parameters: center (mean x, mean y), radius
20
+ x_m, y_m = np.mean(x), np.mean(y)
21
+ r_initial = np.mean(np.sqrt((x - x_m)**2 + (y - y_m)**2))
22
+ initial_params = [x_m, y_m, r_initial]
23
+
24
+ # Use least squares optimization to fit the circle
25
+ result = least_squares(calc_radius, initial_params)
26
+ cx, cy, r = result.x
27
+ return cx, cy, r
28
+
29
+ def remove_noise(points, eps=0.05, min_samples=10):
30
+ """
31
+ Remove noise from points using DBSCAN clustering.
32
+
33
+ Args:
34
+ points (numpy.ndarray): Array of shape (N, 3) with columns [x, y, z].
35
+ eps (float): Maximum distance between two samples to consider them as in the same neighborhood.
36
+ min_samples (int): Minimum number of points to form a dense region.
37
+
38
+ Returns:
39
+ numpy.ndarray: Denoised points.
40
+ """
41
+ clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(points)
42
+ labels = clustering.labels_
43
+ largest_cluster = labels == np.argmax(np.bincount(labels[labels >= 0]))
44
+ return points[largest_cluster]
45
+
46
+ def calculate_dbh(points, dbh_height=1.3, height_buffer=0.1, eps=0.05, min_samples=10):
47
+ """
48
+ Calculate the Diameter at Breast Height (DBH) of a tree from point cloud data.
49
+
50
+ Args:
51
+ points (numpy.ndarray): Array of shape (N, 3) with columns [x, y, z].
52
+ dbh_height (float): Height at which DBH is measured (default is 1.3 meters).
53
+ height_buffer (float): Range around dbh_height to include points (default is ±0.1 meters).
54
+
55
+ Returns:
56
+ float: DBH in meters.
57
+ """
58
+ z_min, z_max = dbh_height - height_buffer, dbh_height + height_buffer
59
+ trunk_points = points[(points[:, 2] >= (z_min)) & (points[:, 2] <= (z_max))]
60
+
61
+ if trunk_points.shape[0] < 3:
62
+ raise ValueError("Not enough points to calculate DBH.")
63
+
64
+ # Remove noise
65
+ denoised_points = remove_noise(trunk_points[:, :2], eps=eps, min_samples=min_samples)
66
+ denoised_points = np.hstack((denoised_points, np.full((denoised_points.shape[0], 1), dbh_height)))
67
+
68
+ if denoised_points.shape[0] < 3:
69
+ raise ValueError("Not enough points left after noise removal.")
70
+
71
+ # Fit a circle to the trunk points
72
+ x, y = denoised_points[:, 0], denoised_points[:, 1]
73
+ cx, cy, radius = fit_circle(x, y)
74
+
75
+ # Generate points along the fitted circle for visualization
76
+ theta = np.linspace(0, 2 * np.pi, 100)
77
+ circle_x = cx + radius * np.cos(theta)
78
+ circle_y = cy + radius * np.sin(theta)
79
+ circle_points = np.column_stack((circle_x, circle_y, np.full_like(circle_x, dbh_height)))
80
+
81
+ # Calculate DBH (Diameter = 2 * radius)
82
+ dbh = 2 * radius
83
+ return dbh, circle_points
84
+
85
+ def calc_canopy_volume(points, threshold, height, z_min):
86
+ '''
87
+ Calculates the canopy points for a given point cloud data of a tree
88
+ and calculates the volume using the Qhull algorithm
89
+
90
+ Args:
91
+ points: point cloud data
92
+ threshold: z_threshold in percentage
93
+ height, z_min
94
+
95
+ Returns:
96
+ canopy_volume, canopy_points
97
+ '''
98
+ canopy_points = points[points[:, 2] > z_min + 0.2 * height]
99
+ z_threshold = np.percentile(canopy_points[:, 2], threshold)
100
+ canopy_points = canopy_points[canopy_points[:, 2] >= z_threshold]
101
+ clustering = DBSCAN(eps=1.0, min_samples=10).fit(canopy_points[:, :3])
102
+ labels = clustering.labels_
103
+ canopy_points = canopy_points[labels != -1]
104
+
105
+ if canopy_points.shape[0] < 4:
106
+ canopy_volume = None
107
+ else:
108
+ '''
109
+ Uses the QuickHull algorithm which uses a divide-and-conquer approach.
110
+ It selects the 2 leftmost and rightmost points on a 2D plane;
111
+ These are part of the ConvexHull.
112
+ Then it selects the point farthest away from the line joining the
113
+ above 2 points and adds it to the ConvexHull.
114
+ The points enclosed within that shape cannot be part of the
115
+ ConvexHull and are ignored. This process is then repeated until
116
+ all points are either part of the ConvexHull or contained inside it.
117
+ '''
118
+ hull = ConvexHull(canopy_points)
119
+ canopy_volume = hull.volume
120
+
121
+ return canopy_volume, canopy_points