fix: background net should condition on rays_d
Browse files- nerf/network.py +4 -5
- nerf/network_grid.py +4 -5
- nerf/network_tcnn.py +6 -16
- nerf/renderer.py +4 -4
- nerf/utils.py +3 -0
nerf/network.py
CHANGED
@@ -52,7 +52,7 @@ class NeRFNetwork(NeRFRenderer):
|
|
52 |
if self.bg_radius > 0:
|
53 |
self.num_layers_bg = num_layers_bg
|
54 |
self.hidden_dim_bg = hidden_dim_bg
|
55 |
-
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=
|
56 |
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
57 |
|
58 |
else:
|
@@ -80,7 +80,7 @@ class NeRFNetwork(NeRFRenderer):
|
|
80 |
return sigma, albedo
|
81 |
|
82 |
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
|
83 |
-
def finite_difference_normal(self, x, epsilon=
|
84 |
# x: [N, 3]
|
85 |
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
86 |
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
@@ -148,10 +148,9 @@ class NeRFNetwork(NeRFRenderer):
|
|
148 |
}
|
149 |
|
150 |
|
151 |
-
def background(self,
|
152 |
-
# x: [N, 2], in [-1, 1]
|
153 |
|
154 |
-
h = self.encoder_bg(
|
155 |
|
156 |
h = self.bg_net(h)
|
157 |
|
|
|
52 |
if self.bg_radius > 0:
|
53 |
self.num_layers_bg = num_layers_bg
|
54 |
self.hidden_dim_bg = hidden_dim_bg
|
55 |
+
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3)
|
56 |
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
57 |
|
58 |
else:
|
|
|
80 |
return sigma, albedo
|
81 |
|
82 |
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
|
83 |
+
def finite_difference_normal(self, x, epsilon=1e-2):
|
84 |
# x: [N, 3]
|
85 |
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
86 |
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
|
|
148 |
}
|
149 |
|
150 |
|
151 |
+
def background(self, d):
|
|
|
152 |
|
153 |
+
h = self.encoder_bg(d) # [N, C]
|
154 |
|
155 |
h = self.bg_net(h)
|
156 |
|
nerf/network_grid.py
CHANGED
@@ -57,7 +57,7 @@ class NeRFNetwork(NeRFRenderer):
|
|
57 |
|
58 |
# use a very simple network to avoid it learning the prompt...
|
59 |
# self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2, num_levels=4, desired_resolution=2048)
|
60 |
-
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=
|
61 |
|
62 |
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
63 |
|
@@ -87,7 +87,7 @@ class NeRFNetwork(NeRFRenderer):
|
|
87 |
return sigma, albedo
|
88 |
|
89 |
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
|
90 |
-
def finite_difference_normal(self, x, epsilon=
|
91 |
# x: [N, 3]
|
92 |
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
93 |
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
@@ -155,10 +155,9 @@ class NeRFNetwork(NeRFRenderer):
|
|
155 |
}
|
156 |
|
157 |
|
158 |
-
def background(self,
|
159 |
-
# x: [N, 2], in [-1, 1]
|
160 |
|
161 |
-
h = self.encoder_bg(
|
162 |
|
163 |
h = self.bg_net(h)
|
164 |
|
|
|
57 |
|
58 |
# use a very simple network to avoid it learning the prompt...
|
59 |
# self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2, num_levels=4, desired_resolution=2048)
|
60 |
+
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3)
|
61 |
|
62 |
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
63 |
|
|
|
87 |
return sigma, albedo
|
88 |
|
89 |
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
|
90 |
+
def finite_difference_normal(self, x, epsilon=1e-2):
|
91 |
# x: [N, 3]
|
92 |
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
93 |
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
|
|
155 |
}
|
156 |
|
157 |
|
158 |
+
def background(self, d):
|
|
|
159 |
|
160 |
+
h = self.encoder_bg(d) # [N, C]
|
161 |
|
162 |
h = self.bg_net(h)
|
163 |
|
nerf/network_tcnn.py
CHANGED
@@ -4,6 +4,7 @@ import torch.nn.functional as F
|
|
4 |
|
5 |
from activation import trunc_exp
|
6 |
from .renderer import NeRFRenderer
|
|
|
7 |
|
8 |
import numpy as np
|
9 |
import tinycudann as tcnn
|
@@ -65,19 +66,9 @@ class NeRFNetwork(NeRFRenderer):
|
|
65 |
self.num_layers_bg = num_layers_bg
|
66 |
self.hidden_dim_bg = hidden_dim_bg
|
67 |
|
68 |
-
self.encoder_bg =
|
69 |
-
|
70 |
-
|
71 |
-
"otype": "HashGrid",
|
72 |
-
"n_levels": 4,
|
73 |
-
"n_features_per_level": 2,
|
74 |
-
"log2_hashmap_size": 16,
|
75 |
-
"base_resolution": 16,
|
76 |
-
"per_level_scale": 1.5,
|
77 |
-
},
|
78 |
-
)
|
79 |
-
|
80 |
-
self.bg_net = MLP(8, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
81 |
|
82 |
else:
|
83 |
self.bg_net = None
|
@@ -156,11 +147,10 @@ class NeRFNetwork(NeRFRenderer):
|
|
156 |
}
|
157 |
|
158 |
|
159 |
-
def background(self,
|
160 |
# x: [N, 2], in [-1, 1]
|
161 |
|
162 |
-
h = (
|
163 |
-
h = self.encoder_bg(h) # [N, C]
|
164 |
|
165 |
h = self.bg_net(h)
|
166 |
|
|
|
4 |
|
5 |
from activation import trunc_exp
|
6 |
from .renderer import NeRFRenderer
|
7 |
+
from encoding import get_encoder
|
8 |
|
9 |
import numpy as np
|
10 |
import tinycudann as tcnn
|
|
|
66 |
self.num_layers_bg = num_layers_bg
|
67 |
self.hidden_dim_bg = hidden_dim_bg
|
68 |
|
69 |
+
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3)
|
70 |
+
|
71 |
+
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
else:
|
74 |
self.bg_net = None
|
|
|
147 |
}
|
148 |
|
149 |
|
150 |
+
def background(self, d):
|
151 |
# x: [N, 2], in [-1, 1]
|
152 |
|
153 |
+
h = self.encoder_bg(d) # [N, C]
|
|
|
154 |
|
155 |
h = self.bg_net(h)
|
156 |
|
nerf/renderer.py
CHANGED
@@ -420,8 +420,8 @@ class NeRFRenderer(nn.Module):
|
|
420 |
# mix background color
|
421 |
if self.bg_radius > 0:
|
422 |
# use the bg model to calculate bg_color
|
423 |
-
sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
424 |
-
bg_color = self.background(
|
425 |
elif bg_color is None:
|
426 |
bg_color = 1
|
427 |
|
@@ -526,8 +526,8 @@ class NeRFRenderer(nn.Module):
|
|
526 |
if self.bg_radius > 0:
|
527 |
|
528 |
# use the bg model to calculate bg_color
|
529 |
-
sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
530 |
-
bg_color = self.background(
|
531 |
|
532 |
elif bg_color is None:
|
533 |
bg_color = 1
|
|
|
420 |
# mix background color
|
421 |
if self.bg_radius > 0:
|
422 |
# use the bg model to calculate bg_color
|
423 |
+
# sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
424 |
+
bg_color = self.background(rays_d.reshape(-1, 3)) # [N, 3]
|
425 |
elif bg_color is None:
|
426 |
bg_color = 1
|
427 |
|
|
|
526 |
if self.bg_radius > 0:
|
527 |
|
528 |
# use the bg model to calculate bg_color
|
529 |
+
# sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
530 |
+
bg_color = self.background(rays_d) # [N, 3]
|
531 |
|
532 |
elif bg_color is None:
|
533 |
bg_color = 1
|
nerf/utils.py
CHANGED
@@ -343,6 +343,9 @@ class Trainer(object):
|
|
343 |
pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
|
344 |
# torch.cuda.synchronize(); print(f'[TIME] nerf render {time.time() - _t:.4f}s')
|
345 |
|
|
|
|
|
|
|
346 |
# text embeddings
|
347 |
if self.opt.dir_text:
|
348 |
dirs = data['dir'] # [B,]
|
|
|
343 |
pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
|
344 |
# torch.cuda.synchronize(); print(f'[TIME] nerf render {time.time() - _t:.4f}s')
|
345 |
|
346 |
+
# print(shading)
|
347 |
+
# torch_vis_2d(pred_rgb[0])
|
348 |
+
|
349 |
# text embeddings
|
350 |
if self.opt.dir_text:
|
351 |
dirs = data['dir'] # [B,]
|