bluestyle97 commited on
Commit
80695f7
·
verified ·
1 Parent(s): 772e07e

Delete freesplatter/utils/mesh_renderer.py

Browse files
Files changed (1) hide show
  1. freesplatter/utils/mesh_renderer.py +0 -608
freesplatter/utils/mesh_renderer.py DELETED
@@ -1,608 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import nvdiffrast.torch as dr
6
-
7
-
8
- def get_ray_directions(h, w, intrinsics, norm=False, device=None):
9
- """
10
- Args:
11
- h (int)
12
- w (int)
13
- intrinsics: (*, 4), in [fx, fy, cx, cy]
14
-
15
- Returns:
16
- directions: (*, h, w, 3), the direction of the rays in camera coordinate
17
- """
18
- batch_size = intrinsics.shape[:-1]
19
- x = torch.linspace(0.5, w - 0.5, w, device=device)
20
- y = torch.linspace(0.5, h - 0.5, h, device=device)
21
- # (*, h, w, 2)
22
- directions_xy = torch.stack(
23
- [((x - intrinsics[..., 2:3]) / intrinsics[..., 0:1])[..., None, :].expand(*batch_size, h, w),
24
- ((y - intrinsics[..., 3:4]) / intrinsics[..., 1:2])[..., :, None].expand(*batch_size, h, w)], dim=-1)
25
- # (*, h, w, 3)
26
- directions = F.pad(directions_xy, [0, 1], mode='constant', value=1.0)
27
- if norm:
28
- directions = F.normalize(directions, dim=-1)
29
- return directions
30
-
31
-
32
- def edge_dilation(img, mask, radius=3, iter=7):
33
- """
34
- Args:
35
- img (torch.Tensor): (n, c, h, w)
36
- mask (torch.Tensor): (n, 1, h, w)
37
- radius (float): Radius of dilation.
38
-
39
- Returns:
40
- torch.Tensor: Dilated image.
41
- """
42
- n, c, h, w = img.size()
43
- int_radius = round(radius)
44
- kernel_size = int(int_radius * 2 + 1)
45
- distance1d_sq = torch.linspace(-int_radius, int_radius, kernel_size, dtype=img.dtype, device=img.device).square()
46
- kernel_distance = (distance1d_sq.reshape(1, -1) + distance1d_sq.reshape(-1, 1)).sqrt()
47
- kernel_neg_distance = kernel_distance.max() - kernel_distance + 1
48
-
49
- for _ in range(iter):
50
-
51
- mask_out = F.max_pool2d(mask, kernel_size, stride=1, padding=int_radius)
52
- do_fill_mask = ((mask_out - mask) > 0.5).squeeze(1)
53
- # (num_fill, 3) in [ind_n, ind_h, ind_w]
54
- do_fill = do_fill_mask.nonzero()
55
-
56
- # unfold the image and mask
57
- mask_unfold = F.unfold(mask, kernel_size, padding=int_radius).reshape(
58
- n, kernel_size * kernel_size, h, w).permute(0, 2, 3, 1)
59
-
60
- fill_ind = (mask_unfold[do_fill_mask] * kernel_neg_distance.flatten()).argmax(dim=-1)
61
- do_fill_h = do_fill[:, 1] + fill_ind // kernel_size - int_radius
62
- do_fill_w = do_fill[:, 2] + fill_ind % kernel_size - int_radius
63
-
64
- img_out = img.clone()
65
- img_out[do_fill[:, 0], :, do_fill[:, 1], do_fill[:, 2]] = img[
66
- do_fill[:, 0], :, do_fill_h, do_fill_w]
67
-
68
- img = img_out
69
- mask = mask_out
70
-
71
- return img
72
-
73
-
74
- def depth_to_normal(depth, directions, format='opengl'):
75
- """
76
- Args:
77
- depth: shape (*, h, w), inverse depth defined as 1 / z
78
- directions: shape (*, h, w, 3), unnormalized ray directions, under OpenCV coordinate system
79
-
80
- Returns:
81
- out_normal: shape (*, h, w, 3), in range [0, 1]
82
- """
83
- out_xyz = directions / depth.unsqueeze(-1).clamp(min=1e-6)
84
- dx = out_xyz[..., :, 1:, :] - out_xyz[..., :, :-1, :]
85
- dy = out_xyz[..., 1:, :, :] - out_xyz[..., :-1, :, :]
86
- right = F.pad(dx, (0, 0, 0, 1, 0, 0), mode='replicate')
87
- up = F.pad(-dy, (0, 0, 0, 0, 1, 0), mode='replicate')
88
- left = F.pad(-dx, (0, 0, 1, 0, 0, 0), mode='replicate')
89
- down = F.pad(dy, (0, 0, 0, 0, 0, 1), mode='replicate')
90
- out_normal = F.normalize(
91
- F.normalize(torch.cross(right, up, dim=-1), dim=-1)
92
- + F.normalize(torch.cross(up, left, dim=-1), dim=-1)
93
- + F.normalize(torch.cross(left, down, dim=-1), dim=-1)
94
- + F.normalize(torch.cross(down, right, dim=-1), dim=-1),
95
- dim=-1)
96
- if format == 'opengl':
97
- out_normal[..., 1:3] = -out_normal[..., 1:3] # to opengl coord
98
- elif format == 'opencv':
99
- out_normal = out_normal
100
- else:
101
- raise ValueError('format should be opengl or opencv')
102
- out_normal = out_normal / 2 + 0.5
103
- return out_normal
104
-
105
-
106
- def make_divisible(x, m=8):
107
- return int(math.ceil(x / m) * m)
108
-
109
-
110
- def interpolate_hwc(x, scale_factor, mode='area'):
111
- batch_dim = x.shape[:-3]
112
- y = x.reshape(batch_dim.numel(), *x.shape[-3:]).permute(0, 3, 1, 2)
113
- y = F.interpolate(y, scale_factor=scale_factor, mode=mode).permute(0, 2, 3, 1)
114
- return y.reshape(*batch_dim, *y.shape[1:])
115
-
116
-
117
- def compute_edge_to_face_mapping(attr_idx):
118
- with torch.no_grad():
119
- # Get unique edges
120
- # Create all edges, packed by triangle
121
- all_edges = torch.cat((
122
- torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1),
123
- torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1),
124
- torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1),
125
- ), dim=-1).view(-1, 2)
126
-
127
- # Swap edge order so min index is always first
128
- order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
129
- sorted_edges = torch.cat((
130
- torch.gather(all_edges, 1, order),
131
- torch.gather(all_edges, 1, 1 - order)
132
- ), dim=-1)
133
-
134
- # Elliminate duplicates and return inverse mapping
135
- unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True)
136
-
137
- tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda()
138
-
139
- tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda()
140
-
141
- # Compute edge to face table
142
- mask0 = order[:,0] == 0
143
- mask1 = order[:,0] == 1
144
- tris_per_edge[idx_map[mask0], 0] = tris[mask0]
145
- tris_per_edge[idx_map[mask1], 1] = tris[mask1]
146
-
147
- return tris_per_edge
148
-
149
-
150
- @torch.cuda.amp.autocast(enabled=False)
151
- def normal_consistency(face_normals, t_pos_idx):
152
-
153
- tris_per_edge = compute_edge_to_face_mapping(t_pos_idx)
154
-
155
- # Fetch normals for both faces sharind an edge
156
- n0 = face_normals[tris_per_edge[:, 0], :]
157
- n1 = face_normals[tris_per_edge[:, 1], :]
158
-
159
- # Compute error metric based on normal difference
160
- term = torch.clamp(torch.sum(n0 * n1, -1, keepdim=True), min=-1.0, max=1.0)
161
- term = (1.0 - term)
162
-
163
- return torch.mean(torch.abs(term))
164
-
165
-
166
- def laplacian_uniform(verts, faces):
167
-
168
- V = verts.shape[0]
169
- F = faces.shape[0]
170
-
171
- # Neighbor indices
172
- ii = faces[:, [1, 2, 0]].flatten()
173
- jj = faces[:, [2, 0, 1]].flatten()
174
- adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(dim=1)
175
- adj_values = torch.ones(adj.shape[1], device=verts.device, dtype=torch.float)
176
-
177
- # Diagonal indices
178
- diag_idx = adj[0]
179
-
180
- # Build the sparse matrix
181
- idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
182
- values = torch.cat((-adj_values, adj_values))
183
-
184
- # The coalesce operation sums the duplicate indices, resulting in the
185
- # correct diagonal
186
- return torch.sparse_coo_tensor(idx, values, (V,V)).coalesce()
187
-
188
-
189
- @torch.cuda.amp.autocast(enabled=False)
190
- def laplacian_smooth_loss(verts, faces):
191
- with torch.no_grad():
192
- L = laplacian_uniform(verts, faces.long())
193
- loss = L.mm(verts)
194
- loss = loss.norm(dim=1)
195
- loss = loss.mean()
196
- return loss
197
-
198
-
199
- class DMTet:
200
-
201
- def __init__(self, device):
202
- self.device = device
203
- self.triangle_table = torch.tensor([
204
- [-1, -1, -1, -1, -1, -1],
205
- [1, 0, 2, -1, -1, -1],
206
- [4, 0, 3, -1, -1, -1],
207
- [1, 4, 2, 1, 3, 4],
208
- [3, 1, 5, -1, -1, -1],
209
- [2, 3, 0, 2, 5, 3],
210
- [1, 4, 0, 1, 5, 4],
211
- [4, 2, 5, -1, -1, -1],
212
- [4, 5, 2, -1, -1, -1],
213
- [4, 1, 0, 4, 5, 1],
214
- [3, 2, 0, 3, 5, 2],
215
- [1, 3, 5, -1, -1, -1],
216
- [4, 1, 2, 4, 3, 1],
217
- [3, 0, 4, -1, -1, -1],
218
- [2, 0, 1, -1, -1, -1],
219
- [-1, -1, -1, -1, -1, -1]
220
- ], dtype=torch.long, device=device)
221
- self.num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long,
222
- device=device)
223
- self.base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
224
-
225
- def sort_edges(self, edges_ex2):
226
- with torch.no_grad():
227
- order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
228
- order = order.unsqueeze(dim=1)
229
-
230
- a = torch.gather(input=edges_ex2, index=order, dim=1)
231
- b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
232
-
233
- return torch.stack([a, b], -1)
234
-
235
- def __call__(self, pos_nx3, sdf_n, tet_fx4):
236
- # pos_nx3: [N, 3]
237
- # sdf_n: [N]
238
- # tet_fx4: [F, 4]
239
-
240
- with torch.no_grad():
241
- occ_n = sdf_n > 0
242
- occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
243
- occ_sum = torch.sum(occ_fx4, -1) # [F,]
244
- valid_tets = (occ_sum > 0) & (occ_sum < 4)
245
- # occ_sum = occ_sum[valid_tets]
246
-
247
- # find all vertices
248
- all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
249
- all_edges = self.sort_edges(all_edges)
250
- unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
251
-
252
- unique_edges = unique_edges.long()
253
- mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
254
- mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
255
- mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=self.device)
256
- idx_map = mapping[idx_map] # map edges to verts
257
-
258
- interp_v = unique_edges[mask_edges]
259
-
260
- edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
261
- edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
262
- edges_to_interp_sdf[:, -1] *= -1
263
-
264
- denominator = edges_to_interp_sdf.sum(1, keepdim=True)
265
-
266
- edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
267
- verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
268
-
269
- idx_map = idx_map.reshape(-1, 6)
270
-
271
- v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=self.device))
272
- tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
273
- num_triangles = self.num_triangles_table[tetindex]
274
-
275
- # Generate triangle indices
276
- faces = torch.cat((
277
- torch.gather(input=idx_map[num_triangles == 1], dim=1,
278
- index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
279
- torch.gather(input=idx_map[num_triangles == 2], dim=1,
280
- index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
281
- ), dim=0)
282
-
283
- return verts, faces
284
-
285
-
286
- class MeshRenderer(nn.Module):
287
- def __init__(self,
288
- near=0.1,
289
- far=10,
290
- ssaa=1,
291
- texture_filter='linear-mipmap-linear',
292
- opengl=False,
293
- device='cuda'):
294
- super().__init__()
295
- self.near = near
296
- self.far = far
297
- assert isinstance(ssaa, int) and ssaa >= 1
298
- self.ssaa = ssaa
299
- self.texture_filter = texture_filter
300
- self.glctx = dr.RasterizeGLContext(output_db=False)
301
-
302
- def forward(self, meshes, poses, intrinsics, h, w, shading_fun=None,
303
- dilate_edges=0, normal_bg=[0.5, 0.5, 1.0], aa=True, render_vc=False):
304
- """
305
- Args:
306
- meshes (list[Mesh]): list of Mesh objects
307
- poses: Shape (num_scenes, num_images, 3, 4)
308
- intrinsics: Shape (num_scenes, num_images, 4) in [fx, fy, cx, cy]
309
- """
310
- num_scenes, num_images, _, _ = poses.size()
311
-
312
- if self.ssaa > 1:
313
- h = h * self.ssaa
314
- w = w * self.ssaa
315
- intrinsics = intrinsics * self.ssaa
316
-
317
- r_mat_c2w = torch.cat(
318
- [poses[..., :3, :1], -poses[..., :3, 1:3]], dim=-1) # opencv to opengl conversion
319
-
320
- proj = poses.new_zeros([num_scenes, num_images, 4, 4])
321
- proj[..., 0, 0] = 2 * intrinsics[..., 0] / w
322
- proj[..., 0, 2] = -2 * intrinsics[..., 2] / w + 1
323
- proj[..., 1, 1] = -2 * intrinsics[..., 1] / h
324
- proj[..., 1, 2] = -2 * intrinsics[..., 3] / h + 1
325
- proj[..., 2, 2] = -(self.far + self.near) / (self.far - self.near)
326
- proj[..., 2, 3] = -(2 * self.far * self.near) / (self.far - self.near)
327
- proj[..., 3, 2] = -1
328
-
329
- # (num_scenes, (num_images, num_vertices, 3))
330
- v_cam = [(mesh.v - poses[i, :, :3, 3].unsqueeze(-2)) @ r_mat_c2w[i] for i, mesh in enumerate(meshes)]
331
- # (num_scenes, (num_images, num_vertices, 4))
332
- v_clip = [F.pad(v, pad=(0, 1), mode='constant', value=1.0) @ proj[i].transpose(-1, -2) for i, v in enumerate(v_cam)]
333
-
334
- if num_scenes == 1:
335
- # (num_images, h, w, 4) in [u, v, z/w, triangle_id] & (num_images, h, w, 4 or 0)
336
- rast, rast_db = dr.rasterize(
337
- self.glctx, v_clip[0], meshes[0].f, (h, w), grad_db=torch.is_grad_enabled())
338
-
339
- fg = (rast[..., 3] > 0).unsqueeze(0) # (num_scenes, num_images, h, w)
340
- alpha = fg.float().unsqueeze(-1)
341
-
342
- depth = 1 / dr.interpolate(
343
- -v_cam[0][..., 2:3].contiguous(), rast, meshes[0].f)[0].reshape(num_scenes, num_images, h, w)
344
- depth.masked_fill_(~fg, 0)
345
-
346
- normal = dr.interpolate(
347
- meshes[0].vn.unsqueeze(0).contiguous(), rast, meshes[0].fn)[0].reshape(num_scenes, num_images, h, w, 3)
348
- normal = F.normalize(normal, dim=-1)
349
- # (num_scenes, num_images, h, w, 3) = (num_scenes, num_images, h, w, 3) @ (num_scenes, num_images, 1, 3, 3)
350
- rot_normal = (normal @ r_mat_c2w.unsqueeze(2)) / 2 + 0.5
351
- rot_normal[~fg] = rot_normal.new_tensor(normal_bg)
352
-
353
- if meshes[0].vt is not None and meshes[0].albedo is not None:
354
- # (num_images, h, w, 2) & (num_images, h, w, 4)
355
- texc, texc_db = dr.interpolate(
356
- meshes[0].vt.unsqueeze(0).contiguous(), rast, meshes[0].ft, rast_db=rast_db, diff_attrs='all')
357
- # (num_scenes, num_images, h, w, 3)
358
- albedo = dr.texture(
359
- meshes[0].albedo.unsqueeze(0)[..., :3].contiguous(), texc, uv_da=texc_db, filter_mode=self.texture_filter).unsqueeze(0)
360
- albedo[~fg] = 0
361
- elif meshes[0].vc is not None:
362
- rgba = dr.interpolate(
363
- meshes[0].vc.contiguous(), rast, meshes[0].f)[0].reshape(num_scenes, num_images, h, w, 4)
364
- alpha = alpha * rgba[..., 3:4]
365
- albedo = rgba[..., :3] * alpha
366
- else:
367
- albedo = torch.zeros_like(rot_normal)
368
-
369
- prev_grad_enabled = torch.is_grad_enabled()
370
- torch.set_grad_enabled(True)
371
- if shading_fun is not None:
372
- xyz = dr.interpolate(
373
- meshes[0].v.unsqueeze(0).contiguous(), rast, meshes[0].f)[0].reshape(num_scenes, num_images, h, w, 3)
374
- rgb_reshade = shading_fun(
375
- world_pos=xyz[fg],
376
- albedo=albedo[fg],
377
- world_normal=normal[fg],
378
- fg_mask=fg)
379
- albedo = torch.zeros_like(albedo)
380
- albedo[fg] = rgb_reshade
381
-
382
- # (num_scenes, num_images, h, w, 4)
383
- rgba = torch.cat([albedo, alpha], dim=-1)
384
-
385
- if dilate_edges > 0:
386
- rgba = rgba.reshape(num_scenes * num_images, h, w, 4).permute(0, 3, 1, 2)
387
- rgba = edge_dilation(rgba, rgba[:, 3:], dilate_edges)
388
- rgba = rgba.permute(0, 2, 3, 1).reshape(num_scenes, num_images, h, w, 4)
389
-
390
- if aa:
391
- rgba, depth, rot_normal = dr.antialias(
392
- torch.cat([rgba, depth.unsqueeze(-1), rot_normal], dim=-1).squeeze(0),
393
- rast, v_clip[0], meshes[0].f).unsqueeze(0).split([4, 1, 3], dim=-1)
394
- depth = depth.squeeze(-1)
395
-
396
- else: # concat and range mode
397
- # v_cat = []
398
- v_clip_cat = []
399
- v_cam_cat = []
400
- vn_cat = []
401
- vt_cat = []
402
- f_cat = []
403
- fn_cat = []
404
- ft_cat = []
405
- v_count = 0
406
- vn_count = 0
407
- vt_count = 0
408
- f_count = 0
409
- f_ranges = []
410
- for i, mesh in enumerate(meshes):
411
- num_v = v_clip[i].size(1)
412
- num_vn = mesh.vn.size(0)
413
- num_vt = mesh.vt.size(0)
414
- # v_cat.append(mesh.v.unsqueeze(0).expand(num_images, -1, -1).reshape(num_images * num_v, 3))
415
- v_clip_cat.append(v_clip[i].reshape(num_images * num_v, 4))
416
- v_cam_cat.append(v_cam[i].reshape(num_images * num_v, 3))
417
- vn_cat.append(mesh.vn.unsqueeze(0).expand(num_images, -1, -1).reshape(num_images * num_vn, 3))
418
- vt_cat.append(mesh.vt.unsqueeze(0).expand(num_images, -1, -1).reshape(num_images * num_vt, 2))
419
- for _ in range(num_images):
420
- f_cat.append(mesh.f + v_count)
421
- fn_cat.append(mesh.fn + vn_count)
422
- ft_cat.append(mesh.ft + vt_count)
423
- v_count += num_v
424
- vn_count += num_vn
425
- vt_count += num_vt
426
- f_ranges.append([f_count, mesh.f.size(0)])
427
- f_count += mesh.f.size(0)
428
- # v_cat = torch.cat(v_cat, dim=0)
429
- v_clip_cat = torch.cat(v_clip_cat, dim=0)
430
- v_cam_cat = torch.cat(v_cam_cat, dim=0)
431
- vn_cat = torch.cat(vn_cat, dim=0)
432
- f_cat = torch.cat(f_cat, dim=0)
433
- f_ranges = torch.tensor(f_ranges, device=poses.device, dtype=torch.int32)
434
- # (num_scenes * num_images, h, w, 4) in [u, v, z/w, triangle_id] & (num_scenes * num_images, h, w, 4 or 0)
435
- rast, rast_db = dr.rasterize(
436
- self.glctx, v_clip_cat, f_cat, (h, w), ranges=f_ranges, grad_db=torch.is_grad_enabled())
437
-
438
- fg = (rast[..., 3] > 0).reshape(num_scenes, num_images, h, w)
439
-
440
- depth = 1 / dr.interpolate(
441
- -v_cam_cat[..., 2:3].contiguous(), rast, f_cat)[0].reshape(num_scenes, num_images, h, w)
442
- depth.masked_fill_(~fg, 0)
443
-
444
- normal = dr.interpolate(
445
- vn_cat, rast, fn_cat)[0].reshape(num_scenes, num_images, h, w, 3)
446
- normal = F.normalize(normal, dim=-1)
447
- # (num_scenes, num_images, h, w, 3) = (num_scenes, num_images, h, w, 3) @ (num_scenes, num_images, 1, 3, 3)
448
- rot_normal = (normal @ r_mat_c2w.unsqueeze(2)) / 2 + 0.5
449
- rot_normal[~fg] = rot_normal.new_tensor(normal_bg)
450
-
451
- # (num_scenes * num_images, h, w, 2) & (num_scenes * num_images, h, w, 4)
452
- texc, texc_db = dr.interpolate(
453
- vt_cat, rast, ft_cat, rast_db=rast_db, diff_attrs='all')
454
- albedo = dr.texture(
455
- torch.cat([mesh.albedo.unsqueeze(0)[..., :3].expand(num_images, -1, -1, -1) for mesh in meshes], dim=0),
456
- texc, uv_da=texc_db, filter_mode=self.texture_filter
457
- ).reshape(num_scenes, num_images, h, w, 3)
458
-
459
- prev_grad_enabled = torch.is_grad_enabled()
460
- torch.set_grad_enabled(True)
461
- if shading_fun is not None:
462
- raise NotImplementedError
463
-
464
- # (num_scenes, num_images, h, w, 4)
465
- rgba = torch.cat([albedo, fg.float().unsqueeze(-1)], dim=-1)
466
-
467
- if dilate_edges > 0:
468
- rgba = rgba.reshape(num_scenes * num_images, h, w, 4).permute(0, 3, 1, 2)
469
- rgba = edge_dilation(rgba, rgba[:, 3:], dilate_edges)
470
- rgba = rgba.permute(0, 2, 3, 1).reshape(num_scenes, num_images, h, w, 4)
471
-
472
- if aa:
473
- # Todo: depth/normal antialiasing
474
- rgba = dr.antialias(
475
- rgba.reshape(num_scenes * num_images, h, w, 4), rast, v_clip_cat, f_cat
476
- ).reshape(num_scenes, num_images, h, w, 4)
477
-
478
- if self.ssaa > 1:
479
- rgba = interpolate_hwc(rgba, 1 / self.ssaa)
480
- depth = interpolate_hwc(depth.unsqueeze(-1), 1 / self.ssaa).squeeze(-1)
481
- rot_normal = interpolate_hwc(rot_normal, 1 / self.ssaa)
482
-
483
- results = dict(
484
- rgba=rgba,
485
- depth=depth,
486
- normal=rot_normal)
487
-
488
- torch.set_grad_enabled(prev_grad_enabled)
489
-
490
- return results
491
-
492
- def bake_xyz_shading_fun(self, meshes, shading_fun, map_size=1024, force_auto_uv=False):
493
- assert len(meshes) == 1, 'only support one mesh'
494
- mesh = meshes[0]
495
-
496
- if mesh.vt is None or force_auto_uv:
497
- mesh.auto_uv()
498
- assert len(mesh.ft) == len(mesh.f)
499
-
500
- vt_clip = torch.cat([mesh.vt * 2 - 1, mesh.vt.new_tensor([[0., 1.]]).expand(mesh.vt.size(0), -1)], dim=-1)
501
-
502
- rast = dr.rasterize(self.glctx, vt_clip[None], mesh.ft, (map_size, map_size), grad_db=False)[0]
503
- valid = (rast[..., 3] > 0).reshape(map_size, map_size)
504
-
505
- xyz = dr.interpolate(mesh.v[None], rast, mesh.f)[0].reshape(map_size, map_size, 3)
506
- rgb_reshade = shading_fun(world_pos=xyz[valid])
507
- new_albedo_map = xyz.new_zeros((map_size, map_size, 3))
508
- new_albedo_map[valid] = rgb_reshade
509
- torch.cuda.empty_cache()
510
- new_albedo_map = edge_dilation(
511
- new_albedo_map.permute(2, 0, 1)[None], valid[None, None].float(),
512
- ).squeeze(0).permute(1, 2, 0)
513
- mesh.albedo = torch.cat(
514
- [new_albedo_map.clamp(min=0, max=1),
515
- torch.ones_like(new_albedo_map[..., :1])], dim=-1)
516
-
517
- mesh.textureless = False
518
- return [mesh]
519
-
520
- def bake_multiview(self, meshes, images, alphas, poses, intrinsics, map_size=1024, cos_weight_pow=4.0):
521
- assert len(meshes) == 1, 'only support one mesh'
522
- mesh = meshes[0]
523
- images = images[0] # (n, h, w, 3)
524
- alphas = alphas[0] # (n, h, w, 1)
525
- n, h, w, _ = images.size()
526
-
527
- r_mat_c2w = torch.cat(
528
- [poses[..., :3, :1], -poses[..., :3, 1:3]], dim=-1)[0] # opencv to opengl conversion
529
-
530
- proj = poses.new_zeros([n, 4, 4])
531
- proj[..., 0, 0] = 2 * intrinsics[..., 0] / w
532
- proj[..., 0, 2] = -2 * intrinsics[..., 2] / w + 1
533
- proj[..., 1, 1] = -2 * intrinsics[..., 1] / h
534
- proj[..., 1, 2] = -2 * intrinsics[..., 3] / h + 1
535
- proj[..., 2, 2] = -(self.far + self.near) / (self.far - self.near)
536
- proj[..., 2, 3] = -(2 * self.far * self.near) / (self.far - self.near)
537
- proj[..., 3, 2] = -1
538
-
539
- # (num_images, num_vertices, 3)
540
- v_cam = (mesh.v.detach() - poses[0, :, :3, 3].unsqueeze(-2)) @ r_mat_c2w
541
- # (num_images, num_vertices, 4)
542
- v_clip = F.pad(v_cam, pad=(0, 1), mode='constant', value=1.0) @ proj.transpose(-1, -2)
543
-
544
- rast, rast_db = dr.rasterize(self.glctx, v_clip, mesh.f, (h, w), grad_db=False)
545
- texc, texc_db = dr.interpolate(
546
- mesh.vt.unsqueeze(0).contiguous(), rast, mesh.ft, rast_db=rast_db, diff_attrs='all')
547
-
548
- with torch.enable_grad():
549
- dummy_maps = torch.ones((n, map_size, map_size, 1), device=images.device, dtype=images.dtype).requires_grad_(True)
550
- # (num_images, h, w, 1)
551
- albedo = dr.texture(
552
- dummy_maps, texc, uv_da=texc_db, filter_mode=self.texture_filter)
553
- visibility_grad = torch.autograd.grad(albedo.sum(), dummy_maps, create_graph=False)[0]
554
-
555
- fg = rast[..., 3] > 0 # (num_images, h, w)
556
- depth = 1 / dr.interpolate(
557
- -v_cam[..., 2:3].contiguous(), rast, mesh.f)[0].reshape(n, h, w)
558
- depth.masked_fill_(~fg, 0)
559
-
560
- # # save all the depth maps for visualization debug
561
- # import matplotlib.pyplot as plt
562
- # for i in range(n):
563
- # plt.imshow(depth[i].cpu().numpy())
564
- # plt.savefig(f'depth_{i}.png')
565
- # # also save the alphas
566
- # for i in range(n):
567
- # plt.imshow(alphas[i].cpu().numpy())
568
- # plt.savefig(f'alpha_{i}.png')
569
-
570
- directions = get_ray_directions(
571
- h, w, intrinsics.squeeze(0), norm=True, device=intrinsics.device)
572
-
573
- normals_opencv = depth_to_normal(
574
- depth, directions, format='opencv') * 2 - 1
575
- normals_cos_weight = (normals_opencv[..., None, :] @ directions[..., :, None]).squeeze(-1).neg().clamp(min=0)
576
-
577
- img_space_weight = (normals_cos_weight ** cos_weight_pow) * alphas
578
- img_space_weight = -F.max_pool2d( # alleviate edge effect
579
- -img_space_weight.permute(0, 3, 1, 2), 5, stride=1, padding=2).permute(0, 2, 3, 1)
580
-
581
- # bake texture
582
- vt_clip = torch.cat([mesh.vt * 2 - 1, mesh.vt.new_tensor([[0., 1.]]).expand(mesh.vt.size(0), -1)], dim=-1)
583
-
584
- rast, rast_db = dr.rasterize(self.glctx, vt_clip[None], mesh.ft, (map_size, map_size), grad_db=False)
585
- valid = (rast[..., 3] > 0).reshape(map_size, map_size)
586
- rast = rast.expand(n, -1, -1, -1)
587
- rast_db = rast_db.expand(n, -1, -1, -1)
588
- v_img = v_clip[..., :2] / v_clip[..., 3:] * 0.5 + 0.5
589
- # print(v_img.min(), v_img.max())
590
- texc, texc_db = dr.interpolate(
591
- v_img.contiguous(), rast.contiguous(), mesh.f, rast_db=rast_db.contiguous(), diff_attrs='all')
592
- # (n, map_size, map_size, 4)
593
- tex = dr.texture(
594
- torch.cat([images, img_space_weight], dim=-1), texc, uv_da=texc_db, filter_mode=self.texture_filter)
595
-
596
- weight = tex[..., 3:4] * visibility_grad
597
-
598
- new_albedo_map = (tex[..., :3] * weight).sum(dim=0) / weight.sum(dim=0).clamp(min=1e-6)
599
-
600
- new_albedo_map = edge_dilation(
601
- new_albedo_map.permute(2, 0, 1)[None], valid[None, None].float(),
602
- ).squeeze(0).permute(1, 2, 0)
603
- mesh.albedo = torch.cat(
604
- [new_albedo_map.clamp(min=0, max=1),
605
- torch.ones_like(new_albedo_map[..., :1])], dim=-1)
606
-
607
- mesh.textureless = False
608
- return [mesh]