Skip to content

Commit f54fb1e

Browse files
authored
fix stack grad gpu (#32781)
1 parent ded39f8 commit f54fb1e

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

paddle/fluid/operators/stack_op.cu

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ class StackGPUKernel : public framework::OpKernel<T> {
9696
};
9797

9898
template <typename T, typename IntType>
99-
__global__ void UnStackCUDAKernel(const T* __restrict__ input, int pre_dim_size,
100-
int split_dim_size, int suf_dim_size,
101-
int num_split, T** output_ptrs) {
99+
__global__ void UnStackHelperCUDAKernel(const T* __restrict__ input,
100+
int pre_dim_size, int split_dim_size,
101+
int suf_dim_size, int num_split,
102+
T** output_ptrs) {
102103
assert(blockDim.y == 1);
103104
assert(blockDim.z == 1);
104105
// In this case they are equal
@@ -114,6 +115,9 @@ __global__ void UnStackCUDAKernel(const T* __restrict__ input, int pre_dim_size,
114115
IntType k = offset % suf_dim_size;
115116

116117
T* output = output_ptrs[j / each_dim_size];
118+
if (output == nullptr) {
119+
return;
120+
}
117121
IntType output_ind = i * each_dim_size * suf_dim_size +
118122
(j % each_dim_size) * suf_dim_size + k;
119123
*(output + output_ind) = input[offset];
@@ -142,6 +146,9 @@ class StackGradGPUKernel : public framework::OpKernel<T> {
142146
std::vector<T*> outputs(n);
143147
auto out_var_names = ctx.OutputNames(framework::GradVarName("X"));
144148
for (size_t j = 0; j < dx.size(); ++j) {
149+
if (dx[j] == nullptr) {
150+
outputs[j] = nullptr;
151+
}
145152
if (out_var_names[j] != framework::kEmptyVarName &&
146153
dx[j]->numel() != 0UL) {
147154
T* ptr = dx[j]->mutable_data<T>(ctx.GetPlace());
@@ -170,13 +177,13 @@ class StackGradGPUKernel : public framework::OpKernel<T> {
170177
auto config = GetGpuLaunchConfig1D(dev_ctx, dy_pre * split_dim * dy_suf);
171178

172179
if (dy->numel() < std::numeric_limits<int32_t>::max()) {
173-
UnStackCUDAKernel<
180+
UnStackHelperCUDAKernel<
174181
T, int32_t><<<config.block_per_grid.x, config.thread_per_block.x, 0,
175182
dev_ctx.stream()>>>(
176183
dy_data, dy_pre, split_dim, dy_suf, split_dim,
177184
reinterpret_cast<T**>(tmp_out_data->ptr()));
178185
} else {
179-
UnStackCUDAKernel<
186+
UnStackHelperCUDAKernel<
180187
T, int64_t><<<config.block_per_grid.x, config.thread_per_block.x, 0,
181188
dev_ctx.stream()>>>(
182189
dy_data, dy_pre, split_dim, dy_suf, split_dim,

0 commit comments

Comments
 (0)