Skip to content

Commit 81ef31d

Browse files
authored
[PHI] Fix paddle.dist for big tensor (#73064)
1 parent 479b646 commit 81ef31d

File tree

1 file changed

+3
-11
lines changed

1 file changed

+3
-11
lines changed

paddle/phi/kernels/gpu/dist_kernel.cu

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,8 @@ __global__ void ReduceSumWithSubtract(
6767
const T* x, const T* y, T* out, int64_t N, Functor func) {
6868
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
6969
MT sum_val(0.0);
70-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
71-
i += blockDim.x * gridDim.x) {
72-
sum_val += func(x[i], y[i]);
73-
}
70+
CUDA_KERNEL_LOOP_TYPE(i, N, int64_t) { sum_val += func(x[i], y[i]); }
7471

75-
__syncthreads();
7672
sum_val = phi::funcs::BlockReduceSum<MT>(sum_val, FULL_MASK);
7773
if (threadIdx.x == 0) {
7874
out[blockIdx.x] = static_cast<T>(sum_val);
@@ -86,12 +82,10 @@ __global__ void ReduceMaxWithSubtract(const T* x,
8682
int64_t N) {
8783
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
8884
MT max_val = std::numeric_limits<MT>::min();
89-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
90-
i += blockDim.x * gridDim.x) {
85+
CUDA_KERNEL_LOOP_TYPE(i, N, int64_t) {
9186
max_val = max(max_val, abs(static_cast<MT>(x[i]) - static_cast<MT>(y[i])));
9287
}
9388

94-
__syncthreads();
9589
max_val = phi::funcs::BlockReduceMax<MT>(max_val, FULL_MASK);
9690
if (threadIdx.x == 0) {
9791
out[blockIdx.x] = static_cast<T>(max_val);
@@ -105,12 +99,10 @@ __global__ void ReduceMinWithSubtract(const T* x,
10599
int64_t N) {
106100
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
107101
MT min_val = std::numeric_limits<MT>::max();
108-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
109-
i += blockDim.x * gridDim.x) {
102+
CUDA_KERNEL_LOOP_TYPE(i, N, int64_t) {
110103
min_val = min(min_val, abs(static_cast<MT>(x[i]) - static_cast<MT>(y[i])));
111104
}
112105

113-
__syncthreads();
114106
min_val = phi::funcs::BlockReduceMin<MT>(min_val, FULL_MASK);
115107
if (threadIdx.x == 0) {
116108
out[blockIdx.x] = static_cast<T>(min_val);

0 commit comments

Comments
 (0)