Skip to content

Commit cdeffff

Browse files
authored
fix gpt2 train loss Nan problem by add a line __syncthreads in BlockReduceSum (#33659)
1 parent 18043ab commit cdeffff

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

paddle/fluid/operators/correlation_op.cu

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ __forceinline__ __device__ T blockReduceSum(T val) {
4242
int wid = threadIdx.x / warpSize;
4343

4444
val = warpReduceSum(val);
45+
__syncthreads();
4546
if (lane == 0) shared[wid] = val;
4647

4748
__syncthreads();

paddle/fluid/operators/layer_norm_op.cu

+10-7
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,16 @@ static __forceinline__ __device__ U WarpReduceSum(U val) {
6464
}
6565

6666
template <typename U>
67-
__forceinline__ __device__ U BlockReduceSum(U val) {
68-
static __shared__ U shared[32];
67+
__forceinline__ __device__ U BlockReduceSum(U val, U *shared) {
6968
int lane = threadIdx.x % warpSize;
7069
int wid = threadIdx.x / warpSize;
7170

7271
val = WarpReduceSum(val); // Each warp performs partial reduction
7372

73+
__syncthreads();
7474
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
7575

7676
__syncthreads(); // Wait for all partial reductions
77-
7877
// read from shared memory only if that warp existed
7978
val =
8079
(threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast<U>(0);
@@ -183,6 +182,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
183182
int64_t feature_size) {
184183
__shared__ U mean_share;
185184
__shared__ U var_share;
185+
__shared__ U shared_mean[32];
186+
__shared__ U shared_var[32];
186187

187188
int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x;
188189
int64_t end_idx = (blockIdx.x + 1) * feature_size;
@@ -196,8 +197,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
196197
var_val += (tmp * tmp);
197198
}
198199

199-
mean_val = BlockReduceSum<U>(mean_val);
200-
var_val = BlockReduceSum<U>(var_val);
200+
mean_val = BlockReduceSum<U>(mean_val, shared_mean);
201+
var_val = BlockReduceSum<U>(var_val, shared_var);
201202

202203
if (threadIdx.x == 0) {
203204
auto scale = static_cast<float>(1.) / static_cast<float>(feature_size);
@@ -541,8 +542,10 @@ __global__ void LayerNormBackwardGradientAll(
541542
}
542543
}
543544

544-
d_scale_partial = BlockReduceSum<U>(d_scale_partial);
545-
d_bias_partial = BlockReduceSum<U>(d_bias_partial);
545+
__shared__ U shared_scale[32];
546+
__shared__ U shared_bias[32];
547+
d_scale_partial = BlockReduceSum<U>(d_scale_partial, shared_scale);
548+
d_bias_partial = BlockReduceSum<U>(d_bias_partial, shared_bias);
546549

547550
if (threadIdx.x == 0) {
548551
d_scale[blockIdx.x + col_offset] = d_scale_partial;

paddle/fluid/operators/math/math_cuda_utils.h

+1
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
188188

189189
val = warpReduceSum<T>(val, mask);
190190

191+
__syncthreads();
191192
if (lane == 0) shared[wid] = val;
192193

193194
__syncthreads();

0 commit comments

Comments
 (0)