add option to enable camera pose jitterring
Browse files- main.py +1 -0
- nerf/provider.py +12 -5
main.py
CHANGED
@@ -42,6 +42,7 @@ if __name__ == '__main__':
|
|
42 |
# rendering resolution in training, decrease this if CUDA OOM.
|
43 |
parser.add_argument('--w', type=int, default=128, help="render width for NeRF in training")
|
44 |
parser.add_argument('--h', type=int, default=128, help="render height for NeRF in training")
|
|
|
45 |
|
46 |
### dataset options
|
47 |
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
|
|
|
42 |
# rendering resolution in training, decrease this if CUDA OOM.
|
43 |
parser.add_argument('--w', type=int, default=128, help="render width for NeRF in training")
|
44 |
parser.add_argument('--h', type=int, default=128, help="render height for NeRF in training")
|
45 |
+
parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
|
46 |
|
47 |
### dataset options
|
48 |
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
|
nerf/provider.py
CHANGED
@@ -55,7 +55,7 @@ def get_view_direction(thetas, phis, overhead, front):
|
|
55 |
return res
|
56 |
|
57 |
|
58 |
-
def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 150], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60):
|
59 |
''' generate random poses from an orbit camera
|
60 |
Args:
|
61 |
size: batch size of generated poses.
|
@@ -82,16 +82,23 @@ def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 150], phi_ra
|
|
82 |
radius * torch.sin(thetas) * torch.cos(phis),
|
83 |
], dim=-1) # [B, 3]
|
84 |
|
|
|
|
|
85 |
# jitters
|
86 |
-
|
87 |
-
|
|
|
88 |
|
89 |
# lookat
|
90 |
forward_vector = safe_normalize(targets - centers)
|
91 |
up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
|
92 |
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
-
up_noise = torch.randn_like(up_vector) * 0.02
|
95 |
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
|
96 |
|
97 |
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
|
@@ -170,7 +177,7 @@ class NeRFDataset:
|
|
170 |
|
171 |
if self.training:
|
172 |
# random pose on the fly
|
173 |
-
poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
|
174 |
|
175 |
# random focal
|
176 |
fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0]
|
|
|
55 |
return res
|
56 |
|
57 |
|
58 |
+
def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 150], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, jitter=False):
|
59 |
''' generate random poses from an orbit camera
|
60 |
Args:
|
61 |
size: batch size of generated poses.
|
|
|
82 |
radius * torch.sin(thetas) * torch.cos(phis),
|
83 |
], dim=-1) # [B, 3]
|
84 |
|
85 |
+
targets = 0
|
86 |
+
|
87 |
# jitters
|
88 |
+
if jitter:
|
89 |
+
centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
|
90 |
+
targets = targets + torch.randn_like(centers) * 0.2
|
91 |
|
92 |
# lookat
|
93 |
forward_vector = safe_normalize(targets - centers)
|
94 |
up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
|
95 |
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
96 |
+
|
97 |
+
if jitter:
|
98 |
+
up_noise = torch.randn_like(up_vector) * 0.02
|
99 |
+
else:
|
100 |
+
up_noise = 0
|
101 |
|
|
|
102 |
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
|
103 |
|
104 |
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
|
|
|
177 |
|
178 |
if self.training:
|
179 |
# random pose on the fly
|
180 |
+
poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose)
|
181 |
|
182 |
# random focal
|
183 |
fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0]
|