Skip to content

Commit 7006838

Browse files
committed
[API] fix paddle.dist with big tensor
1 parent f51e3ff commit 7006838

File tree

8 files changed

+151
-151
lines changed

8 files changed

+151
-151
lines changed

paddle/phi/kernels/cpu/p_norm_grad_kernel.cc

+12-12
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,21 @@ namespace phi {
2424

2525
inline void GetDims(const phi::DDim& dim,
2626
int axis,
27-
int* pre,
28-
int* n,
29-
int* post,
27+
int64_t* pre,
28+
int64_t* n,
29+
int64_t* post,
3030
bool asvector) {
3131
*pre = 1;
3232
*post = 1;
33-
*n = static_cast<int>(dim[axis]);
33+
*n = dim[axis];
3434
if (asvector) {
35-
*n = static_cast<int>(product(dim));
35+
*n = product(dim);
3636
} else {
3737
for (int i = 0; i < axis; ++i) {
38-
(*pre) *= static_cast<int>(dim[i]);
38+
(*pre) *= dim[i];
3939
}
4040
for (int i = axis + 1; i < dim.size(); ++i) {
41-
(*post) *= static_cast<int>(dim[i]);
41+
(*post) *= dim[i];
4242
}
4343
}
4444
}
@@ -64,10 +64,10 @@ void PNormGradKernel(const Context& dev_ctx,
6464
auto xdim = in_x->dims();
6565

6666
if (axis < 0) axis = xdim.size() + axis;
67-
int pre, n, post;
67+
int64_t pre, n, post;
6868
GetDims(xdim, axis, &pre, &n, &post, asvector);
69-
Eigen::DSizes<int, 3> shape(pre, n, post);
70-
Eigen::DSizes<int, 3> rshape(pre, 1, post);
69+
Eigen::DSizes<int64_t, 3> shape(pre, n, post);
70+
Eigen::DSizes<int64_t, 3> rshape(pre, 1, post);
7171

7272
auto* place = dev_ctx.eigen_device();
7373

@@ -81,8 +81,8 @@ void PNormGradKernel(const Context& dev_ctx,
8181
auto norm = norm_e.reshape(rshape);
8282
auto norm_dy = norm_dy_e.reshape(rshape);
8383

84-
Eigen::DSizes<int, 1> rdim(1);
85-
Eigen::DSizes<int, 3> bcast(1, n, 1);
84+
Eigen::DSizes<int64_t, 1> rdim(1);
85+
Eigen::DSizes<int64_t, 3> bcast(1, n, 1);
8686

8787
if (porder == 0) {
8888
phi::funcs::SetConstant<Context, T> set_zero;

paddle/phi/kernels/cpu/p_norm_kernel.cc

+11-11
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@ namespace phi {
2626

2727
inline void GetDims(const phi::DDim& dim,
2828
int axis,
29-
int* pre,
30-
int* n,
31-
int* post,
29+
int64_t* pre,
30+
int64_t* n,
31+
int64_t* post,
3232
bool asvector) {
3333
*pre = 1;
3434
*post = 1;
35-
*n = static_cast<int>(dim[axis]);
35+
*n = static_cast<int64_t>(dim[axis]);
3636
if (asvector) {
37-
*n = static_cast<int>(product(dim));
37+
*n = static_cast<int64_t>(product(dim));
3838
} else {
3939
for (int i = 0; i < axis; ++i) {
40-
(*pre) *= static_cast<int>(dim[i]);
40+
(*pre) *= static_cast<int64_t>(dim[i]);
4141
}
4242
for (int i = axis + 1; i < dim.size(); ++i) {
43-
(*post) *= static_cast<int>(dim[i]);
43+
(*post) *= static_cast<int64_t>(dim[i]);
4444
}
4545
}
4646
}
@@ -59,7 +59,7 @@ void PNormKernel(const Context& dev_ctx,
5959

6060
auto xdim = in_x->dims();
6161
if (axis < 0) axis = xdim.size() + axis;
62-
int pre = 0, n = 0, post = 0;
62+
int64_t pre = 0, n = 0, post = 0;
6363
GetDims(xdim, axis, &pre, &n, &post, asvector);
6464

6565
if (x.numel() == 0) {
@@ -73,8 +73,8 @@ void PNormKernel(const Context& dev_ctx,
7373

7474
auto* place = dev_ctx.eigen_device();
7575

76-
Eigen::DSizes<int, 3> shape(pre, n, post);
77-
Eigen::DSizes<int, 2> norm_shape(pre, post);
76+
Eigen::DSizes<int64_t, 3> shape(pre, n, post);
77+
Eigen::DSizes<int64_t, 2> norm_shape(pre, post);
7878

7979
auto x_e = phi::EigenVector<T>::Flatten(*in_x);
8080
auto norm_e = phi::EigenVector<T>::Flatten(*out);
@@ -86,7 +86,7 @@ void PNormKernel(const Context& dev_ctx,
8686
// p=inf means the maximum of |xr|
8787
// p=-inf means the minimum of |xr|
8888
// otherwise, Lp-norm = pow(sum(pow(|xr|, p)), 1/p)
89-
Eigen::DSizes<int, 1> rdim(1);
89+
Eigen::DSizes<int64_t, 1> rdim(1);
9090
if (porder == 0) {
9191
norm.device(*place) = (xr != xr.constant(0)).template cast<T>().sum(rdim);
9292
} else if (porder == INFINITY) {

paddle/phi/kernels/funcs/blas/blas_impl.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -78,31 +78,31 @@ struct CBlas<phi::dtype::bfloat16> {
7878
}
7979

8080
template <typename... ARGS>
81-
static void VADD(int n,
81+
static void VADD(int64_t n,
8282
const phi::dtype::bfloat16 *x,
8383
const phi::dtype::bfloat16 *y,
8484
phi::dtype::bfloat16 *z) {
85-
for (int i = 0; i < n; ++i) {
85+
for (int64_t i = 0; i < n; ++i) {
8686
z[i] = x[i] + y[i];
8787
}
8888
}
8989

9090
template <typename... ARGS>
91-
static void VMUL(int n,
91+
static void VMUL(int64_t n,
9292
const phi::dtype::bfloat16 *x,
9393
const phi::dtype::bfloat16 *y,
9494
phi::dtype::bfloat16 *z) {
95-
for (int i = 0; i < n; ++i) {
95+
for (int64_t i = 0; i < n; ++i) {
9696
z[i] = x[i] * y[i];
9797
}
9898
}
9999

100100
template <typename... ARGS>
101-
static void VSUB(int n,
101+
static void VSUB(int64_t n,
102102
const phi::dtype::bfloat16 *x,
103103
const phi::dtype::bfloat16 *y,
104104
phi::dtype::bfloat16 *z) {
105-
for (int i = 0; i < n; ++i) {
105+
for (int64_t i = 0; i < n; ++i) {
106106
z[i] = x[i] - y[i];
107107
}
108108
}

0 commit comments

Comments
 (0)