@@ -64,17 +64,16 @@ static __forceinline__ __device__ U WarpReduceSum(U val) {
64
64
}
65
65
66
66
template <typename U>
67
- __forceinline__ __device__ U BlockReduceSum (U val) {
68
- static __shared__ U shared[32 ];
67
+ __forceinline__ __device__ U BlockReduceSum (U val, U *shared) {
69
68
int lane = threadIdx .x % warpSize ;
70
69
int wid = threadIdx .x / warpSize ;
71
70
72
71
val = WarpReduceSum (val); // Each warp performs partial reduction
73
72
73
+ __syncthreads ();
74
74
if (lane == 0 ) shared[wid] = val; // Write reduced value to shared memory
75
75
76
76
__syncthreads (); // Wait for all partial reductions
77
-
78
77
// read from shared memory only if that warp existed
79
78
val =
80
79
(threadIdx .x < blockDim .x / warpSize ) ? shared[lane] : static_cast <U>(0 );
@@ -183,6 +182,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
183
182
int64_t feature_size) {
184
183
__shared__ U mean_share;
185
184
__shared__ U var_share;
185
+ __shared__ U shared_mean[32 ];
186
+ __shared__ U shared_var[32 ];
186
187
187
188
int64_t beg_idx = blockIdx .x * feature_size + threadIdx .x ;
188
189
int64_t end_idx = (blockIdx .x + 1 ) * feature_size;
@@ -196,8 +197,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
196
197
var_val += (tmp * tmp);
197
198
}
198
199
199
- mean_val = BlockReduceSum<U>(mean_val);
200
- var_val = BlockReduceSum<U>(var_val);
200
+ mean_val = BlockReduceSum<U>(mean_val, shared_mean );
201
+ var_val = BlockReduceSum<U>(var_val, shared_var );
201
202
202
203
if (threadIdx .x == 0 ) {
203
204
auto scale = static_cast <float >(1 .) / static_cast <float >(feature_size);
@@ -541,8 +542,10 @@ __global__ void LayerNormBackwardGradientAll(
541
542
}
542
543
}
543
544
544
- d_scale_partial = BlockReduceSum<U>(d_scale_partial);
545
- d_bias_partial = BlockReduceSum<U>(d_bias_partial);
545
+ __shared__ U shared_scale[32 ];
546
+ __shared__ U shared_bias[32 ];
547
+ d_scale_partial = BlockReduceSum<U>(d_scale_partial, shared_scale);
548
+ d_bias_partial = BlockReduceSum<U>(d_bias_partial, shared_bias);
546
549
547
550
if (threadIdx .x == 0 ) {
548
551
d_scale[blockIdx .x + col_offset] = d_scale_partial;
0 commit comments