Skip to content

Commit 53e12d2

Browse files
author
Prapti Devansh Trivedi
committed
add rough idea for kernel
1 parent b7b08ba commit 53e12d2

File tree

5 files changed

+260
-39
lines changed

5 files changed

+260
-39
lines changed

cuda_rasterizer/rasterizer_impl.cu

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,101 @@ int CudaRasterizer::Rasterizer::preprocessForward(
424424
return num_rendered;
425425
}
426426

427+
428+
int CudaRasterizer::Rasterizer::preprocessForwardBatches(
429+
float2* means2D,
430+
float* depths,
431+
int* radii,
432+
float* cov3D,
433+
float4* conic_opacity,
434+
float* rgb,
435+
bool* clamped,//the above are all per-Gaussian intemediate results.
436+
const int P, int D, int M,
437+
const std::vector<int>& width, std::vector<int>& height,
438+
const float* means3D,
439+
const float* scales,
440+
const float* rotations,
441+
const float* shs,
442+
const float* opacities,//3dgs parameters
443+
const std::vector<float>& scale_modifier,
444+
const std::vector<torch::Tensor>& viewmatrix,
445+
const std::vector<torch::Tensor>& projmatrix,
446+
const std::vector<float>& cam_pos,
447+
const std::vector<float>& tan_fovx, std::vector<float>& tan_fovy,
448+
const std::vector<bool>& prefiltered,
449+
std::vector<bool>& debug,//raster_settings
450+
const std::vector<pybind11::dict> &args)
451+
{
452+
auto [global_rank, world_size, iteration, log_interval, device, zhx_debug, zhx_time, mode, dist_division_mode, log_folder] = prepareArgs(args);
453+
char* log_tmp = new char[500];
454+
455+
// print out the environment variables
456+
if (mode == "train" && zhx_debug && iteration % log_interval == 1) {
457+
sprintf(log_tmp, "world_size: %d, global_rank: %d, iteration: %d, log_folder: %s, zhx_debug: %d, zhx_time: %d, device: %d, log_interval: %d, dist_division_mode: %s",
458+
world_size, global_rank, iteration, log_folder.c_str(), zhx_debug, zhx_time, device, log_interval, dist_division_mode.c_str());
459+
save_log_in_file(iteration, global_rank, world_size, log_folder, "cuda", log_tmp);
460+
}
461+
462+
MyTimerOnGPU timer;
463+
// const float focal_y = height / (2.0f * tan_fovy);
464+
// const float focal_x = width / (2.0f * tan_fovx);
465+
const int num_viewpoints=viewmatrix.size();
466+
467+
//CONVERT ALL VECTORS TO FLOATSSSSSS PRAPTIIIIIII
468+
469+
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, num_viewpoints);
470+
dim3 block(BLOCK_X, BLOCK_Y, num_viewpoints);
471+
int tile_num = tile_grid.x * tile_grid.y*tile_grid.z;
472+
473+
// allocate temporary buffer for tiles_touched.
474+
// In sep_rendering==True case, we will compute tiles_touched in the renderForward.
475+
// TODO: remove it later by modifying FORWARD::preprocess when we deprecate sep_rendering==False case
476+
uint32_t* tiles_touched_temp_buffer;
477+
CHECK_CUDA(cudaMalloc(&tiles_touched_temp_buffer, P * sizeof(uint32_t)), debug);
478+
CHECK_CUDA(cudaMemset(tiles_touched_temp_buffer, 0, P * sizeof(uint32_t)), debug);
479+
480+
timer.start("10 preprocess");
481+
// Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB)
482+
CHECK_CUDA(FORWARD::preprocess(
483+
P, D, M,
484+
means3D,
485+
(glm::vec3*)scales,
486+
scale_modifier,
487+
(glm::vec4*)rotations,
488+
opacities,
489+
shs,
490+
clamped,
491+
nullptr,//cov3D_precomp,
492+
nullptr,//colors_precomp,TODO: this is correct?
493+
viewmatrix, projmatrix,
494+
(glm::vec3*)cam_pos,
495+
width, height,
496+
focal_x, focal_y,
497+
tan_fovx, tan_fovy,
498+
radii,
499+
means2D,
500+
depths,
501+
cov3D,
502+
rgb,
503+
conic_opacity,
504+
tile_grid,
505+
tiles_touched_temp_buffer,
506+
prefiltered
507+
), debug)
508+
timer.stop("10 preprocess");
509+
510+
int num_rendered = 0;//TODO: should I calculate this here?
511+
512+
// Print out timing information
513+
if (zhx_time && iteration % log_interval == 1) {
514+
timer.printAllTimes(iteration, world_size, global_rank, log_folder, true);
515+
}
516+
delete log_tmp;
517+
// free temporary buffer for tiles_touched. TODO: remove it.
518+
CHECK_CUDA(cudaFree(tiles_touched_temp_buffer), debug);
519+
return num_rendered;
520+
}
521+
427522
void CudaRasterizer::Rasterizer::preprocessBackward(
428523
const int* radii,
429524
const float* cov3D,

diff_gaussian_rasterization/__init__.py

Lines changed: 81 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def preprocess_gaussians(
3131
sh,
3232
opacities,
3333
raster_settings,
34-
cuda_args,
34+
cuda_args,flag_batched=False
3535
):
3636
return _PreprocessGaussians.apply(
3737
means3D,
@@ -40,7 +40,7 @@ def preprocess_gaussians(
4040
sh,
4141
opacities,
4242
raster_settings,
43-
cuda_args,
43+
cuda_args,flag_batched
4444
)
4545

4646
class _PreprocessGaussians(torch.autograd.Function):
@@ -52,45 +52,88 @@ def forward(
5252
rotations,
5353
sh,
5454
opacities,
55-
raster_settings,
56-
cuda_args,
55+
raster_settings_list,
56+
batched_cuda_args,flag_batched
5757
):
5858

5959
# Restructure arguments the way that the C++ lib expects them
60-
args = (
61-
means3D,
62-
scales,
63-
rotations,
64-
sh,
65-
opacities,# 3dgs' parametes.
66-
raster_settings.scale_modifier,
67-
raster_settings.viewmatrix,
68-
raster_settings.projmatrix,
69-
raster_settings.tanfovx,
70-
raster_settings.tanfovy,
71-
raster_settings.image_height,
72-
raster_settings.image_width,
73-
raster_settings.sh_degree,
74-
raster_settings.campos,
75-
raster_settings.prefiltered,
76-
raster_settings.debug,#raster_settings
77-
cuda_args
78-
)
79-
80-
# TODO: update this.
81-
num_rendered, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped = _C.preprocess_gaussians(*args)
60+
if flag_batched==False:
61+
args = (
62+
means3D,
63+
scales,
64+
rotations,
65+
sh,
66+
opacities,# 3dgs' parametes.
67+
raster_settings.scale_modifier,
68+
raster_settings.viewmatrix,
69+
raster_settings.projmatrix,
70+
raster_settings.tanfovx,
71+
raster_settings.tanfovy,
72+
raster_settings.image_height,
73+
raster_settings.image_width,
74+
raster_settings.sh_degree,
75+
raster_settings.campos,
76+
raster_settings.prefiltered,
77+
raster_settings.debug,#raster_settings
78+
cuda_args
79+
)
80+
81+
# TODO: update this.
82+
num_rendered, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped = _C.preprocess_gaussians(*args)
83+
84+
# Keep relevant tensors for backward
85+
ctx.raster_settings = raster_settings
86+
ctx.cuda_args = cuda_args
87+
ctx.num_rendered = num_rendered
88+
ctx.save_for_backward(means3D, scales, rotations, sh, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped)
89+
ctx.mark_non_differentiable(radii, depths)
90+
91+
# # TODO: double check. means2D is padded to (P, 3) in python. It is (P, 2) in cuda code.
92+
# means2D_pad = torch.zeros((means2D.shape[0], 1), dtype = means2D.dtype, device = means2D.device)
93+
# means2D = torch.cat((means2D, means2D_pad), dim = 1).contiguous()
94+
return means2D, rgb, conic_opacity, radii, depths
95+
96+
else:
97+
args_list=[]
98+
for raster_settings,cuda_args in zip(raster_settings_list,batched_cuda_args):
99+
100+
args = (
101+
means3D,
102+
scales,
103+
rotations,
104+
sh,
105+
opacities,# 3dgs' parametes.
106+
raster_settings.scale_modifier,
107+
raster_settings.viewmatrix,
108+
raster_settings.projmatrix,
109+
raster_settings.tanfovx,
110+
raster_settings.tanfovy,
111+
raster_settings.image_height,
112+
raster_settings.image_width,
113+
raster_settings.sh_degree,
114+
raster_settings.campos,
115+
raster_settings.prefiltered,
116+
raster_settings.debug,#raster_settings
117+
cuda_args
118+
)
119+
args_list.append(args)
120+
121+
# TODO: update this.
122+
num_rendered, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped = _C.preprocess_gaussians_batches(*args_list)
123+
124+
# Keep relevant tensors for backward
125+
ctx.raster_settings = raster_settings_list
126+
ctx.cuda_args = batched_cuda_args
127+
ctx.num_rendered = num_rendered
128+
ctx.save_for_backward(means3D, scales, rotations, sh, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped)
129+
ctx.mark_non_differentiable(radii, depths)
130+
131+
# # TODO: double check. means2D is padded to (P, 3) in python. It is (P, 2) in cuda code.
132+
# means2D_pad = torch.zeros((means2D.shape[0], 1), dtype = means2D.dtype, device = means2D.device)
133+
# means2D = torch.cat((means2D, means2D_pad), dim = 1).contiguous()
134+
return means2D, rgb, conic_opacity, radii, depths
82135

83-
# Keep relevant tensors for backward
84-
ctx.raster_settings = raster_settings
85-
ctx.cuda_args = cuda_args
86-
ctx.num_rendered = num_rendered
87-
ctx.save_for_backward(means3D, scales, rotations, sh, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped)
88-
ctx.mark_non_differentiable(radii, depths)
89136

90-
# # TODO: double check. means2D is padded to (P, 3) in python. It is (P, 2) in cuda code.
91-
# means2D_pad = torch.zeros((means2D.shape[0], 1), dtype = means2D.dtype, device = means2D.device)
92-
# means2D = torch.cat((means2D, means2D_pad), dim = 1).contiguous()
93-
return means2D, rgb, conic_opacity, radii, depths
94137

95138
@staticmethod # TODO: gradient for conic_opacity is tricky. because cuda render backward generate dL_dconic and dL_dopacity sperately.
96139
def backward(ctx, grad_means2D, grad_rgb, grad_conic_opacity, grad_radii, grad_depths):
@@ -320,14 +363,14 @@ def markVisible(self, positions):
320363
def preprocess_gaussians(self, means3D, scales, rotations, shs, opacities, batched_cuda_args=None):
321364
# Invoke C++/CUDA rasterization routine
322365

323-
return preprocess_gaussians_batches(
366+
return preprocess_gaussians(
324367
means3D,
325368
scales,
326369
rotations,
327370
shs,
328371
opacities,
329372
self.raster_settings_list,
330-
batched_cuda_args)
373+
batched_cuda_args,True)
331374

332375
class GaussianRasterizer(nn.Module):
333376
def __init__(self, raster_settings):

ext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1717
m.def("mark_visible", &markVisible);
1818
m.def("preprocess_gaussians", &PreprocessGaussiansCUDA);
19+
m.def("preprocess_gaussians_batched", &PreprocessGaussiansCUDABatches);
1920
m.def("preprocess_gaussians_backward", &PreprocessGaussiansBackwardCUDA);
2021
m.def("get_distribution_strategy", &GetDistributionStrategyCUDA);
2122
m.def("render_gaussians", &RenderGaussiansCUDA);

rasterization_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_improved_gaussian_rasterizer():
143143

144144
rasterizer=GaussianRasterizerBatches(raster_settings=raster_settings_list)
145145
start_time = time.time()
146-
batched_means2D, batched_rgb, batched_conic_opacity, batched_radii, batched_depths = rasterizer.preprocess_gaussians_batches(
146+
batched_means2D, batched_rgb, batched_conic_opacity, batched_radii, batched_depths = rasterizer.preprocess_gaussians(
147147
means3D=means3D,
148148
scales=scales,
149149
rotations=rotations,

rasterize_points.cu

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,88 @@ PreprocessGaussiansCUDA(
142142
return std::make_tuple(rendered, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped);
143143
}
144144

145+
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
146+
PreprocessGaussiansCUDABatches(
147+
const torch::Tensor& means3D,
148+
const torch::Tensor& scales,
149+
const torch::Tensor& rotations,
150+
const torch::Tensor& sh,
151+
const torch::Tensor& opacity,//3dgs' parametes.
152+
const std::vector<float>& scale_modifier,
153+
const std::vector<torch::Tensor>& viewmatrix,
154+
const std::vector<torch::Tensor>& projmatrix,
155+
const std::vector<float>& tan_fovx,
156+
const std::vector<float>& tan_fovy,
157+
const std::vector<int>& image_height,
158+
const std::vector<int>& image_width,
159+
const std::vector<int>& degree,
160+
const std::vector<torch::Tensor>& campos,
161+
const std::vector<bool>& prefiltered,//raster_settings
162+
const std::vector<bool>& debug,
163+
const std::vector<pybind11::dict> &args) {
164+
165+
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
166+
AT_ERROR("means3D must have dimensions (num_points, 3)");
167+
}
168+
169+
const int P = means3D.size(0);
170+
// const int H = image_height;
171+
// const int W = image_width;
172+
173+
// of shape (P, 2). means2D is (P, 2) in cuda. It will be converted to (P, 3) when is sent back to python to meet torch graph's requirement.
174+
torch::Tensor means2D = torch::full({P, 2}, 0.0, means3D.options());//TODO: what about require_grads?
175+
// of shape (P)
176+
torch::Tensor depths = torch::full({P}, 0.0, means3D.options());
177+
// of shape (P)
178+
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
179+
// of shape (P, 6)
180+
torch::Tensor cov3D = torch::full({P, 6}, 0.0, means3D.options());
181+
// of shape (P, 4)
182+
torch::Tensor conic_opacity = torch::full({P, 4}, 0.0, means3D.options());
183+
// of shape (P, 3)
184+
torch::Tensor rgb = torch::full({P, 3}, 0.0, means3D.options());
185+
// of shape (P)
186+
torch::Tensor clamped = torch::full({P, 3}, false, means3D.options().dtype(at::kBool));
187+
//TODO: compare to original GeometryState implementation, this one does not explicitly do gpu memory alignment.
188+
//That may lead to problems. However, pytorch does implicit memory alignment.
189+
190+
int rendered = 0;//TODO: I could compute rendered here by summing up geomState.tiles_touched.
191+
if(P != 0)
192+
{
193+
int M = 0;
194+
if(sh.size(0) != 0)
195+
{
196+
M = sh.size(1);
197+
}
198+
199+
rendered = CudaRasterizer::Rasterizer::preprocessForwardBatches(
200+
reinterpret_cast<float2*>(means2D.contiguous().data<float>()),//TODO: check whether it supports float2?
201+
depths.contiguous().data<float>(),
202+
radii.contiguous().data<int>(),
203+
cov3D.contiguous().data<float>(),
204+
reinterpret_cast<float4*>(conic_opacity.contiguous().data<float>()),
205+
rgb.contiguous().data<float>(),
206+
clamped.contiguous().data<bool>(),
207+
P, degree, M,
208+
image_width, image_height,
209+
means3D,
210+
scales,
211+
rotations,
212+
sh,
213+
opacity,
214+
scale_modifier,
215+
viewmatrix,
216+
projmatrix,
217+
campos,
218+
tan_fovx,
219+
tan_fovy,
220+
prefiltered,
221+
debug,
222+
args);
223+
}
224+
return std::make_tuple(rendered, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped);
225+
}
226+
145227

146228
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
147229
PreprocessGaussiansBackwardCUDA(

0 commit comments

Comments
 (0)