Skip to content

Commit 12f2dcc

Browse files
authored
sparse conv fix bug (#72404)
1 parent 2eaa787 commit 12f2dcc

File tree

4 files changed

+29
-128
lines changed

4 files changed

+29
-128
lines changed

cmake/operators.cmake

-1
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,6 @@ function(prune_pybind_h)
705705
list(APPEND op_list "fusion_seqconv_eltadd_relu")
706706
list(APPEND op_list "fusion_seqpool_cvm_concat")
707707
list(APPEND op_list "fusion_gru")
708-
list(APPEND op_list "fusion_seqexpand_concat_fc")
709708
list(APPEND op_list "fusion_repeated_fc_relu")
710709
list(APPEND op_list "fusion_squared_mat_sub")
711710

paddle/phi/kernels/sparse/gpu/conv.cu.h

-94
Original file line numberDiff line numberDiff line change
@@ -207,73 +207,6 @@ __global__ void UniqueKernel(const IntT* in_indices,
207207
}
208208
}
209209

210-
template <int BS>
211-
__global__ void GetOutIndices(const int* flags,
212-
const int n,
213-
const int* offsets,
214-
const int out_nnz,
215-
int* out) {
216-
int tid = threadIdx.x + blockDim.x * blockIdx.x;
217-
__shared__ int block_counts[BS];
218-
__shared__ int block_outs[BS * 32];
219-
220-
int count = 0;
221-
222-
if (tid < n) {
223-
// get the count of 1 in flags[tid]
224-
int flag = flags[tid];
225-
count = BitCount(static_cast<uint32_t>(flag));
226-
}
227-
228-
// call block prefix_sum
229-
// using namespace cub;
230-
typedef cub::BlockScan<int, BS> BlockScan;
231-
__shared__ typename BlockScan::TempStorage temp_storage;
232-
BlockScan(temp_storage).ExclusiveSum(count, count);
233-
__syncthreads();
234-
235-
// write index to out
236-
if (tid < n) {
237-
// get the count of 1 in flags[tid]
238-
int flag = flags[tid];
239-
// int j = block_counts[threadIdx.x];
240-
int j = count;
241-
// TODO(zhangkaihuo): opt the loop
242-
for (int i = 0; i < 32; ++i) {
243-
if ((1 & (flag >> i)) == 1) {
244-
block_outs[j++] = (tid << 5) + i;
245-
}
246-
}
247-
}
248-
249-
__syncthreads();
250-
// write to block_outs
251-
int start = offsets[blockIdx.x];
252-
int end = blockIdx.x == gridDim.x - 1 ? out_nnz : offsets[blockIdx.x + 1];
253-
for (int i = threadIdx.x; i < end - start; i += blockDim.x) {
254-
out[start + i] = block_outs[i];
255-
}
256-
}
257-
258-
template <typename IntT>
259-
__global__ void GroupIndices(const int* out_index_table,
260-
const int n,
261-
const int kernel_size,
262-
IntT* out_indices,
263-
int* out_index_counts,
264-
int* out_index_groups) {
265-
CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
266-
IntT index = out_indices[i];
267-
int real_index = out_index_table[index];
268-
out_indices[i] = real_index;
269-
270-
// kernel_size at most
271-
int j = atomicAdd(out_index_counts + real_index, 1);
272-
// nnz * kernel_size
273-
out_index_groups[real_index * kernel_size + j] = i;
274-
}
275-
}
276-
277210
template <typename IntT>
278211
__global__ void GetOutIndexTable1(const IntT* indices,
279212
const IntT non_zero_num,
@@ -294,33 +227,6 @@ __global__ void GetOutIndexTable1(const IntT* indices,
294227
}
295228
}
296229

297-
template <typename IntT>
298-
__global__ void GetOutIndexTable(int* indices,
299-
const int non_zero_num,
300-
const Dims4D out_dims,
301-
const bool is2D,
302-
int* out_index_table,
303-
IntT* out_indices) {
304-
CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) {
305-
IntT index = static_cast<IntT>(indices[i]);
306-
out_index_table[index] = i;
307-
IntT batch, x, y, z;
308-
phi::funcs::sparse::IndexToPoint<Dims4D>(
309-
index, out_dims, &batch, &x, &y, &z);
310-
// get out indices
311-
out_indices[i] = batch;
312-
if (is2D) {
313-
out_indices[i + non_zero_num] = y;
314-
out_indices[i + non_zero_num * 2] = x;
315-
} else {
316-
out_indices[i + non_zero_num] = z;
317-
out_indices[i + non_zero_num * 2] = y;
318-
out_indices[i + non_zero_num * 3] = x;
319-
}
320-
indices[i] = 0;
321-
}
322-
}
323-
324230
template <typename IntT>
325231
__global__ void CopyRuleBook(const int* counters,
326232
const int* offsets,

paddle/phi/kernels/sparse/gpu/conv_with_buffer.cu.h

+28-33
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ template <int BS>
181181
__global__ void GetOutIndices(const int* flags,
182182
const int n,
183183
const int* offsets,
184-
const int* out_nnz,
184+
const int out_nnz,
185185
int* out) {
186186
int tid = threadIdx.x + blockDim.x * blockIdx.x;
187187
__shared__ int block_counts[BS];
@@ -219,21 +219,19 @@ __global__ void GetOutIndices(const int* flags,
219219
__syncthreads();
220220
// write to block_outs
221221
int start = offsets[blockIdx.x];
222-
int end = blockIdx.x == gridDim.x - 1 ? out_nnz[0] : offsets[blockIdx.x + 1];
222+
int end = blockIdx.x == gridDim.x - 1 ? out_nnz : offsets[blockIdx.x + 1];
223223
for (int i = threadIdx.x; i < end - start; i += blockDim.x) {
224224
out[start + i] = block_outs[i];
225225
}
226226
}
227227

228228
template <typename IntT>
229229
__global__ void GroupIndices(const int* out_index_table,
230-
const int* rulebook_len_ptr,
230+
const int n,
231231
const int kernel_size,
232232
IntT* out_indices,
233233
int* out_index_counts,
234234
int* out_index_groups) {
235-
int n = rulebook_len_ptr[0] / 2;
236-
out_indices = out_indices + n;
237235
CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
238236
IntT index = out_indices[i];
239237
int real_index = out_index_table[index];
@@ -248,12 +246,11 @@ __global__ void GroupIndices(const int* out_index_table,
248246

249247
template <typename IntT>
250248
__global__ void GetOutIndexTable(int* indices,
251-
const int* non_zero_num_ptr,
249+
const int non_zero_num,
252250
const Dims4D out_dims,
253251
const bool is2D,
254252
int* out_index_table,
255253
IntT* out_indices) {
256-
int non_zero_num = non_zero_num_ptr[0];
257254
CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) {
258255
IntT index = static_cast<IntT>(indices[i]);
259256
out_index_table[index] = i;
@@ -463,6 +460,18 @@ int ProductRuleBookWithBuffer(const Context& dev_ctx,
463460
sizeof(int),
464461
gpuMemcpyDeviceToDevice,
465462
dev_ctx.stream());
463+
phi::backends::gpu::GpuMemcpyAsync(h_buffer,
464+
d_buffer.data<int>(),
465+
(2 * kernel_size + 3) * sizeof(int),
466+
gpuMemcpyDeviceToHost,
467+
dev_ctx.stream());
468+
469+
dev_ctx.Wait();
470+
int rulebook_len = h_buffer[2 * kernel_size + 1] / 2;
471+
int out_nnz = h_buffer[2 * kernel_size + 2];
472+
473+
rulebook->Resize({rulebook_rows, static_cast<int>(rulebook_len)});
474+
out_index->Resize({static_cast<int>(rulebook_len)});
466475

467476
const int threads = 256;
468477
const int blocks = (index_flags->numel() + threads - 1) / threads;
@@ -493,57 +502,43 @@ int ProductRuleBookWithBuffer(const Context& dev_ctx,
493502
<<<blocks, threads, 0, dev_ctx.stream()>>>(index_flags_ptr,
494503
index_flags->numel(),
495504
out_index_table_ptr,
496-
unique_key_ptr,
505+
out_nnz,
497506
out_index_ptr);
498507

499508
const int64_t sparse_dim = is2D ? 3 : 4;
500509
phi::DenseTensor out_indices =
501-
phi::Empty<IntT>(dev_ctx, {sparse_dim, static_cast<int>(max_nnz)});
502-
phi::DenseTensor out_values = phi::Empty<T>(
503-
dev_ctx, {static_cast<int>(max_nnz), kernel_sizes[sparse_dim]});
510+
phi::Empty<IntT>(dev_ctx, {sparse_dim, out_nnz});
511+
512+
phi::DenseTensor out_values =
513+
phi::Empty<T>(dev_ctx, {out_nnz, kernel_sizes[sparse_dim]});
514+
out->SetMember(out_indices, out_values, out_dims, false);
504515

505516
IntT* out_indices_ptr = out_indices.data<IntT>();
506517

507-
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, max_nnz, 1);
518+
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_nnz, 1);
508519
GetOutIndexTable<IntT>
509520
<<<config.block_per_grid, config.thread_per_block, 0, dev_ctx.stream()>>>(
510521
out_index_ptr,
511-
unique_key_ptr,
522+
out_nnz,
512523
d_out_dims,
513524
is2D,
514525
out_index_table_ptr,
515526
out_indices_ptr);
516527

517-
config = phi::backends::gpu::GetGpuLaunchConfig1D(
518-
dev_ctx, static_cast<int>(max_nnz), 1);
519-
unique_value->ResizeAndAllocate({static_cast<int>(max_nnz * kernel_size)});
528+
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
529+
unique_value->ResizeAndAllocate({static_cast<int>(out_nnz * kernel_size)});
520530
int* unique_value_ptr = unique_value->data<int>();
521531

522532
GroupIndices<<<config.block_per_grid,
523533
config.thread_per_block,
524534
0,
525535
dev_ctx.stream()>>>(out_index_table_ptr,
526-
rulebook_len_tensor.data<int>(),
536+
rulebook_len,
527537
kernel_size,
528-
rulebook_ptr,
538+
rulebook_ptr + rulebook_len,
529539
out_index_ptr,
530540
unique_value_ptr);
531541

532-
phi::backends::gpu::GpuMemcpyAsync(h_buffer,
533-
d_buffer.data<int>(),
534-
(2 * kernel_size + 3) * sizeof(int),
535-
gpuMemcpyDeviceToHost,
536-
dev_ctx.stream());
537-
dev_ctx.Wait();
538-
int rulebook_len = h_buffer[2 * kernel_size + 1] / 2;
539-
int out_nnz = h_buffer[2 * kernel_size + 2];
540-
rulebook->Resize({rulebook_rows, static_cast<int>(rulebook_len)});
541-
out_index->Resize({static_cast<int>(rulebook_len)});
542-
out_indices.Resize({sparse_dim, static_cast<int>(out_nnz)});
543-
unique_value->Resize(
544-
{static_cast<int>(static_cast<int>(out_nnz) * kernel_size)});
545-
out_values.Resize({out_nnz, kernel_sizes[sparse_dim]});
546-
out->SetMember(out_indices, out_values, out_dims, false);
547542
return rulebook_len;
548543
}
549544
} // namespace sparse

paddle/phi/kernels/sparse/gpu/sparse_conv_hashmap.cuh

+1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ __global__ void lookup_coords_kernel(
118118
{
119119
int tidx = blockIdx.x * blockDim.x + threadIdx.x;
120120
int idx = tidx / kernel_volume;
121+
if (idx >= n) return;
121122
int _kernel_idx = tidx % kernel_volume;
122123
int kernel_idx = _kernel_idx;
123124
const int* in_coords = coords + _width * idx;

0 commit comments

Comments
 (0)