Skip to content

Commit b7b08ba

Browse files
author
Prapti Devansh Trivedi
committed
add mock of improved preproc
1 parent 6aed776 commit b7b08ba

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

diff_gaussian_rasterization/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,31 @@ class GaussianRasterizationSettings(NamedTuple):
304304
prefiltered : bool
305305
debug : bool
306306

307+
class GaussianRasterizerBatches(nn.Module):
308+
def __init__(self, raster_settings):
309+
super().__init__()
310+
self.raster_settings_list = raster_settings
311+
312+
def markVisible(self, positions):
313+
# Mark visible points (based on frustum culling for camera) with a boolean
314+
with torch.no_grad():
315+
visible = []
316+
for viewmatrix, projmatrix in zip(self.raster_settings.viewmatrix, self.raster_settings.projmatrix):
317+
visible.append(_C.mark_visible(positions, viewmatrix, projmatrix))
318+
return visible
319+
320+
def preprocess_gaussians(self, means3D, scales, rotations, shs, opacities, batched_cuda_args=None):
321+
# Invoke C++/CUDA rasterization routine
322+
323+
return preprocess_gaussians_batches(
324+
means3D,
325+
scales,
326+
rotations,
327+
shs,
328+
opacities,
329+
self.raster_settings_list,
330+
batched_cuda_args)
331+
307332
class GaussianRasterizer(nn.Module):
308333
def __init__(self, raster_settings):
309334
super().__init__()

rasterization_tests.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,88 @@ def test_gaussian_rasterizer_time():
7474
preprocess_time = end_time - start_time
7575
print(f"Time taken by preprocess_gaussians: {preprocess_time:.4f} seconds")
7676

77+
def test_improved_gaussian_rasterizer():
78+
79+
# Set up the input data
80+
num_gaussians = 10000
81+
num_batches = 4
82+
means3D = torch.randn(num_gaussians, 3).cuda()
83+
scales = torch.randn(num_gaussians, 3).cuda()
84+
rotations = torch.randn(num_gaussians, 3, 3).cuda()
85+
shs = torch.randn(num_gaussians, 9).cuda()
86+
opacity = torch.randn(num_gaussians, 1).cuda()
87+
88+
# Set up the viewpoint cameras
89+
batched_viewpoint_cameras = []
90+
for _ in range(num_batches):
91+
viewpoint_camera = type('ViewpointCamera', (), {})
92+
viewpoint_camera.FoVx = math.radians(60)
93+
viewpoint_camera.FoVy = math.radians(60)
94+
viewpoint_camera.image_height = 512
95+
viewpoint_camera.image_width = 512
96+
viewpoint_camera.world_view_transform = torch.eye(4).cuda()
97+
viewpoint_camera.full_proj_transform = torch.eye(4).cuda()
98+
viewpoint_camera.camera_center = torch.zeros(3).cuda()
99+
batched_viewpoint_cameras.append(viewpoint_camera)
100+
101+
# Set up the strategies
102+
batched_strategies = [None] * num_batches
103+
104+
# Set up other parameters
105+
bg_color = torch.ones(3).cuda()
106+
scaling_modifier = 1.0
107+
pc = type('PC', (), {})
108+
pc.active_sh_degree = 2
109+
pipe = type('Pipe', (), {})
110+
pipe.debug = False
111+
mode = "train"
112+
113+
batched_rasterizers = []
114+
batched_cuda_args = []
115+
batched_screenspace_params = []
116+
batched_means2D = []
117+
batched_radii = []
118+
raster_settings_list=[]
119+
for i, (viewpoint_camera, strategy) in enumerate(zip(batched_viewpoint_cameras, batched_strategies)):
120+
########## [START] Prepare CUDA Rasterization Settings ##########
121+
cuda_args = get_cuda_args(strategy, mode)
122+
batched_cuda_args.append(cuda_args)
123+
124+
# Set up rasterization configuration
125+
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
126+
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
127+
128+
raster_settings_list.append(GaussianRasterizationSettings(
129+
image_height=int(viewpoint_camera.image_height),
130+
image_width=int(viewpoint_camera.image_width),
131+
tanfovx=tanfovx,
132+
tanfovy=tanfovy,
133+
bg=bg_color,
134+
scale_modifier=scaling_modifier,
135+
viewmatrix=viewpoint_camera.world_view_transform,
136+
projmatrix=viewpoint_camera.full_proj_transform,
137+
sh_degree=pc.active_sh_degree,
138+
campos=viewpoint_camera.camera_center,
139+
prefiltered=False,
140+
debug=pipe.debug
141+
))
142+
143+
144+
rasterizer=GaussianRasterizerBatches(raster_settings=raster_settings_list)
145+
start_time = time.time()
146+
batched_means2D, batched_rgb, batched_conic_opacity, batched_radii, batched_depths = rasterizer.preprocess_gaussians_batches(
147+
means3D=means3D,
148+
scales=scales,
149+
rotations=rotations,
150+
shs=shs,
151+
opacities=opacity,
152+
cuda_args=batched_cuda_args
153+
)
154+
end_time = time.time()
155+
156+
preprocess_time = end_time - start_time
157+
print(f"Time taken by preprocess_gaussians: {preprocess_time:.4f} seconds")
158+
77159

78160
def test_batched_gaussian_rasterizer():
79161
# Set up the input data
@@ -163,5 +245,7 @@ def test_batched_gaussian_rasterizer():
163245
# Perform further operations with the batched results
164246
# ...
165247

248+
249+
166250
if __name__ == "__main__":
167251
test_gaussian_rasterizer_time()

0 commit comments

Comments
 (0)