|
#include <cuda.h> |
|
#include <cuda_fp16.h> |
|
#include <cuda_runtime.h> |
|
|
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <torch/torch.h> |
|
|
|
#include <cstdio> |
|
#include <stdint.h> |
|
#include <stdexcept> |
|
#include <limits> |
|
|
|
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") |
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") |
|
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") |
|
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") |
|
|
|
|
|
inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; } |
|
inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; } |
|
inline constexpr __device__ float PI() { return 3.141592653589793f; } |
|
inline constexpr __device__ float RPI() { return 0.3183098861837907f; } |
|
|
|
|
|
template <typename T> |
|
inline __host__ __device__ T div_round_up(T val, T divisor) { |
|
return (val + divisor - 1) / divisor; |
|
} |
|
|
|
inline __host__ __device__ float signf(const float x) { |
|
return copysignf(1.0, x); |
|
} |
|
|
|
inline __host__ __device__ float clamp(const float x, const float min, const float max) { |
|
return fminf(max, fmaxf(min, x)); |
|
} |
|
|
|
inline __host__ __device__ void swapf(float& a, float& b) { |
|
float c = a; a = b; b = c; |
|
} |
|
|
|
inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { |
|
const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z))); |
|
int exponent; |
|
frexpf(mx, &exponent); |
|
return fminf(max_cascade - 1, fmaxf(0, exponent)); |
|
} |
|
|
|
inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) { |
|
const float mx = dt * H * 0.5; |
|
int exponent; |
|
frexpf(mx, &exponent); |
|
return fminf(max_cascade - 1, fmaxf(0, exponent)); |
|
} |
|
|
|
inline __host__ __device__ uint32_t __expand_bits(uint32_t v) |
|
{ |
|
v = (v * 0x00010001u) & 0xFF0000FFu; |
|
v = (v * 0x00000101u) & 0x0F00F00Fu; |
|
v = (v * 0x00000011u) & 0xC30C30C3u; |
|
v = (v * 0x00000005u) & 0x49249249u; |
|
return v; |
|
} |
|
|
|
inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) |
|
{ |
|
uint32_t xx = __expand_bits(x); |
|
uint32_t yy = __expand_bits(y); |
|
uint32_t zz = __expand_bits(z); |
|
return xx | (yy << 1) | (zz << 2); |
|
} |
|
|
|
inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) |
|
{ |
|
x = x & 0x49249249; |
|
x = (x | (x >> 2)) & 0xc30c30c3; |
|
x = (x | (x >> 4)) & 0x0f00f00f; |
|
x = (x | (x >> 8)) & 0xff0000ff; |
|
x = (x | (x >> 16)) & 0x0000ffff; |
|
return x; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__global__ void kernel_near_far_from_aabb( |
|
const scalar_t * __restrict__ rays_o, |
|
const scalar_t * __restrict__ rays_d, |
|
const scalar_t * __restrict__ aabb, |
|
const uint32_t N, |
|
const float min_near, |
|
scalar_t * nears, scalar_t * fars |
|
) { |
|
|
|
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; |
|
if (n >= N) return; |
|
|
|
|
|
rays_o += n * 3; |
|
rays_d += n * 3; |
|
|
|
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; |
|
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; |
|
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; |
|
|
|
|
|
float near = (aabb[0] - ox) * rdx; |
|
float far = (aabb[3] - ox) * rdx; |
|
if (near > far) swapf(near, far); |
|
|
|
float near_y = (aabb[1] - oy) * rdy; |
|
float far_y = (aabb[4] - oy) * rdy; |
|
if (near_y > far_y) swapf(near_y, far_y); |
|
|
|
if (near > far_y || near_y > far) { |
|
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max(); |
|
return; |
|
} |
|
|
|
if (near_y > near) near = near_y; |
|
if (far_y < far) far = far_y; |
|
|
|
float near_z = (aabb[2] - oz) * rdz; |
|
float far_z = (aabb[5] - oz) * rdz; |
|
if (near_z > far_z) swapf(near_z, far_z); |
|
|
|
if (near > far_z || near_z > far) { |
|
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max(); |
|
return; |
|
} |
|
|
|
if (near_z > near) near = near_z; |
|
if (far_z < far) far = far_z; |
|
|
|
if (near < min_near) near = min_near; |
|
|
|
nears[n] = near; |
|
fars[n] = far; |
|
} |
|
|
|
|
|
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { |
|
|
|
static constexpr uint32_t N_THREAD = 128; |
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
rays_o.scalar_type(), "near_far_from_aabb", ([&] { |
|
kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>()); |
|
})); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__global__ void kernel_sph_from_ray( |
|
const scalar_t * __restrict__ rays_o, |
|
const scalar_t * __restrict__ rays_d, |
|
const float radius, |
|
const uint32_t N, |
|
scalar_t * coords |
|
) { |
|
|
|
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; |
|
if (n >= N) return; |
|
|
|
|
|
rays_o += n * 3; |
|
rays_d += n * 3; |
|
coords += n * 2; |
|
|
|
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; |
|
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; |
|
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; |
|
|
|
|
|
const float A = dx * dx + dy * dy + dz * dz; |
|
const float B = ox * dx + oy * dy + oz * dz; |
|
const float C = ox * ox + oy * oy + oz * oz - radius * radius; |
|
|
|
const float t = (- B + sqrtf(B * B - A * C)) / A; |
|
|
|
|
|
const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz; |
|
const float theta = atan2(sqrtf(x * x + z * z), y); |
|
const float phi = atan2(z, x); |
|
|
|
|
|
coords[0] = 2 * theta * RPI() - 1; |
|
coords[1] = phi * RPI(); |
|
} |
|
|
|
|
|
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) { |
|
|
|
static constexpr uint32_t N_THREAD = 128; |
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
rays_o.scalar_type(), "sph_from_ray", ([&] { |
|
kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>()); |
|
})); |
|
} |
|
|
|
|
|
|
|
|
|
__global__ void kernel_morton3D( |
|
const int * __restrict__ coords, |
|
const uint32_t N, |
|
int * indices |
|
) { |
|
|
|
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; |
|
if (n >= N) return; |
|
|
|
|
|
coords += n * 3; |
|
indices[n] = __morton3D(coords[0], coords[1], coords[2]); |
|
} |
|
|
|
|
|
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) { |
|
static constexpr uint32_t N_THREAD = 128; |
|
kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>()); |
|
} |
|
|
|
|
|
|
|
|
|
__global__ void kernel_morton3D_invert( |
|
const int * __restrict__ indices, |
|
const uint32_t N, |
|
int * coords |
|
) { |
|
|
|
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; |
|
if (n >= N) return; |
|
|
|
|
|
coords += n * 3; |
|
|
|
const int ind = indices[n]; |
|
|
|
coords[0] = __morton3D_invert(ind >> 0); |
|
coords[1] = __morton3D_invert(ind >> 1); |
|
coords[2] = __morton3D_invert(ind >> 2); |
|
} |
|
|
|
|
|
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) { |
|
static constexpr uint32_t N_THREAD = 128; |
|
kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>()); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__global__ void kernel_packbits( |
|
const scalar_t * __restrict__ grid, |
|
const uint32_t N, |
|
const float density_thresh, |
|
uint8_t * bitfield |
|
) { |
|
|
|
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; |
|
if (n >= N) return; |
|
|
|
|
|
grid += n * 8; |
|
|
|
uint8_t bits = 0; |
|
|
|
#pragma unroll |
|
for (uint8_t i = 0; i < 8; i++) { |
|
bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0; |
|
} |
|
|
|
bitfield[n] = bits; |
|
} |
|
|
|
|
|
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) { |
|
|
|
static constexpr uint32_t N_THREAD = 128; |
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
grid.scalar_type(), "packbits", ([&] { |
|
kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>()); |
|
})); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__global__ void kernel_march_rays_train( |
|
const scalar_t * __restrict__ rays_o, |
|
const scalar_t * __restrict__ rays_d, |
|
const uint8_t * __restrict__ grid, |
|
const float bound, |
|
const float dt_gamma, const uint32_t max_steps, |
|
const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, |
|
const scalar_t* __restrict__ nears, |
|
const scalar_t* __restrict__ fars, |
|
scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas, |
|
int * rays, |
|
int * counter, |
|
const scalar_t* __restrict__ noises |
|
) { |
|
|
|
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; |
|
if (n >= N) return; |
|
|
|
|
|
rays_o += n * 3; |
|
rays_d += n * 3; |
|
|
|
|
|
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; |
|
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; |
|
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; |
|
const float rH = 1 / (float)H; |
|
const float H3 = H * H * H; |
|
|
|
const float near = nears[n]; |
|
const float far = fars[n]; |
|
const float noise = noises[n]; |
|
|
|
const float dt_min = 2 * SQRT3() / max_steps; |
|
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; |
|
|
|
float t0 = near; |
|
|
|
|
|
t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise; |
|
|
|
|
|
float t = t0; |
|
uint32_t num_steps = 0; |
|
|
|
|
|
|
|
while (t < far && num_steps < max_steps) { |
|
|
|
const float x = clamp(ox + t * dx, -bound, bound); |
|
const float y = clamp(oy + t * dy, -bound, bound); |
|
const float z = clamp(oz + t * dz, -bound, bound); |
|
|
|
const float dt = clamp(t * dt_gamma, dt_min, dt_max); |
|
|
|
|
|
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); |
|
|
|
const float mip_bound = fminf(scalbnf(1.0f, level), bound); |
|
const float mip_rbound = 1 / mip_bound; |
|
|
|
|
|
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); |
|
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); |
|
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); |
|
|
|
const uint32_t index = level * H3 + __morton3D(nx, ny, nz); |
|
const bool occ = grid[index / 8] & (1 << (index % 8)); |
|
|
|
|
|
|
|
|
|
if (occ) { |
|
num_steps++; |
|
t += dt; |
|
|
|
} else { |
|
|
|
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; |
|
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; |
|
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; |
|
|
|
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); |
|
|
|
do { |
|
t += clamp(t * dt_gamma, dt_min, dt_max); |
|
} while (t < tt); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
uint32_t point_index = atomicAdd(counter, num_steps); |
|
uint32_t ray_index = atomicAdd(counter + 1, 1); |
|
|
|
|
|
|
|
|
|
rays[ray_index * 3] = n; |
|
rays[ray_index * 3 + 1] = point_index; |
|
rays[ray_index * 3 + 2] = num_steps; |
|
|
|
if (num_steps == 0) return; |
|
if (point_index + num_steps > M) return; |
|
|
|
xyzs += point_index * 3; |
|
dirs += point_index * 3; |
|
deltas += point_index * 2; |
|
|
|
t = t0; |
|
uint32_t step = 0; |
|
|
|
float last_t = t; |
|
|
|
while (t < far && step < num_steps) { |
|
|
|
const float x = clamp(ox + t * dx, -bound, bound); |
|
const float y = clamp(oy + t * dy, -bound, bound); |
|
const float z = clamp(oz + t * dz, -bound, bound); |
|
|
|
const float dt = clamp(t * dt_gamma, dt_min, dt_max); |
|
|
|
|
|
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); |
|
|
|
const float mip_bound = fminf(scalbnf(1.0f, level), bound); |
|
const float mip_rbound = 1 / mip_bound; |
|
|
|
|
|
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); |
|
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); |
|
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); |
|
|
|
|
|
const uint32_t index = level * H3 + __morton3D(nx, ny, nz); |
|
const bool occ = grid[index / 8] & (1 << (index % 8)); |
|
|
|
|
|
if (occ) { |
|
|
|
xyzs[0] = x; |
|
xyzs[1] = y; |
|
xyzs[2] = z; |
|
dirs[0] = dx; |
|
dirs[1] = dy; |
|
dirs[2] = dz; |
|
t += dt; |
|
deltas[0] = dt; |
|
deltas[1] = t - last_t; |
|
last_t = t; |
|
xyzs += 3; |
|
dirs += 3; |
|
deltas += 2; |
|
step++; |
|
|
|
} else { |
|
|
|
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; |
|
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; |
|
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; |
|
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); |
|
|
|
do { |
|
t += clamp(t * dt_gamma, dt_min, dt_max); |
|
} while (t < tt); |
|
} |
|
} |
|
} |
|
|
|
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) { |
|
|
|
static constexpr uint32_t N_THREAD = 128; |
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
rays_o.scalar_type(), "march_rays_train", ([&] { |
|
kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>()); |
|
})); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__global__ void kernel_composite_rays_train_forward( |
|
const scalar_t * __restrict__ sigmas, |
|
const scalar_t * __restrict__ rgbs, |
|
const scalar_t * __restrict__ deltas, |
|
const int * __restrict__ rays, |
|
const uint32_t M, const uint32_t N, const float T_thresh, |
|
scalar_t * weights_sum, |
|
scalar_t * depth, |
|
scalar_t * image |
|
) { |
|
|
|
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; |
|
if (n >= N) return; |
|
|
|
|
|
uint32_t index = rays[n * 3]; |
|
uint32_t offset = rays[n * 3 + 1]; |
|
uint32_t num_steps = rays[n * 3 + 2]; |
|
|
|
|
|
if (num_steps == 0 || offset + num_steps > M) { |
|
weights_sum[index] = 0; |
|
depth[index] = 0; |
|
image[index * 3] = 0; |
|
image[index * 3 + 1] = 0; |
|
image[index * 3 + 2] = 0; |
|
return; |
|
} |
|
|
|
sigmas += offset; |
|
rgbs += offset * 3; |
|
deltas += offset * 2; |
|
|
|
|
|
uint32_t step = 0; |
|
|
|
scalar_t T = 1.0f; |
|
scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0; |
|
|
|
while (step < num_steps) { |
|
|
|
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); |
|
const scalar_t weight = alpha * T; |
|
|
|
r += weight * rgbs[0]; |
|
g += weight * rgbs[1]; |
|
b += weight * rgbs[2]; |
|
|
|
t += deltas[1]; |
|
d += weight * t; |
|
|
|
ws += weight; |
|
|
|
T *= 1.0f - alpha; |
|
|
|
|
|
if (T < T_thresh) break; |
|
|
|
|
|
|
|
|
|
sigmas++; |
|
rgbs += 3; |
|
deltas += 2; |
|
|
|
step++; |
|
} |
|
|
|
|
|
|
|
|
|
weights_sum[index] = ws; |
|
depth[index] = d; |
|
image[index * 3] = r; |
|
image[index * 3 + 1] = g; |
|
image[index * 3 + 2] = b; |
|
} |
|
|
|
|
|
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) { |
|
|
|
static constexpr uint32_t N_THREAD = 128; |
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
sigmas.scalar_type(), "composite_rays_train_forward", ([&] { |
|
kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>()); |
|
})); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__global__ void kernel_composite_rays_train_backward( |
|
const scalar_t * __restrict__ grad_weights_sum, |
|
const scalar_t * __restrict__ grad_image, |
|
const scalar_t * __restrict__ sigmas, |
|
const scalar_t * __restrict__ rgbs, |
|
const scalar_t * __restrict__ deltas, |
|
const int * __restrict__ rays, |
|
const scalar_t * __restrict__ weights_sum, |
|
const scalar_t * __restrict__ image, |
|
const uint32_t M, const uint32_t N, const float T_thresh, |
|
scalar_t * grad_sigmas, |
|
scalar_t * grad_rgbs |
|
) { |
|
|
|
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; |
|
if (n >= N) return; |
|
|
|
|
|
uint32_t index = rays[n * 3]; |
|
uint32_t offset = rays[n * 3 + 1]; |
|
uint32_t num_steps = rays[n * 3 + 2]; |
|
|
|
if (num_steps == 0 || offset + num_steps > M) return; |
|
|
|
grad_weights_sum += index; |
|
grad_image += index * 3; |
|
weights_sum += index; |
|
image += index * 3; |
|
sigmas += offset; |
|
rgbs += offset * 3; |
|
deltas += offset * 2; |
|
grad_sigmas += offset; |
|
grad_rgbs += offset * 3; |
|
|
|
|
|
uint32_t step = 0; |
|
|
|
scalar_t T = 1.0f; |
|
const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0]; |
|
scalar_t r = 0, g = 0, b = 0, ws = 0; |
|
|
|
while (step < num_steps) { |
|
|
|
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); |
|
const scalar_t weight = alpha * T; |
|
|
|
r += weight * rgbs[0]; |
|
g += weight * rgbs[1]; |
|
b += weight * rgbs[2]; |
|
ws += weight; |
|
|
|
T *= 1.0f - alpha; |
|
|
|
|
|
|
|
grad_rgbs[0] = grad_image[0] * weight; |
|
grad_rgbs[1] = grad_image[1] * weight; |
|
grad_rgbs[2] = grad_image[2] * weight; |
|
|
|
|
|
grad_sigmas[0] = deltas[0] * ( |
|
grad_image[0] * (T * rgbs[0] - (r_final - r)) + |
|
grad_image[1] * (T * rgbs[1] - (g_final - g)) + |
|
grad_image[2] * (T * rgbs[2] - (b_final - b)) + |
|
grad_weights_sum[0] * (1 - ws_final) |
|
); |
|
|
|
|
|
|
|
if (T < T_thresh) break; |
|
|
|
|
|
sigmas++; |
|
rgbs += 3; |
|
deltas += 2; |
|
grad_sigmas++; |
|
grad_rgbs += 3; |
|
|
|
step++; |
|
} |
|
} |
|
|
|
|
|
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) { |
|
|
|
static constexpr uint32_t N_THREAD = 128; |
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
grad_image.scalar_type(), "composite_rays_train_backward", ([&] { |
|
kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>()); |
|
})); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__global__ void kernel_march_rays( |
|
const uint32_t n_alive, |
|
const uint32_t n_step, |
|
const int* __restrict__ rays_alive, |
|
const scalar_t* __restrict__ rays_t, |
|
const scalar_t* __restrict__ rays_o, |
|
const scalar_t* __restrict__ rays_d, |
|
const float bound, |
|
const float dt_gamma, const uint32_t max_steps, |
|
const uint32_t C, const uint32_t H, |
|
const uint8_t * __restrict__ grid, |
|
const scalar_t* __restrict__ nears, |
|
const scalar_t* __restrict__ fars, |
|
scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas, |
|
const scalar_t* __restrict__ noises |
|
) { |
|
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; |
|
if (n >= n_alive) return; |
|
|
|
const int index = rays_alive[n]; |
|
const float noise = noises[n]; |
|
|
|
|
|
rays_o += index * 3; |
|
rays_d += index * 3; |
|
xyzs += n * n_step * 3; |
|
dirs += n * n_step * 3; |
|
deltas += n * n_step * 2; |
|
|
|
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; |
|
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; |
|
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; |
|
const float rH = 1 / (float)H; |
|
const float H3 = H * H * H; |
|
|
|
float t = rays_t[index]; |
|
const float near = nears[index], far = fars[index]; |
|
|
|
const float dt_min = 2 * SQRT3() / max_steps; |
|
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; |
|
|
|
|
|
uint32_t step = 0; |
|
|
|
|
|
t += clamp(t * dt_gamma, dt_min, dt_max) * noise; |
|
|
|
float last_t = t; |
|
|
|
while (t < far && step < n_step) { |
|
|
|
const float x = clamp(ox + t * dx, -bound, bound); |
|
const float y = clamp(oy + t * dy, -bound, bound); |
|
const float z = clamp(oz + t * dz, -bound, bound); |
|
|
|
const float dt = clamp(t * dt_gamma, dt_min, dt_max); |
|
|
|
|
|
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); |
|
|
|
const float mip_bound = fminf(scalbnf(1, level), bound); |
|
const float mip_rbound = 1 / mip_bound; |
|
|
|
|
|
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); |
|
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); |
|
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); |
|
|
|
const uint32_t index = level * H3 + __morton3D(nx, ny, nz); |
|
const bool occ = grid[index / 8] & (1 << (index % 8)); |
|
|
|
|
|
if (occ) { |
|
|
|
xyzs[0] = x; |
|
xyzs[1] = y; |
|
xyzs[2] = z; |
|
dirs[0] = dx; |
|
dirs[1] = dy; |
|
dirs[2] = dz; |
|
|
|
t += dt; |
|
deltas[0] = dt; |
|
deltas[1] = t - last_t; |
|
last_t = t; |
|
|
|
xyzs += 3; |
|
dirs += 3; |
|
deltas += 2; |
|
step++; |
|
|
|
|
|
} else { |
|
|
|
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; |
|
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; |
|
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; |
|
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); |
|
|
|
do { |
|
t += clamp(t * dt_gamma, dt_min, dt_max); |
|
} while (t < tt); |
|
} |
|
} |
|
} |
|
|
|
|
|
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) { |
|
static constexpr uint32_t N_THREAD = 128; |
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
rays_o.scalar_type(), "march_rays", ([&] { |
|
kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>()); |
|
})); |
|
} |
|
|
|
|
|
template <typename scalar_t> |
|
__global__ void kernel_composite_rays( |
|
const uint32_t n_alive, |
|
const uint32_t n_step, |
|
const float T_thresh, |
|
int* rays_alive, |
|
scalar_t* rays_t, |
|
const scalar_t* __restrict__ sigmas, |
|
const scalar_t* __restrict__ rgbs, |
|
const scalar_t* __restrict__ deltas, |
|
scalar_t* weights_sum, scalar_t* depth, scalar_t* image |
|
) { |
|
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; |
|
if (n >= n_alive) return; |
|
|
|
const int index = rays_alive[n]; |
|
|
|
|
|
sigmas += n * n_step; |
|
rgbs += n * n_step * 3; |
|
deltas += n * n_step * 2; |
|
|
|
rays_t += index; |
|
weights_sum += index; |
|
depth += index; |
|
image += index * 3; |
|
|
|
scalar_t t = rays_t[0]; |
|
|
|
scalar_t weight_sum = weights_sum[0]; |
|
scalar_t d = depth[0]; |
|
scalar_t r = image[0]; |
|
scalar_t g = image[1]; |
|
scalar_t b = image[2]; |
|
|
|
|
|
uint32_t step = 0; |
|
while (step < n_step) { |
|
|
|
|
|
if (deltas[0] == 0) break; |
|
|
|
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const scalar_t T = 1 - weight_sum; |
|
const scalar_t weight = alpha * T; |
|
weight_sum += weight; |
|
|
|
t += deltas[1]; |
|
d += weight * t; |
|
r += weight * rgbs[0]; |
|
g += weight * rgbs[1]; |
|
b += weight * rgbs[2]; |
|
|
|
|
|
|
|
|
|
|
|
if (T < T_thresh) break; |
|
|
|
|
|
sigmas++; |
|
rgbs += 3; |
|
deltas += 2; |
|
step++; |
|
} |
|
|
|
|
|
|
|
|
|
if (step < n_step) { |
|
rays_alive[n] = -1; |
|
} else { |
|
rays_t[0] = t; |
|
} |
|
|
|
weights_sum[0] = weight_sum; |
|
depth[0] = d; |
|
image[0] = r; |
|
image[1] = g; |
|
image[2] = b; |
|
} |
|
|
|
|
|
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) { |
|
static constexpr uint32_t N_THREAD = 128; |
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
image.scalar_type(), "composite_rays", ([&] { |
|
kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>()); |
|
})); |
|
} |