@@ -74,6 +74,88 @@ def test_gaussian_rasterizer_time():
74
74
preprocess_time = end_time - start_time
75
75
print (f"Time taken by preprocess_gaussians: { preprocess_time :.4f} seconds" )
76
76
77
+ def test_improved_gaussian_rasterizer ():
78
+
79
+ # Set up the input data
80
+ num_gaussians = 10000
81
+ num_batches = 4
82
+ means3D = torch .randn (num_gaussians , 3 ).cuda ()
83
+ scales = torch .randn (num_gaussians , 3 ).cuda ()
84
+ rotations = torch .randn (num_gaussians , 3 , 3 ).cuda ()
85
+ shs = torch .randn (num_gaussians , 9 ).cuda ()
86
+ opacity = torch .randn (num_gaussians , 1 ).cuda ()
87
+
88
+ # Set up the viewpoint cameras
89
+ batched_viewpoint_cameras = []
90
+ for _ in range (num_batches ):
91
+ viewpoint_camera = type ('ViewpointCamera' , (), {})
92
+ viewpoint_camera .FoVx = math .radians (60 )
93
+ viewpoint_camera .FoVy = math .radians (60 )
94
+ viewpoint_camera .image_height = 512
95
+ viewpoint_camera .image_width = 512
96
+ viewpoint_camera .world_view_transform = torch .eye (4 ).cuda ()
97
+ viewpoint_camera .full_proj_transform = torch .eye (4 ).cuda ()
98
+ viewpoint_camera .camera_center = torch .zeros (3 ).cuda ()
99
+ batched_viewpoint_cameras .append (viewpoint_camera )
100
+
101
+ # Set up the strategies
102
+ batched_strategies = [None ] * num_batches
103
+
104
+ # Set up other parameters
105
+ bg_color = torch .ones (3 ).cuda ()
106
+ scaling_modifier = 1.0
107
+ pc = type ('PC' , (), {})
108
+ pc .active_sh_degree = 2
109
+ pipe = type ('Pipe' , (), {})
110
+ pipe .debug = False
111
+ mode = "train"
112
+
113
+ batched_rasterizers = []
114
+ batched_cuda_args = []
115
+ batched_screenspace_params = []
116
+ batched_means2D = []
117
+ batched_radii = []
118
+ raster_settings_list = []
119
+ for i , (viewpoint_camera , strategy ) in enumerate (zip (batched_viewpoint_cameras , batched_strategies )):
120
+ ########## [START] Prepare CUDA Rasterization Settings ##########
121
+ cuda_args = get_cuda_args (strategy , mode )
122
+ batched_cuda_args .append (cuda_args )
123
+
124
+ # Set up rasterization configuration
125
+ tanfovx = math .tan (viewpoint_camera .FoVx * 0.5 )
126
+ tanfovy = math .tan (viewpoint_camera .FoVy * 0.5 )
127
+
128
+ raster_settings_list .append (GaussianRasterizationSettings (
129
+ image_height = int (viewpoint_camera .image_height ),
130
+ image_width = int (viewpoint_camera .image_width ),
131
+ tanfovx = tanfovx ,
132
+ tanfovy = tanfovy ,
133
+ bg = bg_color ,
134
+ scale_modifier = scaling_modifier ,
135
+ viewmatrix = viewpoint_camera .world_view_transform ,
136
+ projmatrix = viewpoint_camera .full_proj_transform ,
137
+ sh_degree = pc .active_sh_degree ,
138
+ campos = viewpoint_camera .camera_center ,
139
+ prefiltered = False ,
140
+ debug = pipe .debug
141
+ ))
142
+
143
+
144
+ rasterizer = GaussianRasterizerBatches (raster_settings = raster_settings_list )
145
+ start_time = time .time ()
146
+ batched_means2D , batched_rgb , batched_conic_opacity , batched_radii , batched_depths = rasterizer .preprocess_gaussians_batches (
147
+ means3D = means3D ,
148
+ scales = scales ,
149
+ rotations = rotations ,
150
+ shs = shs ,
151
+ opacities = opacity ,
152
+ cuda_args = batched_cuda_args
153
+ )
154
+ end_time = time .time ()
155
+
156
+ preprocess_time = end_time - start_time
157
+ print (f"Time taken by preprocess_gaussians: { preprocess_time :.4f} seconds" )
158
+
77
159
78
160
def test_batched_gaussian_rasterizer ():
79
161
# Set up the input data
@@ -163,5 +245,7 @@ def test_batched_gaussian_rasterizer():
163
245
# Perform further operations with the batched results
164
246
# ...
165
247
248
+
249
+
166
250
if __name__ == "__main__" :
167
251
test_gaussian_rasterizer_time ()
0 commit comments