@@ -96,9 +96,10 @@ class StackGPUKernel : public framework::OpKernel<T> {
96
96
};
97
97
98
98
template <typename T, typename IntType>
99
- __global__ void UnStackCUDAKernel (const T* __restrict__ input, int pre_dim_size,
100
- int split_dim_size, int suf_dim_size,
101
- int num_split, T** output_ptrs) {
99
+ __global__ void UnStackHelperCUDAKernel (const T* __restrict__ input,
100
+ int pre_dim_size, int split_dim_size,
101
+ int suf_dim_size, int num_split,
102
+ T** output_ptrs) {
102
103
assert (blockDim .y == 1 );
103
104
assert (blockDim .z == 1 );
104
105
// In this case they are equal
@@ -114,6 +115,9 @@ __global__ void UnStackCUDAKernel(const T* __restrict__ input, int pre_dim_size,
114
115
IntType k = offset % suf_dim_size;
115
116
116
117
T* output = output_ptrs[j / each_dim_size];
118
+ if (output == nullptr ) {
119
+ return ;
120
+ }
117
121
IntType output_ind = i * each_dim_size * suf_dim_size +
118
122
(j % each_dim_size) * suf_dim_size + k;
119
123
*(output + output_ind) = input[offset];
@@ -142,6 +146,9 @@ class StackGradGPUKernel : public framework::OpKernel<T> {
142
146
std::vector<T*> outputs (n);
143
147
auto out_var_names = ctx.OutputNames (framework::GradVarName (" X" ));
144
148
for (size_t j = 0 ; j < dx.size (); ++j) {
149
+ if (dx[j] == nullptr ) {
150
+ outputs[j] = nullptr ;
151
+ }
145
152
if (out_var_names[j] != framework::kEmptyVarName &&
146
153
dx[j]->numel () != 0UL ) {
147
154
T* ptr = dx[j]->mutable_data <T>(ctx.GetPlace ());
@@ -170,13 +177,13 @@ class StackGradGPUKernel : public framework::OpKernel<T> {
170
177
auto config = GetGpuLaunchConfig1D (dev_ctx, dy_pre * split_dim * dy_suf);
171
178
172
179
if (dy->numel () < std::numeric_limits<int32_t >::max ()) {
173
- UnStackCUDAKernel <
180
+ UnStackHelperCUDAKernel <
174
181
T, int32_t ><<<config.block_per_grid.x, config.thread_per_block.x, 0 ,
175
182
dev_ctx.stream()>>> (
176
183
dy_data, dy_pre, split_dim, dy_suf, split_dim,
177
184
reinterpret_cast <T**>(tmp_out_data->ptr ()));
178
185
} else {
179
- UnStackCUDAKernel <
186
+ UnStackHelperCUDAKernel <
180
187
T, int64_t ><<<config.block_per_grid.x, config.thread_per_block.x, 0 ,
181
188
dev_ctx.stream()>>> (
182
189
dy_data, dy_pre, split_dim, dy_suf, split_dim,
0 commit comments