12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
+ #include < thrust/device_vector.h>
16
+ #include < thrust/sort.h>
17
+
15
18
#include " paddle/phi/kernels/sparse/unary_kernel.h"
16
19
17
20
#include " paddle/phi/backends/gpu/gpu_context.h"
@@ -32,7 +35,8 @@ __global__ void GetCooNonZeroNumberCudaKernel(const int64_t* x_indices_data,
32
35
const int64_t * ends,
33
36
const int64_t axes_size,
34
37
const int64_t x_nnz,
35
- int * out_nnz) {
38
+ int * out_nnz,
39
+ int64_t * out_nnz_indices) {
36
40
CUDA_KERNEL_LOOP_TYPE (j, x_nnz, int64_t ) {
37
41
bool hit = true ;
38
42
for (size_t ii = 0 ; ii < axes_size; ++ii) {
@@ -43,7 +47,8 @@ __global__ void GetCooNonZeroNumberCudaKernel(const int64_t* x_indices_data,
43
47
}
44
48
}
45
49
if (!hit) continue ;
46
- atomicAdd (out_nnz, 1 );
50
+ int old_value = atomicAdd (out_nnz, 1 );
51
+ out_nnz_indices[old_value] = j;
47
52
}
48
53
}
49
54
@@ -52,37 +57,27 @@ __global__ void GetCooOutCudaKernel(const int64_t* x_indices_data,
52
57
const T* x_values_data,
53
58
const int64_t * axes,
54
59
const int64_t * starts,
55
- const int64_t * ends,
56
60
const int64_t axes_size,
57
61
const int64_t sparse_dim,
58
62
const int64_t x_nnz,
59
63
const int64_t out_nnz,
64
+ const int64_t * out_nnz_indices,
60
65
int64_t * out_indices_data,
61
66
T* out_values_data) {
62
- int64_t tid = static_cast <int64_t >(blockIdx .x ) * blockDim .x + threadIdx .x ;
63
- if (tid == 0 ) {
64
- int64_t index = 0 ;
65
- for (int64_t j = 0 ; j < x_nnz && index < out_nnz; ++j) {
66
- bool hit = true ;
67
- for (size_t ii = 0 ; ii < axes_size; ++ii) {
68
- auto item = x_indices_data[axes[ii] * x_nnz + j];
69
- if (!(starts[ii] <= item && item < ends[ii])) {
70
- hit = false ;
71
- break ;
72
- }
73
- }
74
- if (!hit) continue ;
75
- // set value
76
- out_values_data[index ] = x_values_data[j];
77
- // set coordinate
78
- for (int64_t i = 0 ; i < sparse_dim; ++i) {
79
- out_indices_data[i * out_nnz + index ] = x_indices_data[i * x_nnz + j];
80
- }
81
- for (size_t ii = 0 ; ii < axes_size; ++ii) {
82
- auto i = axes[ii];
83
- out_indices_data[i * out_nnz + index ] -= starts[ii];
84
- }
85
- index ++;
67
+ CUDA_KERNEL_LOOP_TYPE (index , out_nnz, int64_t ) {
68
+ // index is in the order of the non-zero elements in out
69
+ // out_nnz_indices[index] is the valid index in x's non-zero elements, where
70
+ // the `hit` is true.
71
+ int64_t j = out_nnz_indices[index ];
72
+ // set value
73
+ out_values_data[index ] = x_values_data[j];
74
+ // set coordinate
75
+ for (int64_t i = 0 ; i < sparse_dim; ++i) {
76
+ out_indices_data[i * out_nnz + index ] = x_indices_data[i * x_nnz + j];
77
+ }
78
+ for (size_t ii = 0 ; ii < axes_size; ++ii) {
79
+ auto i = axes[ii];
80
+ out_indices_data[i * out_nnz + index ] -= starts[ii];
86
81
}
87
82
}
88
83
}
@@ -113,6 +108,13 @@ void SliceCooKernel(const Context& dev_ctx,
113
108
phi::backends::gpu::GpuMemsetAsync (
114
109
d_out_nnz_ptr, 0 , sizeof (int32_t ), dev_ctx.stream ());
115
110
111
+ // out_nnz_indices is the indices where the data is valid in out
112
+ // the length of the out_nnz_indices must be less than x.nnz()
113
+ DenseTensor d_out_nnz_indices = phi::Empty<int64_t >(dev_ctx, {x.nnz ()});
114
+ int64_t * d_out_nnz_indices_ptr = d_out_nnz_indices.data <int64_t >();
115
+ phi::backends::gpu::GpuMemsetAsync (
116
+ d_out_nnz_indices_ptr, 0 , sizeof (int64_t ), dev_ctx.stream ());
117
+
116
118
// copy axes to device
117
119
auto d_axes_tensor = memory_utils::Alloc (
118
120
dev_ctx.GetPlace (),
@@ -164,14 +166,27 @@ void SliceCooKernel(const Context& dev_ctx,
164
166
d_ends,
165
167
axes.size (),
166
168
x.nnz (),
167
- d_out_nnz_ptr);
169
+ d_out_nnz_ptr,
170
+ d_out_nnz_indices_ptr);
168
171
172
+ // copy d_out_nnz from device to host (out_nnz)
169
173
int32_t out_nnz = 0 ;
170
174
phi::backends::gpu::GpuMemcpyAsync (&out_nnz,
171
175
d_out_nnz_ptr,
172
176
sizeof (int32_t ),
173
177
gpuMemcpyDeviceToHost,
174
178
dev_ctx.stream ());
179
+ // sort `d_out_nnz_indices_ptr`
180
+ d_out_nnz_indices.Resize ({out_nnz});
181
+ thrust::device_vector<int64_t > d_out_nnz_indices_vec (
182
+ d_out_nnz_indices_ptr, d_out_nnz_indices_ptr + out_nnz);
183
+ thrust::sort (d_out_nnz_indices_vec.begin (), d_out_nnz_indices_vec.end ());
184
+ phi::backends::gpu::GpuMemcpyAsync (
185
+ d_out_nnz_indices_ptr,
186
+ thrust::raw_pointer_cast (d_out_nnz_indices_vec.data ()),
187
+ out_nnz * sizeof (int64_t ),
188
+ gpuMemcpyDeviceToDevice,
189
+ dev_ctx.stream ());
175
190
176
191
// Step4: Get the values and indices of output
177
192
auto sparse_dim = static_cast <int64_t >(x.sparse_dim ());
@@ -184,18 +199,21 @@ void SliceCooKernel(const Context& dev_ctx,
184
199
auto * out_values_data = out_values.data <T>();
185
200
const auto * x_values_data = x.values ().data <T>();
186
201
187
- GetCooOutCudaKernel<T>
188
- <<<1 , 1 , 0 , dev_ctx.stream()>>> (x_indices_data,
189
- x_values_data,
190
- d_axes,
191
- d_starts,
192
- d_ends,
193
- axes.size (),
194
- sparse_dim,
195
- x.nnz (),
196
- static_cast <int64_t >(out_nnz),
197
- out_indices_data,
198
- out_values_data);
202
+ config = phi::backends::gpu::GetGpuLaunchConfig1D (dev_ctx, out_nnz, 1 );
203
+ GetCooOutCudaKernel<T><<<config.block_per_grid.x,
204
+ config.thread_per_block.x,
205
+ 0 ,
206
+ dev_ctx.stream()>>> (x_indices_data,
207
+ x_values_data,
208
+ d_axes,
209
+ d_starts,
210
+ axes.size (),
211
+ sparse_dim,
212
+ x.nnz (),
213
+ static_cast <int64_t >(out_nnz),
214
+ d_out_nnz_indices_ptr,
215
+ out_indices_data,
216
+ out_values_data);
199
217
}
200
218
201
219
__global__ void GetCsrNonZeroNumberCudaKernel (const int64_t * x_crows_data,
0 commit comments