@@ -181,7 +181,7 @@ template <int BS>
181
181
__global__ void GetOutIndices (const int * flags,
182
182
const int n,
183
183
const int * offsets,
184
- const int * out_nnz,
184
+ const int out_nnz,
185
185
int * out) {
186
186
int tid = threadIdx.x + blockDim.x * blockIdx.x ;
187
187
__shared__ int block_counts[BS];
@@ -219,21 +219,19 @@ __global__ void GetOutIndices(const int* flags,
219
219
__syncthreads ();
220
220
// write to block_outs
221
221
int start = offsets[blockIdx.x ];
222
- int end = blockIdx.x == gridDim.x - 1 ? out_nnz[ 0 ] : offsets[blockIdx.x + 1 ];
222
+ int end = blockIdx.x == gridDim.x - 1 ? out_nnz : offsets[blockIdx.x + 1 ];
223
223
for (int i = threadIdx.x ; i < end - start; i += blockDim.x ) {
224
224
out[start + i] = block_outs[i];
225
225
}
226
226
}
227
227
228
228
template <typename IntT>
229
229
__global__ void GroupIndices (const int * out_index_table,
230
- const int * rulebook_len_ptr ,
230
+ const int n ,
231
231
const int kernel_size,
232
232
IntT* out_indices,
233
233
int * out_index_counts,
234
234
int * out_index_groups) {
235
- int n = rulebook_len_ptr[0 ] / 2 ;
236
- out_indices = out_indices + n;
237
235
CUDA_KERNEL_LOOP_TYPE (i, n, int64_t ) {
238
236
IntT index = out_indices[i];
239
237
int real_index = out_index_table[index ];
@@ -248,12 +246,11 @@ __global__ void GroupIndices(const int* out_index_table,
248
246
249
247
template <typename IntT>
250
248
__global__ void GetOutIndexTable (int * indices,
251
- const int * non_zero_num_ptr ,
249
+ const int non_zero_num ,
252
250
const Dims4D out_dims,
253
251
const bool is2D,
254
252
int * out_index_table,
255
253
IntT* out_indices) {
256
- int non_zero_num = non_zero_num_ptr[0 ];
257
254
CUDA_KERNEL_LOOP_TYPE (i, non_zero_num, int64_t ) {
258
255
IntT index = static_cast <IntT>(indices[i]);
259
256
out_index_table[index ] = i;
@@ -463,6 +460,18 @@ int ProductRuleBookWithBuffer(const Context& dev_ctx,
463
460
sizeof (int ),
464
461
gpuMemcpyDeviceToDevice,
465
462
dev_ctx.stream ());
463
+ phi::backends::gpu::GpuMemcpyAsync (h_buffer,
464
+ d_buffer.data <int >(),
465
+ (2 * kernel_size + 3 ) * sizeof (int ),
466
+ gpuMemcpyDeviceToHost,
467
+ dev_ctx.stream ());
468
+
469
+ dev_ctx.Wait ();
470
+ int rulebook_len = h_buffer[2 * kernel_size + 1 ] / 2 ;
471
+ int out_nnz = h_buffer[2 * kernel_size + 2 ];
472
+
473
+ rulebook->Resize ({rulebook_rows, static_cast <int >(rulebook_len)});
474
+ out_index->Resize ({static_cast <int >(rulebook_len)});
466
475
467
476
const int threads = 256 ;
468
477
const int blocks = (index_flags->numel () + threads - 1 ) / threads;
@@ -493,57 +502,43 @@ int ProductRuleBookWithBuffer(const Context& dev_ctx,
493
502
<<<blocks, threads, 0 , dev_ctx.stream ()>>>(index_flags_ptr,
494
503
index_flags->numel (),
495
504
out_index_table_ptr,
496
- unique_key_ptr ,
505
+ out_nnz ,
497
506
out_index_ptr);
498
507
499
508
const int64_t sparse_dim = is2D ? 3 : 4 ;
500
509
phi::DenseTensor out_indices =
501
- phi::Empty<IntT>(dev_ctx, {sparse_dim, static_cast <int >(max_nnz)});
502
- phi::DenseTensor out_values = phi::Empty<T>(
503
- dev_ctx, {static_cast <int >(max_nnz), kernel_sizes[sparse_dim]});
510
+ phi::Empty<IntT>(dev_ctx, {sparse_dim, out_nnz});
511
+
512
+ phi::DenseTensor out_values =
513
+ phi::Empty<T>(dev_ctx, {out_nnz, kernel_sizes[sparse_dim]});
514
+ out->SetMember (out_indices, out_values, out_dims, false );
504
515
505
516
IntT* out_indices_ptr = out_indices.data <IntT>();
506
517
507
- config = phi::backends::gpu::GetGpuLaunchConfig1D (dev_ctx, max_nnz , 1 );
518
+ config = phi::backends::gpu::GetGpuLaunchConfig1D (dev_ctx, out_nnz , 1 );
508
519
GetOutIndexTable<IntT>
509
520
<<<config.block_per_grid , config.thread_per_block , 0 , dev_ctx.stream ()>>>(
510
521
out_index_ptr,
511
- unique_key_ptr ,
522
+ out_nnz ,
512
523
d_out_dims,
513
524
is2D,
514
525
out_index_table_ptr,
515
526
out_indices_ptr);
516
527
517
- config = phi::backends::gpu::GetGpuLaunchConfig1D (
518
- dev_ctx, static_cast <int >(max_nnz), 1 );
519
- unique_value->ResizeAndAllocate ({static_cast <int >(max_nnz * kernel_size)});
528
+ config = phi::backends::gpu::GetGpuLaunchConfig1D (dev_ctx, rulebook_len, 1 );
529
+ unique_value->ResizeAndAllocate ({static_cast <int >(out_nnz * kernel_size)});
520
530
int * unique_value_ptr = unique_value->data <int >();
521
531
522
532
GroupIndices<<<config.block_per_grid ,
523
533
config.thread_per_block ,
524
534
0 ,
525
535
dev_ctx.stream ()>>>(out_index_table_ptr,
526
- rulebook_len_tensor. data < int >() ,
536
+ rulebook_len ,
527
537
kernel_size,
528
- rulebook_ptr,
538
+ rulebook_ptr + rulebook_len ,
529
539
out_index_ptr,
530
540
unique_value_ptr);
531
541
532
- phi::backends::gpu::GpuMemcpyAsync (h_buffer,
533
- d_buffer.data <int >(),
534
- (2 * kernel_size + 3 ) * sizeof (int ),
535
- gpuMemcpyDeviceToHost,
536
- dev_ctx.stream ());
537
- dev_ctx.Wait ();
538
- int rulebook_len = h_buffer[2 * kernel_size + 1 ] / 2 ;
539
- int out_nnz = h_buffer[2 * kernel_size + 2 ];
540
- rulebook->Resize ({rulebook_rows, static_cast <int >(rulebook_len)});
541
- out_index->Resize ({static_cast <int >(rulebook_len)});
542
- out_indices.Resize ({sparse_dim, static_cast <int >(out_nnz)});
543
- unique_value->Resize (
544
- {static_cast <int >(static_cast <int >(out_nnz) * kernel_size)});
545
- out_values.Resize ({out_nnz, kernel_sizes[sparse_dim]});
546
- out->SetMember (out_indices, out_values, out_dims, false );
547
542
return rulebook_len;
548
543
}
549
544
} // namespace sparse
0 commit comments