Skip to content

Commit 74ceb80

Browse files
committed
parallel coo slice in gpu
1 parent 428dc2e commit 74ceb80

File tree

2 files changed

+63
-40
lines changed

2 files changed

+63
-40
lines changed

paddle/phi/kernels/sparse/gpu/slice_kernel.cu

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <thrust/device_vector.h>
16+
#include <thrust/sort.h>
17+
1518
#include "paddle/phi/kernels/sparse/unary_kernel.h"
1619

1720
#include "paddle/phi/backends/gpu/gpu_context.h"
@@ -32,7 +35,8 @@ __global__ void GetCooNonZeroNumberCudaKernel(const int64_t* x_indices_data,
3235
const int64_t* ends,
3336
const int64_t axes_size,
3437
const int64_t x_nnz,
35-
int* out_nnz) {
38+
int* out_nnz,
39+
int64_t* out_nnz_indices) {
3640
CUDA_KERNEL_LOOP_TYPE(j, x_nnz, int64_t) {
3741
bool hit = true;
3842
for (size_t ii = 0; ii < axes_size; ++ii) {
@@ -43,7 +47,8 @@ __global__ void GetCooNonZeroNumberCudaKernel(const int64_t* x_indices_data,
4347
}
4448
}
4549
if (!hit) continue;
46-
atomicAdd(out_nnz, 1);
50+
int old_value = atomicAdd(out_nnz, 1);
51+
out_nnz_indices[old_value] = j;
4752
}
4853
}
4954

@@ -52,37 +57,27 @@ __global__ void GetCooOutCudaKernel(const int64_t* x_indices_data,
5257
const T* x_values_data,
5358
const int64_t* axes,
5459
const int64_t* starts,
55-
const int64_t* ends,
5660
const int64_t axes_size,
5761
const int64_t sparse_dim,
5862
const int64_t x_nnz,
5963
const int64_t out_nnz,
64+
const int64_t* out_nnz_indices,
6065
int64_t* out_indices_data,
6166
T* out_values_data) {
62-
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
63-
if (tid == 0) {
64-
int64_t index = 0;
65-
for (int64_t j = 0; j < x_nnz && index < out_nnz; ++j) {
66-
bool hit = true;
67-
for (size_t ii = 0; ii < axes_size; ++ii) {
68-
auto item = x_indices_data[axes[ii] * x_nnz + j];
69-
if (!(starts[ii] <= item && item < ends[ii])) {
70-
hit = false;
71-
break;
72-
}
73-
}
74-
if (!hit) continue;
75-
// set value
76-
out_values_data[index] = x_values_data[j];
77-
// set coordinate
78-
for (int64_t i = 0; i < sparse_dim; ++i) {
79-
out_indices_data[i * out_nnz + index] = x_indices_data[i * x_nnz + j];
80-
}
81-
for (size_t ii = 0; ii < axes_size; ++ii) {
82-
auto i = axes[ii];
83-
out_indices_data[i * out_nnz + index] -= starts[ii];
84-
}
85-
index++;
67+
CUDA_KERNEL_LOOP_TYPE(index, out_nnz, int64_t) {
68+
// index is in the order of the non-zero elements in out
69+
// out_nnz_indices[index] is the valid index in x's non-zero elements, where
70+
// the `hit` is true.
71+
int64_t j = out_nnz_indices[index];
72+
// set value
73+
out_values_data[index] = x_values_data[j];
74+
// set coordinate
75+
for (int64_t i = 0; i < sparse_dim; ++i) {
76+
out_indices_data[i * out_nnz + index] = x_indices_data[i * x_nnz + j];
77+
}
78+
for (size_t ii = 0; ii < axes_size; ++ii) {
79+
auto i = axes[ii];
80+
out_indices_data[i * out_nnz + index] -= starts[ii];
8681
}
8782
}
8883
}
@@ -113,6 +108,13 @@ void SliceCooKernel(const Context& dev_ctx,
113108
phi::backends::gpu::GpuMemsetAsync(
114109
d_out_nnz_ptr, 0, sizeof(int32_t), dev_ctx.stream());
115110

111+
// out_nnz_indices is the indices where the data is valid in out
112+
// the length of the out_nnz_indices must be less than x.nnz()
113+
DenseTensor d_out_nnz_indices = phi::Empty<int64_t>(dev_ctx, {x.nnz()});
114+
int64_t* d_out_nnz_indices_ptr = d_out_nnz_indices.data<int64_t>();
115+
phi::backends::gpu::GpuMemsetAsync(
116+
d_out_nnz_indices_ptr, 0, sizeof(int64_t), dev_ctx.stream());
117+
116118
// copy axes to device
117119
auto d_axes_tensor = memory_utils::Alloc(
118120
dev_ctx.GetPlace(),
@@ -164,14 +166,27 @@ void SliceCooKernel(const Context& dev_ctx,
164166
d_ends,
165167
axes.size(),
166168
x.nnz(),
167-
d_out_nnz_ptr);
169+
d_out_nnz_ptr,
170+
d_out_nnz_indices_ptr);
168171

172+
// copy d_out_nnz from device to host (out_nnz)
169173
int32_t out_nnz = 0;
170174
phi::backends::gpu::GpuMemcpyAsync(&out_nnz,
171175
d_out_nnz_ptr,
172176
sizeof(int32_t),
173177
gpuMemcpyDeviceToHost,
174178
dev_ctx.stream());
179+
// sort `d_out_nnz_indices_ptr`
180+
d_out_nnz_indices.Resize({out_nnz});
181+
thrust::device_vector<int64_t> d_out_nnz_indices_vec(
182+
d_out_nnz_indices_ptr, d_out_nnz_indices_ptr + out_nnz);
183+
thrust::sort(d_out_nnz_indices_vec.begin(), d_out_nnz_indices_vec.end());
184+
phi::backends::gpu::GpuMemcpyAsync(
185+
d_out_nnz_indices_ptr,
186+
thrust::raw_pointer_cast(d_out_nnz_indices_vec.data()),
187+
out_nnz * sizeof(int64_t),
188+
gpuMemcpyDeviceToDevice,
189+
dev_ctx.stream());
175190

176191
// Step4: Get the values and indices of output
177192
auto sparse_dim = static_cast<int64_t>(x.sparse_dim());
@@ -184,18 +199,21 @@ void SliceCooKernel(const Context& dev_ctx,
184199
auto* out_values_data = out_values.data<T>();
185200
const auto* x_values_data = x.values().data<T>();
186201

187-
GetCooOutCudaKernel<T>
188-
<<<1, 1, 0, dev_ctx.stream()>>>(x_indices_data,
189-
x_values_data,
190-
d_axes,
191-
d_starts,
192-
d_ends,
193-
axes.size(),
194-
sparse_dim,
195-
x.nnz(),
196-
static_cast<int64_t>(out_nnz),
197-
out_indices_data,
198-
out_values_data);
202+
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_nnz, 1);
203+
GetCooOutCudaKernel<T><<<config.block_per_grid.x,
204+
config.thread_per_block.x,
205+
0,
206+
dev_ctx.stream()>>>(x_indices_data,
207+
x_values_data,
208+
d_axes,
209+
d_starts,
210+
axes.size(),
211+
sparse_dim,
212+
x.nnz(),
213+
static_cast<int64_t>(out_nnz),
214+
d_out_nnz_indices_ptr,
215+
out_indices_data,
216+
out_values_data);
199217
}
200218

201219
__global__ void GetCsrNonZeroNumberCudaKernel(const int64_t* x_crows_data,

python/paddle/fluid/tests/unittests/test_sparse_slice_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828

2929
data_5d = [
3030
[[2, 3, 4, 5, 6], [0, 1, 2, 4], [0, 1, 2, -4], [3, 3, 4, -2]],
31+
[[2, 64, 256, 256, 10], [0, 1, 2, 4], [0, 1, 2, -4], [3, 3, 4, -2]],
3132
]
3233
data_4d = [
3334
[[2, 3, 4, 5], [0, 1, 2, 3], [0, 1, 2, -4], [3, 3, 4, -2]],
35+
[[64, 256, 256, 10], [0, 1, 2, 3], [0, 1, 2, -4], [3, 3, 4, -2]],
3436
]
3537

3638
data_3d = [
@@ -41,6 +43,7 @@
4143
[[4, 4, 5], [1], [2], [3]],
4244
[[4, 4, 5], [1, 2], [2, 2], [3, 4]],
4345
[[4, 4, 5], [0, 2], [2, 2], [3, 4]],
46+
[[256, 256, 10], [0, 2], [2, 2], [3, 4]],
4447
]
4548

4649
data_2d = [
@@ -115,6 +118,8 @@ def test_coo_3d(self):
115118
self.check_result_with_shape(*item, format='coo')
116119

117120
def test_coo_2d(self):
121+
x = [[1, 2, 3, 4], [0, 1, 2, 0]]
122+
self.check_result_with_list(x, [0, 1], [0, 1], [2, 3], format='coo')
118123
for item in data_2d:
119124
self.check_result_with_shape(*item, format='coo')
120125

0 commit comments

Comments
 (0)