From 6f80ea0bbcdc391459621e8f730a9dfd56057d96 Mon Sep 17 00:00:00 2001 From: lilujia Date: Tue, 22 Apr 2025 17:05:27 +0800 Subject: [PATCH] [XPU] fix index's datatype, using int64 instead of int, part 1 (a-f) --- paddle/phi/kernels/funcs/norm_utils.h | 26 +++ .../phi/kernels/xpu/activation_grad_kernel.cc | 8 +- paddle/phi/kernels/xpu/activation_kernel.cc | 10 +- paddle/phi/kernels/xpu/add_n_kernel.cc | 2 +- .../kernels/xpu/affine_channel_grad_kernel.cc | 7 +- .../phi/kernels/xpu/affine_channel_kernel.cc | 7 +- paddle/phi/kernels/xpu/amp_kernel.cc | 2 +- paddle/phi/kernels/xpu/arg_min_max_kernel.cc | 8 +- paddle/phi/kernels/xpu/argsort_grad_kernel.cc | 16 +- paddle/phi/kernels/xpu/argsort_kernel.cc | 179 +++--------------- .../phi/kernels/xpu/batch_norm_grad_kernel.cc | 10 +- paddle/phi/kernels/xpu/batch_norm_kernel.cc | 2 +- paddle/phi/kernels/xpu/bmm_xpu_utils.h | 8 +- paddle/phi/kernels/xpu/c_concat_kernel.cc | 4 +- paddle/phi/kernels/xpu/c_embedding_kernel.cc | 6 - paddle/phi/kernels/xpu/clip_by_norm_kernel.cc | 4 +- .../kernels/xpu/concat_and_split_functor.cc | 10 +- paddle/phi/kernels/xpu/concat_grad_kernel.cc | 6 +- paddle/phi/kernels/xpu/concat_kernel.cc | 4 +- paddle/phi/kernels/xpu/conv_grad_kernel.cc | 2 +- paddle/phi/kernels/xpu/conv_kernel.cc | 6 +- .../kernels/xpu/conv_transpose_grad_kernel.cc | 36 ++-- .../phi/kernels/xpu/conv_transpose_kernel.cc | 57 +++--- .../kernels/xpu/cross_entropy_grad_kernel.cc | 10 +- .../phi/kernels/xpu/cross_entropy_kernel.cc | 2 +- paddle/phi/kernels/xpu/cum_kernel.cc | 4 +- .../xpu/deformable_conv_grad_kernel.cc | 45 +++-- .../phi/kernels/xpu/deformable_conv_kernel.cc | 39 ++-- paddle/phi/kernels/xpu/diag_kernel.cc | 6 +- .../xpu/distribute_fpn_proposals_kernel.cc | 2 +- paddle/phi/kernels/xpu/elementwise.h | 2 +- .../phi/kernels/xpu/elementwise_add_kernel.cc | 6 +- .../phi/kernels/xpu/embedding_grad_kernel.cc | 12 +- paddle/phi/kernels/xpu/embedding_kernel.cc | 13 +- paddle/phi/kernels/xpu/expand_as_kernel.cc | 5 +- paddle/phi/kernels/xpu/expand_kernel.cc | 2 +- .../phi/kernels/xpu/flash_attn_grad_kernel.cc | 10 +- paddle/phi/kernels/xpu/flash_attn_kernel.cc | 20 +- paddle/phi/kernels/xpu/full_kernel.cc | 2 +- .../xpu/fused_attention_grad_kernel.cc | 14 +- .../phi/kernels/xpu/fused_attention_kernel.cc | 40 ++-- paddle/phi/kernels/xpu/gather_grad_kernel.cc | 2 +- paddle/phi/kernels/xpu/gather_kernel.cc | 2 +- .../phi/kernels/xpu/gather_nd_grad_kernel.cc | 2 +- paddle/phi/kernels/xpu/gather_nd_kernel.cc | 2 +- .../phi/kernels/xpu/group_norm_grad_kernel.cc | 2 +- .../phi/kernels/xpu/index_put_grad_kernel.cc | 2 +- .../kernels/xpu/interpolate_grad_kernel.cc | 2 +- .../phi/kernels/xpu/kldiv_loss_grad_kernel.cc | 2 +- paddle/phi/kernels/xpu/kldiv_loss_kernel.cc | 2 +- paddle/phi/kernels/xpu/logical_kernel.cc | 2 +- paddle/phi/kernels/xpu/reduce.h | 2 +- .../kernels/xpu/scatter_nd_add_grad_kernel.cc | 2 +- .../phi/kernels/xpu/set_value_grad_kernel.cc | 2 +- paddle/phi/kernels/xpu/set_value_kernel.cc | 2 +- paddle/phi/kernels/xpu/softmax_kernel.cc | 2 +- .../phi/kernels/xpu/take_along_axis_kernel.cc | 2 +- paddle/phi/kernels/xpu/tile_kernel.cc | 2 +- paddle/phi/kernels/xpu/unique_kernel.cc | 4 +- paddle/phi/kernels/xpu/xpu_api_wrapper.h | 114 +++++------ .../kernels/xpu/xpu_fused_common_function.h | 6 +- 61 files changed, 347 insertions(+), 465 deletions(-) diff --git a/paddle/phi/kernels/funcs/norm_utils.h b/paddle/phi/kernels/funcs/norm_utils.h index c3a3b07ae08cc..204255ba17542 100644 --- a/paddle/phi/kernels/funcs/norm_utils.h +++ b/paddle/phi/kernels/funcs/norm_utils.h @@ -46,5 +46,31 @@ inline void ExtractNCWHD(const phi::DDim &dims, : 1; } } + +inline void ExtractNCWHD(const phi::DDim &dims, + const DataLayout &data_layout, + int64_t *N, + int64_t *C, + int64_t *H, + int64_t *W, + int64_t *D) { + *N = dims[0]; + if (dims.size() == 2) { + *C = dims[1]; + *H = 1; + *W = 1; + *D = 1; + } else { + *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; + *H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1]; + *W = dims.size() > 3 + ? (data_layout == DataLayout::kNCHW ? dims[3] : dims[2]) + : 1; + *D = dims.size() > 4 + ? (data_layout == DataLayout::kNCHW ? dims[4] : dims[3]) + : 1; + } +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/xpu/activation_grad_kernel.cc b/paddle/phi/kernels/xpu/activation_grad_kernel.cc index 9793cabc07b75..58ec9f2d7e962 100644 --- a/paddle/phi/kernels/xpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_grad_kernel.cc @@ -144,7 +144,7 @@ int xpu_activation_backward(const Context& dev_ctx, const XPUType*, const XPUType*, XPUType*, - int)> func) { + int64_t)> func) { /* TODO: relu tanh sigmoid are inplace */ const XPUType* x_data = nullptr; const XPUType* y_data = nullptr; @@ -446,9 +446,9 @@ void PowGradKernel(const Context& dev_ctx, T* x_grad = dx->data(); // check dims: all dims should equal - auto x_dims = common::vectorize(x.dims()); - auto dy_dims = common::vectorize(dout.dims()); - auto dx_dims = common::vectorize(dx->dims()); + auto x_dims = common::vectorize(x.dims()); + auto dy_dims = common::vectorize(dout.dims()); + auto dx_dims = common::vectorize(dx->dims()); PADDLE_ENFORCE_EQ(x_dims, dy_dims, errors::PreconditionNotMet("x_dims should match dy_dims.")); diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index 9e23e2ba36fb2..c8617ceb428eb 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -72,7 +72,7 @@ int xpu_activation_func( const Context& dev_ctx, const DenseTensor& x, DenseTensor* out, - std::function func) { + std::function func) { int r = func(dev_ctx.x_context(), reinterpret_cast(x.data()), reinterpret_cast(out->data()), @@ -85,8 +85,8 @@ int xpu_activation_func_with_max_x_y( const Context& dev_ctx, const DenseTensor& x, DenseTensor* out, - std::function< - int(xpu::Context*, const XPUType*, XPUType*, int, const float*, float*)> + std::function func) { // does not support "const float* max_x, float* max_y" now int r = func(dev_ctx.x_context(), @@ -106,7 +106,7 @@ int xpu_activation_1attr_func(const Context& dev_ctx, std::function func) { @@ -130,7 +130,7 @@ int xpu_activation_2attr_func(const Context& dev_ctx, std::function(out->at(j).data())); // int sum(Context* ctx, const std::vector& x_list, T* - // y, int len); + // y, int64_t len); int r = xpu::sum(dev_ctx.x_context(), ptrs, reinterpret_cast(out->at(j).data()), diff --git a/paddle/phi/kernels/xpu/affine_channel_grad_kernel.cc b/paddle/phi/kernels/xpu/affine_channel_grad_kernel.cc index 146e820722a17..c7c4fe5a6daff 100644 --- a/paddle/phi/kernels/xpu/affine_channel_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/affine_channel_grad_kernel.cc @@ -43,9 +43,10 @@ void AffineChannelGradXPUKernel(const Context& dev_ctx, const phi::DataLayout layout = common::StringToDataLayout(data_layout); auto dims = x->dims(); - int N = dims[0]; - int C = layout == phi::DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; - int HxW = x->numel() / N / C; + int64_t N = dims[0]; + int64_t C = + (layout == phi::DataLayout::kNCHW) ? dims[1] : dims[dims.size() - 1]; + int64_t HxW = x->numel() / N / C; auto* dy_d = dy->data(); auto* scale_d = scale->data(); diff --git a/paddle/phi/kernels/xpu/affine_channel_kernel.cc b/paddle/phi/kernels/xpu/affine_channel_kernel.cc index 5674b1093a68e..a149fab405a82 100644 --- a/paddle/phi/kernels/xpu/affine_channel_kernel.cc +++ b/paddle/phi/kernels/xpu/affine_channel_kernel.cc @@ -39,9 +39,10 @@ void AffineChannelXPUKernel(const Context& dev_ctx, const phi::DataLayout layout = common::StringToDataLayout(data_layout); auto dims = x->dims(); - int N = dims[0]; - int C = layout == phi::DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; - int HxW = x->numel() / N / C; + int64_t N = dims[0]; + int64_t C = + layout == phi::DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; + int64_t HxW = x->numel() / N / C; auto* scale_d = scale->data(); auto* bias_d = bias->data(); diff --git a/paddle/phi/kernels/xpu/amp_kernel.cc b/paddle/phi/kernels/xpu/amp_kernel.cc index 110e91e0db06a..23fd3709144fe 100644 --- a/paddle/phi/kernels/xpu/amp_kernel.cc +++ b/paddle/phi/kernels/xpu/amp_kernel.cc @@ -67,7 +67,7 @@ void UpdateLossScalingKernel(const Context& dev_ctx, for (size_t i = 0; i < xs.size(); ++i) { auto* out = outs[i]; T* out_data = dev_ctx.template Alloc(out); - int num = out->numel(); + int64_t num = out->numel(); if (cpu_found_inf_data) { VLOG(1) << "-- UpdateLossScaling: Find infinite grads. --"; int r = 0; diff --git a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc index 344431858fd34..418b42e734fbd 100644 --- a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc +++ b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc @@ -49,7 +49,7 @@ void ArgMaxKernel(const Context& dev_ctx, dtype)); // TODO(ZHUI): fix dtype of out DDim x_dims; - int axis_val = axis.to(); + int64_t axis_val = axis.to(); if (flatten) { x_dims = common::make_ddim({x.numel()}); // if flatten, the axis just as 0 @@ -58,7 +58,7 @@ void ArgMaxKernel(const Context& dev_ctx, x_dims = x.dims(); if (axis_val < 0) axis_val += x_dims.size(); } - auto xdims_vec = common::vectorize(x_dims); + auto xdims_vec = common::vectorize(x_dims); if (dtype != DataType::INT32) { dev_ctx.template Alloc(out); if (x.dims().size() == 0) { @@ -130,7 +130,7 @@ void ArgMinKernel(const Context& dev_ctx, dtype)); DDim x_dims; - int axis_val = axis.to(); + int64_t axis_val = axis.to(); if (flatten) { x_dims = common::make_ddim({x.numel()}); // If flatten, the axis just as 0 @@ -139,7 +139,7 @@ void ArgMinKernel(const Context& dev_ctx, x_dims = x.dims(); if (axis_val < 0) axis_val += x_dims.size(); } - auto xdims_vec = common::vectorize(x_dims); + auto xdims_vec = common::vectorize(x_dims); if (dtype != DataType::INT32) { dev_ctx.template Alloc(out); if (x.dims().size() == 0) { diff --git a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc index 3e1ef0c0d15d3..c71a266fbda47 100644 --- a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc @@ -51,15 +51,15 @@ void ArgsortGradKernel(const Context& dev_ctx, if (axis == -1 || axis + 1 == in_dims.size()) { is_need_transpose = false; } - int len_before = common::product(common::slice_ddim(in_dims, 0, axis)); - int len_after = + auto len_before = common::product(common::slice_ddim(in_dims, 0, axis)); + auto len_after = common::product(common::slice_ddim(in_dims, axis + 1, in_dims.size())); - int m = len_before * len_after; - int n = in_dims[axis]; - int len = m * n; - std::vector permute_vec{0, 2, 1}; - std::vector data_shape{len_before, n, len_after}; - std::vector data_shape_trans{len_before, len_after, n}; + auto m = len_before * len_after; + auto n = in_dims[axis]; + auto len = m * n; + std::vector permute_vec{0, 2, 1}; + std::vector data_shape{len_before, n, len_after}; + std::vector data_shape_trans{len_before, len_after, n}; const int64_t* indices_data = indices.data(); const T* out_grad_data = out_grad.data(); diff --git a/paddle/phi/kernels/xpu/argsort_kernel.cc b/paddle/phi/kernels/xpu/argsort_kernel.cc index 7b221cff91d03..892bd000d7ea8 100644 --- a/paddle/phi/kernels/xpu/argsort_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_kernel.cc @@ -26,54 +26,47 @@ static inline void xpu_argsort(xpu::Context* ctx, const T* input_data, T* output_data, TID* indices_data, - int m, - int n, + int64_t m, + int64_t n, bool descending, bool stable) { int ret; if (stable) { ret = xpu::stable_sort( ctx, input_data, output_data, indices_data, m, n, descending); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "stable_sort"); } else { ret = xpu::sort(ctx, input_data, output_data, indices_data, m, n, descending); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sort"); } - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sort"); } template static inline void xpu_transpose(xpu::Context* ctx, const T* x, T* y, - const std::vector& xshape, - const std::vector& permute) { + const std::vector& xshape, + const std::vector& permute) { int ret = xpu::transpose(ctx, x, y, xshape, permute); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "transpose"); } -template -static inline void xpu_cast(xpu::Context* ctx, const TX* x, TY* y, int len) { - int ret = xpu::cast(ctx, x, y, len); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "cast"); -} - -template +template struct XPUArgsort { void operator()(xpu::Context* ctx, const T* input_data, T* output_data, int64_t* indices_data, - const std::vector& data_shape, - const std::vector& permute, + const std::vector& data_shape, + const std::vector& permute, bool descending, bool stable) { xpu::ctx_guard RAII_GUARD(ctx); - int m = data_shape[0] * data_shape[2]; - int n = data_shape[1]; - int len = data_shape[0] * data_shape[1] * data_shape[2]; - std::vector trans_data_shape{ + int64_t m = data_shape[0] * data_shape[2]; + int64_t n = data_shape[1]; + int64_t len = data_shape[0] * data_shape[1] * data_shape[2]; + std::vector trans_data_shape{ data_shape[0], data_shape[2], data_shape[1]}; T* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); @@ -96,87 +89,6 @@ struct XPUArgsort { } }; -template -struct XPUArgsort { - void operator()(xpu::Context* ctx, - const T* input_data, - T* output_data, - int64_t* indices_data, - const std::vector& data_shape, - const std::vector& permute, - bool descending, - bool stable) { - xpu::ctx_guard RAII_GUARD(ctx); - int m = data_shape[0] * data_shape[2]; - int n = data_shape[1]; - int len = data_shape[0] * data_shape[1] * data_shape[2]; - std::vector trans_data_shape{ - data_shape[0], data_shape[2], data_shape[1]}; - - T* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - T* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm(len); - - xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute); - xpu_argsort(ctx, - input_data_trans, - output_data_trans, - indices_data_trans, - m, - n, - descending, - stable); - xpu_transpose( - ctx, output_data_trans, output_data, trans_data_shape, permute); - xpu_cast(ctx, indices_data_trans, cast_data_int64, len); - xpu_transpose( - ctx, cast_data_int64, indices_data, trans_data_shape, permute); - } -}; - -template <> -struct XPUArgsort { - void operator()(xpu::Context* ctx, - const int64_t* input_data, - int64_t* output_data, - int64_t* indices_data, - const std::vector& data_shape, - const std::vector& permute, - bool descending, - bool stable) { - xpu::ctx_guard RAII_GUARD(ctx); - int m = data_shape[0] * data_shape[2]; - int n = data_shape[1]; - int len = data_shape[0] * data_shape[1] * data_shape[2]; - std::vector trans_data_shape{ - data_shape[0], data_shape[2], data_shape[1]}; - - int* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - int* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - int* cast_data_int = RAII_GUARD.alloc_l3_or_gm(len); - int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm(len); - - xpu_cast(ctx, input_data, cast_data_int, len); - xpu_transpose(ctx, cast_data_int, input_data_trans, data_shape, permute); - xpu_argsort(ctx, - input_data_trans, - output_data_trans, - indices_data_trans, - m, - n, - descending, - stable); - - xpu_cast(ctx, output_data_trans, cast_data_int64, len); - xpu_transpose(ctx, cast_data_int64, output_data, trans_data_shape, permute); - xpu_cast(ctx, indices_data_trans, cast_data_int64, len); - xpu_transpose( - ctx, cast_data_int64, indices_data, trans_data_shape, permute); - } -}; - template void ArgsortKernel(const Context& dev_ctx, const DenseTensor& input, @@ -188,7 +100,7 @@ void ArgsortKernel(const Context& dev_ctx, auto in_dims = input.dims(); auto rank = in_dims.size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; - int n = in_dims[axis]; + int64_t n = in_dims[axis]; auto input_data = input.data(); auto output_data = dev_ctx.template Alloc(output); @@ -200,62 +112,23 @@ void ArgsortKernel(const Context& dev_ctx, return; } - int len_before = common::product(common::slice_ddim(in_dims, 0, axis)); - int len_after = + int64_t len_before = common::product(common::slice_ddim(in_dims, 0, axis)); + int64_t len_after = common::product(common::slice_ddim(in_dims, axis + 1, in_dims.size())); - std::vector permute_vec{0, 2, 1}; - std::vector data_shape{len_before, n, len_after}; - - bool int64_need_cast = false; - bool index_need_cast = false; - if (std::is_same::value) { - if ((n > 10240) && (n <= 16384)) { - int64_need_cast = true; - } - if ((n > 8192) && (n <= 10240)) { - index_need_cast = true; - } - } else { - if ((n > 10240) && (n <= 16384)) { - index_need_cast = true; - } - } + std::vector permute_vec{0, 2, 1}; + std::vector data_shape{len_before, n, len_after}; using XPUType = typename XPUTypeTrait::Type; - if (int64_need_cast) { - XPUArgsort()( - dev_ctx.x_context(), - reinterpret_cast(input_data), - reinterpret_cast(output_data), - indices_data, - data_shape, - permute_vec, - descending, - stable); - } else if (index_need_cast) { - XPUArgsort()( - dev_ctx.x_context(), - reinterpret_cast(input_data), - reinterpret_cast(output_data), - indices_data, - data_shape, - permute_vec, - descending, - stable); - } else { - XPUArgsort()( - dev_ctx.x_context(), - reinterpret_cast(input_data), - reinterpret_cast(output_data), - indices_data, - data_shape, - permute_vec, - descending, - stable); - } + XPUArgsort()(dev_ctx.x_context(), + reinterpret_cast(input_data), + reinterpret_cast(output_data), + indices_data, + data_shape, + permute_vec, + descending, + stable); } - } // namespace phi PD_REGISTER_KERNEL(argsort, diff --git a/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc index 0e14a6fbb7504..c84954927b332 100644 --- a/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc @@ -28,9 +28,9 @@ static int CalculateInvBNY(xpu::Context *ctx, const T *bias, const T *mean, const T *variance, - const int N, - const int C, - const int M, + const int64_t N, + const int64_t C, + const int64_t M, const T *y) { PADDLE_ENFORCE_EQ(x, y, @@ -58,7 +58,7 @@ template static int CalculateInvVar(xpu::Context *ctx, const T *var, const T epsilon, - const int C, + const int64_t C, T *epsilon_data, T *inv_var) { int r1 = constant(ctx, epsilon_data, 1, epsilon); @@ -124,7 +124,7 @@ void BatchNormGradKernel(const Context &dev_ctx, "But received: the size of input's dimensions is [%d]", x_dims.size())); - int N = -1, C = -1, H = -1, W = -1, D = -1; + int64_t N = -1, C = -1, H = -1, W = -1, D = -1; funcs::ExtractNCWHD(x_dims, data_layout_val, &N, &C, &H, &W, &D); N = (N == 0) ? 1 : N; C = (C == 0) ? 1 : C; diff --git a/paddle/phi/kernels/xpu/batch_norm_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_kernel.cc index ebd36d1f5f5ee..892e0397231a8 100644 --- a/paddle/phi/kernels/xpu/batch_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/batch_norm_kernel.cc @@ -60,7 +60,7 @@ void BatchNormKernel(const Context& dev_ctx, "But received: the size of input's dimensions is [%d]", x_dims.size())); - int N = -1, C = -1, H = -1, W = -1, D = -1; + int64_t N = -1, C = -1, H = -1, W = -1, D = -1; funcs::ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); N = (N == 0) ? 1 : N; C = (C == 0) ? 1 : C; diff --git a/paddle/phi/kernels/xpu/bmm_xpu_utils.h b/paddle/phi/kernels/xpu/bmm_xpu_utils.h index c7c6bfe2bed64..2155f7184a674 100644 --- a/paddle/phi/kernels/xpu/bmm_xpu_utils.h +++ b/paddle/phi/kernels/xpu/bmm_xpu_utils.h @@ -35,10 +35,10 @@ static void MatMulXPUFunction(const DenseTensor& x, ColumnMatrixFromVector(y_dims), 0, trans_y); T* data_c = out->data(); - int m = mat_dim_a.height_; - int n = mat_dim_b.width_; - int k = mat_dim_a.width_; - int batch_size = mat_dim_a.batch_size_; + int64_t m = mat_dim_a.height_; + int64_t n = mat_dim_b.width_; + int64_t k = mat_dim_a.width_; + int64_t batch_size = mat_dim_a.batch_size_; // batch matmul int fc_calc_type = FCCalcType(); decltype(&xblas_fc_batch_wrapper) diff --git a/paddle/phi/kernels/xpu/c_concat_kernel.cc b/paddle/phi/kernels/xpu/c_concat_kernel.cc index e50050e378253..5790c6e7029a9 100644 --- a/paddle/phi/kernels/xpu/c_concat_kernel.cc +++ b/paddle/phi/kernels/xpu/c_concat_kernel.cc @@ -76,8 +76,8 @@ void CConcatKernel(const Context& dev_ctx, int axis = x->dims().size() - 1; auto out_dims = x->dims(); out_dims[out_dims.size() - 1] *= nranks; - int rows_per_tensor = x->dims()[0]; - int offset = 0; + int64_t rows_per_tensor = x->dims()[0]; + int64_t offset = 0; for (int i = 0; i < nranks; i++) { phi::DenseTensor temp = temp_out.Slice(offset, offset + rows_per_tensor); inputs.emplace_back(temp); diff --git a/paddle/phi/kernels/xpu/c_embedding_kernel.cc b/paddle/phi/kernels/xpu/c_embedding_kernel.cc index 4b72b041e9a44..51e0c4d02f453 100644 --- a/paddle/phi/kernels/xpu/c_embedding_kernel.cc +++ b/paddle/phi/kernels/xpu/c_embedding_kernel.cc @@ -33,12 +33,6 @@ void CEmbeddingKernel(const Context& dev_ctx, const int64_t height = w.dims()[0]; const int64_t width = w.dims()[1]; - // int embedding(Context* ctx, const T* x, const TID* indices, T* y, int xm, - // int n, int ym, int padding_idx, TID start_index = 0); - - // xm: table height: number of entries of table. - // n: embedding dim: number of float value within single entry. - // ym: number of elements of input ids. const auto& index_type = ids.dtype(); if (index_type == phi::DataType::INT32) { int r = xpu::paddle_embedding(dev_ctx.x_context(), diff --git a/paddle/phi/kernels/xpu/clip_by_norm_kernel.cc b/paddle/phi/kernels/xpu/clip_by_norm_kernel.cc index 329c5cb28b791..fa089551ae01d 100644 --- a/paddle/phi/kernels/xpu/clip_by_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/clip_by_norm_kernel.cc @@ -34,8 +34,8 @@ void ClipByNormKernel(const Context& dev_ctx, "Please check if it is created correctly.")); const auto& x_dims = input->dims(); - std::vector xshape(x_dims.size()); - std::vector rdims(x_dims.size()); + std::vector xshape(x_dims.size()); + std::vector rdims(x_dims.size()); for (int i = 0; i < x_dims.size(); i++) { xshape[i] = x_dims[i]; rdims[i] = i; diff --git a/paddle/phi/kernels/xpu/concat_and_split_functor.cc b/paddle/phi/kernels/xpu/concat_and_split_functor.cc index 08d2832107d70..58c4732826c90 100644 --- a/paddle/phi/kernels/xpu/concat_and_split_functor.cc +++ b/paddle/phi/kernels/xpu/concat_and_split_functor.cc @@ -38,9 +38,9 @@ class ConcatFunctor { int num = input.size(); auto input_dims = input[0].dims(); - std::vector> xdims_list(num); + std::vector> xdims_list(num); for (int i = 0; i < num; ++i) { - std::vector tmp_dims(input_dims.size()); + std::vector tmp_dims(input_dims.size()); for (int j = 0; j < input_dims.size(); ++j) { tmp_dims[j] = input[i].dims()[j]; } @@ -89,9 +89,9 @@ class SplitFunctor { if (input_dims.size() == 0) { input_dims = {1}; } - std::vector split_list(num); - std::vector xdims_list(input_dims.size()); - int total_length = 0; + std::vector split_list(num); + std::vector xdims_list(input_dims.size()); + int64_t total_length = 0; for (int i = 0; i < num; ++i) { auto ins_i_dims = ins[i]->dims(); // special for 0-dim shape diff --git a/paddle/phi/kernels/xpu/concat_grad_kernel.cc b/paddle/phi/kernels/xpu/concat_grad_kernel.cc index a005f1f0fe4cc..729516ec7212b 100644 --- a/paddle/phi/kernels/xpu/concat_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/concat_grad_kernel.cc @@ -70,9 +70,9 @@ void ConcatGradKernel(const Context& dev_ctx, out_grad.dims().size())); auto input_dims = x[0]->dims(); - std::vector split_list(x.size()); - std::vector xdims_list(input_dims.size()); - int total_length = 0; + std::vector split_list(x.size()); + std::vector xdims_list(input_dims.size()); + int64_t total_length = 0; for (size_t i = 0; i < x.size(); ++i) { split_list[i] = x[i]->dims()[axis]; total_length += x[i]->dims()[axis]; diff --git a/paddle/phi/kernels/xpu/concat_kernel.cc b/paddle/phi/kernels/xpu/concat_kernel.cc index 2834b143f4cad..e34e52168d8f7 100644 --- a/paddle/phi/kernels/xpu/concat_kernel.cc +++ b/paddle/phi/kernels/xpu/concat_kernel.cc @@ -90,13 +90,13 @@ void ConcatKernel(const Context& dev_ctx, } } - std::vector> xdims_list; + std::vector> xdims_list; std::vector ptrs; for (unsigned int i = 0; i < x.size(); ++i) { if (x[i] && x[i]->numel() > 0) { ptrs.push_back(reinterpret_cast(x[i]->data())); int size = x[i]->dims().size(); - std::vector tmp_dims(size); + std::vector tmp_dims(size); for (int j = 0; j < size; ++j) { tmp_dims[j] = x[i]->dims()[j]; } diff --git a/paddle/phi/kernels/xpu/conv_grad_kernel.cc b/paddle/phi/kernels/xpu/conv_grad_kernel.cc index 20c8e1d3f2290..ff5a327ff3fe9 100644 --- a/paddle/phi/kernels/xpu/conv_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_grad_kernel.cc @@ -238,7 +238,7 @@ void Conv3DGradKernel(const Context& dev_ctx, UpdatePaddingAndDilation( &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); - int batch_size = input.dims()[0]; + int64_t batch_size = input.dims()[0]; int64_t img_c = input.dims()[1]; int64_t img_d = input.dims()[2]; int64_t img_h = input.dims()[3]; diff --git a/paddle/phi/kernels/xpu/conv_kernel.cc b/paddle/phi/kernels/xpu/conv_kernel.cc index e405e121e8cbb..c51dd9b2eeba9 100644 --- a/paddle/phi/kernels/xpu/conv_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_kernel.cc @@ -85,7 +85,8 @@ void ConvKernel(const Context& dev_ctx, if (data_format == "NHWC") { filter_data_tmp = RAII_GUARD.alloc(filter.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp); - std::vector filter_shape = common::vectorize(filter.dims()); + std::vector filter_shape = + common::vectorize(filter.dims()); int r = xpu::transpose(dev_ctx.x_context(), filter_data, filter_data_tmp, @@ -228,7 +229,8 @@ void Conv3DKernel(const Context& dev_ctx, if (data_format == "NDHWC") { filter_data_tmp = RAII_GUARD.alloc(filter.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp); - std::vector filter_shape = common::vectorize(filter.dims()); + std::vector filter_shape = + common::vectorize(filter.dims()); int r = xpu::transpose(dev_ctx.x_context(), filter_data, filter_data_tmp, diff --git a/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc b/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc index 5c911475af25f..4dfafc8660b78 100644 --- a/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc @@ -41,8 +41,12 @@ void Conv2dTransposeGradKernel(const Context& ctx, DenseTensor filter_ = filter; if (!dx && !dfilter) return; - std::vector paddings_ = paddings; - std::vector dilations_ = dilations; + std::vector strides_ = + std::vector(strides.begin(), strides.end()); + std::vector paddings_ = + std::vector(paddings.begin(), paddings.end()); + std::vector dilations_ = + std::vector(dilations.begin(), dilations.end()); PADDLE_ENFORCE_EQ( data_format == "NHWC" || data_format == "NDHWC", @@ -52,17 +56,21 @@ void Conv2dTransposeGradKernel(const Context& ctx, DDim in_data_dims = slice_ddim(x.dims(), 2, x.dims().size()); DDim filter_data_dims = slice_ddim(filter_.dims(), 2, filter_.dims().size()); - std::vector ksize = common::vectorize(filter_data_dims); - UpdatePaddingAndDilation( - &paddings_, &dilations_, padding_algorithm, in_data_dims, strides, ksize); + std::vector ksize = common::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings_, + &dilations_, + padding_algorithm, + in_data_dims, + strides_, + ksize); - const int batch_size = static_cast(x.dims()[0]); - const int img_yc = static_cast(x.dims()[1]); - const int img_yh = static_cast(x.dims()[2]); - const int img_yw = static_cast(x.dims()[3]); - const int img_xc = static_cast(dout.dims()[1]); - const int img_xh = static_cast(dout.dims()[2]); - const int img_xw = static_cast(dout.dims()[3]); + const int64_t batch_size = x.dims()[0]; + const int64_t img_yc = x.dims()[1]; + const int64_t img_yh = x.dims()[2]; + const int64_t img_yw = x.dims()[3]; + const int64_t img_xc = dout.dims()[1]; + const int64_t img_xh = dout.dims()[2]; + const int64_t img_xw = dout.dims()[3]; if (dx) { ctx.template Alloc(dx); } @@ -88,7 +96,7 @@ void Conv2dTransposeGradKernel(const Context& ctx, img_xh, img_xw, ksize, - strides, + strides_, paddings_, dilations_, groups, @@ -115,7 +123,7 @@ void Conv2dTransposeGradKernel(const Context& ctx, img_xh, img_xw, ksize, - strides, + strides_, paddings_, dilations_, groups, diff --git a/paddle/phi/kernels/xpu/conv_transpose_kernel.cc b/paddle/phi/kernels/xpu/conv_transpose_kernel.cc index 64c24633febc5..79183f83000d8 100644 --- a/paddle/phi/kernels/xpu/conv_transpose_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_transpose_kernel.cc @@ -27,21 +27,6 @@ namespace xpudnn = baidu::xpu::xpudnn; #endif namespace phi { -// target_len == 2 || target_len == 4 -inline std::vector vector_extend(const std::vector& src, - int target_len) { - if (target_len == 2 && src.size() == 1) { - return {src[0], src[0]}; - } - if (target_len == 4 && src.size() == 1) { - return {src[0], src[0], src[0], src[0]}; - } - if (target_len == 4 && src.size() == 2) { - return {src[0], src[0], src[1], src[1]}; - } - return src; -} - template void Conv2dTransposeKernel(const Context& ctx, const DenseTensor& x, @@ -83,11 +68,11 @@ void Conv2dTransposeKernel(const Context& ctx, strides_, ksize); - const int64_t batch_size = static_cast(x.dims()[0]); - const int64_t img_yc = static_cast(x.dims()[1]); - const int64_t img_xc = static_cast(out->dims()[1]); - const int64_t img_xh = static_cast(out->dims()[2]); - const int64_t img_xw = static_cast(out->dims()[3]); + const int64_t batch_size = x.dims()[0]; + const int64_t img_yc = x.dims()[1]; + const int64_t img_xc = out->dims()[1]; + const int64_t img_xh = out->dims()[2]; + const int64_t img_xw = out->dims()[3]; int fc_calc_type = FCCalcType(); if (fc_calc_type == XPUFCCalcType::FC_INT32) { @@ -223,17 +208,25 @@ void Conv2dTransposeKernel(const Context& ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_fusion_v2"); } #else - std::vector ksize = common::vectorize(filter_data_dims); - std::vector paddings_ = paddings; - std::vector dilations_ = dilations; - UpdatePaddingAndDilation( - &paddings_, &dilations_, padding_algorithm, in_data_dims, strides, ksize); + std::vector ksize = common::vectorize(filter_data_dims); + std::vector strides_ = + std::vector(strides.begin(), strides.end()); + std::vector paddings_ = + std::vector(paddings.begin(), paddings.end()); + std::vector dilations_ = + std::vector(dilations.begin(), dilations.end()); + UpdatePaddingAndDilation(&paddings_, + &dilations_, + padding_algorithm, + in_data_dims, + strides_, + ksize); - const int batch_size = static_cast(x.dims()[0]); - const int img_yc = static_cast(x.dims()[1]); - const int img_xc = static_cast(out->dims()[1]); - const int img_xh = static_cast(out->dims()[2]); - const int img_xw = static_cast(out->dims()[3]); + const int64_t batch_size = x.dims()[0]; + const int64_t img_yc = x.dims()[1]; + const int64_t img_xc = out->dims()[1]; + const int64_t img_xh = out->dims()[2]; + const int64_t img_xw = out->dims()[3]; int fc_calc_type = FCCalcType(); if (fc_calc_type == XPUFCCalcType::FC_INT32) { @@ -306,8 +299,8 @@ void Conv2dTransposeKernel(const Context& ctx, } else { // xpu::conv2d_transpose_v2 do not support int_with_ll now // use xpu::conv2d_transpose - int img_yh = static_cast(x.dims()[2]); - int img_yw = static_cast(x.dims()[3]); + int64_t img_yh = x.dims()[2]; + int64_t img_yw = x.dims()[3]; int r = xpu::conv2d_transpose( ctx.x_context(), x.data(), diff --git a/paddle/phi/kernels/xpu/cross_entropy_grad_kernel.cc b/paddle/phi/kernels/xpu/cross_entropy_grad_kernel.cc index abe1bcd9074d5..a4572fea4187c 100644 --- a/paddle/phi/kernels/xpu/cross_entropy_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/cross_entropy_grad_kernel.cc @@ -36,10 +36,10 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, const int rank = logit_grad->dims().size(); const int axis = phi::funcs::CanonicalAxis(axis_in, rank); - const int n = phi::funcs::SizeToAxis(axis, logit_grad->dims()); - const int d = phi::funcs::SizeFromAxis(axis, logit_grad->dims()); + const int64_t n = phi::funcs::SizeToAxis(axis, logit_grad->dims()); + const int64_t d = phi::funcs::SizeFromAxis(axis, logit_grad->dims()); - int r = XPU_SUCCESS; + int r = 0; if (axis == rank - 1) { if (soft_label) { @@ -89,9 +89,9 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_softmax_with_cross_entropy_grad"); } } else { - int t = logit_grad->dims()[axis]; + int64_t t = logit_grad->dims()[axis]; xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - int len = softmax.numel(); + int64_t len = softmax.numel(); XPUType* trans_logit = RAII_GUARD.alloc_l3_or_gm(len); PADDLE_ENFORCE_XDNN_NOT_NULL(trans_logit); diff --git a/paddle/phi/kernels/xpu/cross_entropy_kernel.cc b/paddle/phi/kernels/xpu/cross_entropy_kernel.cc index c8e800d502e61..a574100165aac 100644 --- a/paddle/phi/kernels/xpu/cross_entropy_kernel.cc +++ b/paddle/phi/kernels/xpu/cross_entropy_kernel.cc @@ -45,7 +45,7 @@ void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx, auto softmax_data = reinterpret_cast(softmax->data()); auto loss_data = reinterpret_cast(loss->data()); - int r = XPU_SUCCESS; + int r = 0; xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); if (!use_softmax) { // For cross entropy only cases, logits are outputs of softmax diff --git a/paddle/phi/kernels/xpu/cum_kernel.cc b/paddle/phi/kernels/xpu/cum_kernel.cc index a4dc4db547e72..29f7614ee83f8 100644 --- a/paddle/phi/kernels/xpu/cum_kernel.cc +++ b/paddle/phi/kernels/xpu/cum_kernel.cc @@ -40,12 +40,12 @@ void CumsumKernel(const Context& dev_ctx, } // prepare for call xdnn api - std::vector x_shape = common::vectorize(x.dims()); + std::vector x_shape = common::vectorize(x.dims()); int axis_as_int = axis.to(); if (flatten) { // flatten to 1-dim vector - x_shape = {static_cast(x.numel())}; + x_shape = {x.numel()}; axis_as_int = 0; } else { // not flatten diff --git a/paddle/phi/kernels/xpu/deformable_conv_grad_kernel.cc b/paddle/phi/kernels/xpu/deformable_conv_grad_kernel.cc index 45b1d33a9f7ff..c582b52e835b3 100644 --- a/paddle/phi/kernels/xpu/deformable_conv_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/deformable_conv_grad_kernel.cc @@ -74,7 +74,7 @@ void DeformableConvGradKernel(const Context& dev_ctx, "Filter high and weight should less than 8 on xpu " "in deformable_conv_grad op.")); - const int batch_size = static_cast(x.dims()[0]); + const int64_t batch_size = x.dims()[0]; std::vector output_shape_vec(common::vectorize(out_grad.dims())); const T* output_grad_ptr = out_grad.data(); const T* input_ptr = x.data(); @@ -102,18 +102,17 @@ void DeformableConvGradKernel(const Context& dev_ctx, dmask_data, errors::ResourceExhausted("XPU has no enough memory")); } - int input_dim = x.numel() / x.dims()[0]; - int input_offset_dim = offset.numel() / offset.dims()[0]; - int input_mask_dim = mask->numel() / mask->dims()[0]; - int output_dim = + int64_t input_dim = x.numel() / x.dims()[0]; + int64_t input_offset_dim = offset.numel() / offset.dims()[0]; + int64_t input_mask_dim = mask->numel() / mask->dims()[0]; + int64_t output_dim = output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3]; - std::vector ksize{static_cast(filter.dims()[2]), - static_cast(filter.dims()[3])}; - int n = im2col_step; - int c = x.dims()[1]; - int h = x.dims()[2]; - int w = x.dims()[3]; - int f = filter.dims()[0]; + std::vector ksize{filter.dims()[2], filter.dims()[3]}; + int64_t n = static_cast(im2col_step); + int64_t c = x.dims()[1]; + int64_t h = x.dims()[2]; + int64_t w = x.dims()[3]; + int64_t f = filter.dims()[0]; T* filter_grad_tmp = RAII_GUARD.alloc_l3_or_gm(filter_grad->numel()); PADDLE_ENFORCE_NOT_NULL( @@ -136,27 +135,27 @@ void DeformableConvGradKernel(const Context& dev_ctx, dev_ctx.x_context(), filter_grad_tmp, filter.numel(), zero); PADDLE_ENFORCE_XDNN_SUCCESS(r_filter, "constant"); - for (int i = 0; i < batch_size / im2col_step; ++i) { + for (int64_t i = 0; i < batch_size / n; ++i) { int r = xpu::deformable_conv_grad( dev_ctx.x_context(), - input_ptr + i * im2col_step * input_dim, + input_ptr + i * n * input_dim, filter_ptr, - offset_ptr + i * im2col_step * input_offset_dim, - mask_ptr + i * im2col_step * input_mask_dim, - output_grad_ptr + i * im2col_step * output_dim, - dx_data + i * im2col_step * input_dim, + offset_ptr + i * n * input_offset_dim, + mask_ptr + i * n * input_mask_dim, + output_grad_ptr + i * n * output_dim, + dx_data + i * n * input_dim, filter_grad_tmp, - doffset_data + i * im2col_step * input_offset_dim, - dmask_data + i * im2col_step * input_mask_dim, + doffset_data + i * n * input_offset_dim, + dmask_data + i * n * input_mask_dim, n, c, h, w, f, ksize, - strides, - paddings, - dilations, + std::vector{strides.begin(), strides.end()}, + std::vector{paddings.begin(), paddings.end()}, + std::vector{dilations.begin(), dilations.end()}, groups, deformable_groups, nullptr, diff --git a/paddle/phi/kernels/xpu/deformable_conv_kernel.cc b/paddle/phi/kernels/xpu/deformable_conv_kernel.cc index 29c5d6896f3ed..1c3c023660660 100644 --- a/paddle/phi/kernels/xpu/deformable_conv_kernel.cc +++ b/paddle/phi/kernels/xpu/deformable_conv_kernel.cc @@ -53,7 +53,7 @@ void DeformableConvKernel(const Context& dev_ctx, "Filter high and weight should less than 8 on xpu " "in deformable_conv op.")); - const int batch_size = static_cast(x.dims()[0]); + const int64_t batch_size = x.dims()[0]; std::vector output_shape_vec(common::vectorize(out->dims())); const T* input_ptr = x.data(); @@ -66,36 +66,35 @@ void DeformableConvKernel(const Context& dev_ctx, const int zero = 0; int r = xpu::constant(dev_ctx.x_context(), output_prt, out->numel(), zero); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); - int input_dim = x.numel() / x.dims()[0]; - int input_offset_dim = offset.numel() / offset.dims()[0]; - int input_mask_dim = mask->numel() / mask->dims()[0]; - int output_dim = + int64_t input_dim = x.numel() / x.dims()[0]; + int64_t input_offset_dim = offset.numel() / offset.dims()[0]; + int64_t input_mask_dim = mask->numel() / mask->dims()[0]; + int64_t output_dim = output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3]; - std::vector ksize{static_cast(filter.dims()[2]), - static_cast(filter.dims()[3])}; - int n = im2col_step; - int c = x.dims()[1]; - int h = x.dims()[2]; - int w = x.dims()[3]; - int f = filter.dims()[0]; + std::vector ksize{filter.dims()[2], filter.dims()[3]}; + int64_t n = static_cast(im2col_step); + int64_t c = x.dims()[1]; + int64_t h = x.dims()[2]; + int64_t w = x.dims()[3]; + int64_t f = filter.dims()[0]; - for (int i = 0; i < batch_size / im2col_step; ++i) { + for (int64_t i = 0; i < batch_size / n; ++i) { int r = xpu::deformable_conv( dev_ctx.x_context(), - input_ptr + i * im2col_step * input_dim, + input_ptr + i * n * input_dim, filter_ptr, - offset_ptr + i * im2col_step * input_offset_dim, - mask_ptr + i * im2col_step * input_mask_dim, - output_prt + i * im2col_step * output_dim, + offset_ptr + i * n * input_offset_dim, + mask_ptr + i * n * input_mask_dim, + output_prt + i * n * output_dim, n, c, h, w, f, ksize, - strides, - paddings, - dilations, + std::vector{strides.begin(), strides.end()}, + std::vector{paddings.begin(), paddings.end()}, + std::vector{dilations.begin(), dilations.end()}, groups, deformable_groups, nullptr, diff --git a/paddle/phi/kernels/xpu/diag_kernel.cc b/paddle/phi/kernels/xpu/diag_kernel.cc index 89c991742e83c..8fb84f6a04aab 100644 --- a/paddle/phi/kernels/xpu/diag_kernel.cc +++ b/paddle/phi/kernels/xpu/diag_kernel.cc @@ -31,11 +31,11 @@ void DiagKernel(const Context& dev_ctx, dev_ctx.template Alloc(out); auto* out_data = reinterpret_cast(out->data()); - auto x_shape = common::vectorize(x.dims()); - auto out_shape = common::vectorize(out->dims()); + auto x_shape = common::vectorize(x.dims()); + auto out_shape = common::vectorize(out->dims()); if (x.dims().size() == 0) { - x_shape = std::vector({1}); + x_shape = std::vector({1}); } int r = xpu::diag(dev_ctx.x_context(), diff --git a/paddle/phi/kernels/xpu/distribute_fpn_proposals_kernel.cc b/paddle/phi/kernels/xpu/distribute_fpn_proposals_kernel.cc index 6a3aee8356901..201a02548b27c 100644 --- a/paddle/phi/kernels/xpu/distribute_fpn_proposals_kernel.cc +++ b/paddle/phi/kernels/xpu/distribute_fpn_proposals_kernel.cc @@ -41,7 +41,7 @@ static void Sort(const XPUContext& dev_ctx, DenseTensor index_t; index_t.Resize({value.numel()}); int* index = dev_ctx.template HostAlloc(&index_t); - for (int i = 0; i < value.numel(); ++i) { + for (int64_t i = 0; i < value.numel(); ++i) { index[i] = i; } diff --git a/paddle/phi/kernels/xpu/elementwise.h b/paddle/phi/kernels/xpu/elementwise.h index 9956aa2214b30..3cd2b21e22583 100644 --- a/paddle/phi/kernels/xpu/elementwise.h +++ b/paddle/phi/kernels/xpu/elementwise.h @@ -78,7 +78,7 @@ void XPUElementwise(const XPUContext& dev_ctx, } } - int ret = xpu::SUCCESS; + int ret = 0; // For [2, 3] + [] --> [2, 3] + [1, 1] // For [] + [2, 3] --> [1, 1] + [2, 3] diff --git a/paddle/phi/kernels/xpu/elementwise_add_kernel.cc b/paddle/phi/kernels/xpu/elementwise_add_kernel.cc index 81edc43058596..9d778b7a089ae 100644 --- a/paddle/phi/kernels/xpu/elementwise_add_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_add_kernel.cc @@ -47,7 +47,7 @@ void AddKernel(const Context& dev_ctx, const float* x_data = x.data(); float* z_data = out->data(); - int ret = xpu::SUCCESS; + int ret = 0; if (y.dtype() == phi::DataType::BFLOAT16) { using YType = DataTypeToCppType::type; using XPUYType = typename XPUTypeTrait::Type; @@ -100,8 +100,8 @@ void GradAddXPUKernel(const Context& dev_ctx, using XPUType = typename XPUTypeTrait::Type; dev_ctx.template Alloc(out); - auto x_shape = common::vectorize(x.dims()); - auto y_shape = common::vectorize(y.dims()); + auto x_shape = common::vectorize(x.dims()); + auto y_shape = common::vectorize(y.dims()); int r = xpu::broadcast_add(dev_ctx.x_context(), reinterpret_cast(x.data()), reinterpret_cast(y.data()), diff --git a/paddle/phi/kernels/xpu/embedding_grad_kernel.cc b/paddle/phi/kernels/xpu/embedding_grad_kernel.cc index c22c469a7ec11..7f528f0efc794 100644 --- a/paddle/phi/kernels/xpu/embedding_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/embedding_grad_kernel.cc @@ -41,12 +41,6 @@ void EmbeddingGradKernel(const Context& ctx, } int64_t ids_numel = ids_t->numel(); - PADDLE_ENFORCE_EQ( - ids_numel <= std::numeric_limits::max(), - true, - common::errors::OutOfRange( - "Number of ids greater than int32_t::max , please check " - "number of ids in LookupTableV2GradXPUKernel.")); auto& dev_ctx = ctx; xpu::ctx_guard RAII_GUARD(ctx.x_context()); @@ -63,9 +57,9 @@ void EmbeddingGradKernel(const Context& ctx, const T* d_output_data = d_output_t->data(); T* d_table_data = dev_ctx.template Alloc(d_table_t); - int xm = d_table_t->dims()[0]; - int ym = static_cast(ids_numel); - int n = d_table_t->dims()[1]; + int64_t xm = d_table_t->dims()[0]; + int64_t ym = ids_numel; + int64_t n = d_table_t->dims()[1]; int r = xpu::embedding_grad( dev_ctx.x_context(), diff --git a/paddle/phi/kernels/xpu/embedding_kernel.cc b/paddle/phi/kernels/xpu/embedding_kernel.cc index 9f0a8a592ea8b..a003ce55c952e 100644 --- a/paddle/phi/kernels/xpu/embedding_kernel.cc +++ b/paddle/phi/kernels/xpu/embedding_kernel.cc @@ -44,17 +44,10 @@ void EmbeddingKernel(const Context &ctx, auto *table = table_t->data(); auto *output = dev_ctx.template Alloc(output_t); - PADDLE_ENFORCE_EQ( - ids_numel <= std::numeric_limits::max(), - true, - common::errors::OutOfRange( - "Number of ids greater than int32_t::max , please check " - "number of ids in LookupTableV2XPUKernel.")); - - int ym = static_cast(ids_numel); + int64_t ym = ids_numel; - size_t xm = table_t->dims()[0]; - size_t n = table_t->dims()[1]; + int64_t xm = table_t->dims()[0]; + int64_t n = table_t->dims()[1]; int r; xpu::ctx_guard RAII_GUARD(ctx.x_context()); diff --git a/paddle/phi/kernels/xpu/expand_as_kernel.cc b/paddle/phi/kernels/xpu/expand_as_kernel.cc index 699247313afcd..df875c7fbc104 100644 --- a/paddle/phi/kernels/xpu/expand_as_kernel.cc +++ b/paddle/phi/kernels/xpu/expand_as_kernel.cc @@ -27,8 +27,7 @@ void ExpandAs(const Context& context, const std::vector& target_shape, DenseTensor* out) { using XPUType = typename XPUTypeTrait::Type; - auto in_dims = x.dims(); - auto vec_in_dims = common::vectorize(in_dims); + auto vec_in_dims = common::vectorize(x.dims()); auto diff = target_shape.size() - vec_in_dims.size(); vec_in_dims.insert(vec_in_dims.begin(), diff, 1); for (size_t i = 0; i < vec_in_dims.size(); ++i) { @@ -67,7 +66,7 @@ void ExpandAs(const Context& context, auto& x_shape = vec_in_dims; auto out_shape = common::vectorize(out_dims); - int r = XPU_SUCCESS; + int r = 0; if (std::is_same::value) { auto x_data = reinterpret_cast(x.data()); diff --git a/paddle/phi/kernels/xpu/expand_kernel.cc b/paddle/phi/kernels/xpu/expand_kernel.cc index 1f03c97075c5b..e51f5dc8daf8c 100644 --- a/paddle/phi/kernels/xpu/expand_kernel.cc +++ b/paddle/phi/kernels/xpu/expand_kernel.cc @@ -103,7 +103,7 @@ void ExpandKernel(const Context& ctx, out_shape = {1}; } - int r = XPU_SUCCESS; + int r = 0; if (std::is_same::value) { auto x_data = reinterpret_cast(x.data()); auto out_data = reinterpret_cast(out->data()); diff --git a/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc b/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc index 3f66719493ae1..d7241c58a12e2 100644 --- a/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc @@ -37,13 +37,13 @@ void FlashAttnGradKernelBase( const paddle::optional& attn_mask, const paddle::optional& startend_row_indices, const DenseTensor& dout, - const int batch_size, + const int64_t batch_size, const Scalar& max_seqlen_q_, const Scalar& max_seqlen_k_, - const int num_heads, - const int num_heads_k, - const int head_size, - const int head_size_v, + const int64_t num_heads, + const int64_t num_heads_k, + const int64_t head_size, + const int64_t head_size_v, float scale, float dropout, bool causal, diff --git a/paddle/phi/kernels/xpu/flash_attn_kernel.cc b/paddle/phi/kernels/xpu/flash_attn_kernel.cc index 7f69adbbbf718..4ebfa0c6e5d70 100644 --- a/paddle/phi/kernels/xpu/flash_attn_kernel.cc +++ b/paddle/phi/kernels/xpu/flash_attn_kernel.cc @@ -33,13 +33,13 @@ void FlashAttnKernelBase( const paddle::optional& fixed_seed_offset, const paddle::optional& attn_mask, const paddle::optional& startend_row_indices, - const int batch_size, + const int64_t batch_size, const Scalar& max_seqlen_q_, const Scalar& max_seqlen_k_, - const int num_heads, - const int num_heads_k, - const int head_size, - const int head_size_v, + const int64_t num_heads, + const int64_t num_heads_k, + const int64_t head_size, + const int64_t head_size_v, float scale, float dropout, bool causal, @@ -268,11 +268,11 @@ void FlashAttnUnpaddedKernel( // q, k, v [batch_size * seq_len, num_heads, head_dim] std::vector dims = common::vectorize(q.dims()); - const int batch_size = cu_seqlens_q.numel() - 1; - const int num_heads = dims[1]; - const int head_size = dims[2]; - const int num_heads_k = k.dims()[1]; - const int head_size_v = v.dims()[2]; + const int64_t batch_size = cu_seqlens_q.numel() - 1; + const int64_t num_heads = dims[1]; + const int64_t head_size = dims[2]; + const int64_t num_heads_k = k.dims()[1]; + const int64_t head_size_v = v.dims()[2]; #ifndef PADDLE_WITH_XPU_XRE5 // lod info, only support qlod == klod std::vector qlod_vec(batch_size + 1, 0); diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index 8d2152c829c1f..99bc7d11b01db 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -107,7 +107,7 @@ void FullBatchSizeLikeKernel(const Context& dev_ctx, if (x.lod().size() && x_batch_size_dim == 0) { // set the correct batch size for the DenseTensor. auto odims = out->dims(); - odims[out_batch_size_dim] = static_cast(x.lod().back().size()) - 1; + odims[out_batch_size_dim] = x.lod().back().size() - 1; FullKernel(dev_ctx, common::vectorize(odims), val, dtype, out); } FullLikeKernel(dev_ctx, x, val, dtype, out); diff --git a/paddle/phi/kernels/xpu/fused_attention_grad_kernel.cc b/paddle/phi/kernels/xpu/fused_attention_grad_kernel.cc index 0df5ec0703151..732b69537630a 100644 --- a/paddle/phi/kernels/xpu/fused_attention_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/fused_attention_grad_kernel.cc @@ -53,7 +53,7 @@ void FusedAttentionGradKernel( const DenseTensor &fmha_out, const DenseTensor &out_linear_out, const DenseTensor &dropout_mask_out, - int num_heads, + int num_heads_, // unused bool transpose_qkv_wb, bool pre_layer_norm, float epsilon, @@ -209,11 +209,11 @@ void FusedAttentionGradKernel( const auto input_x_dims = x.dims(); const auto qkv_w_dims = qkv_weight.dims(); - int batch_size = input_x_dims[0]; - int seq_len = input_x_dims[1]; - int embed_dims = input_x_dims[2]; - num_heads = qkv_w_dims[1]; - int head_dims = qkv_w_dims[2]; + int64_t batch_size = input_x_dims[0]; + int64_t seq_len = input_x_dims[1]; + int64_t embed_dims = input_x_dims[2]; + int64_t num_heads = qkv_w_dims[1]; + int64_t head_dims = qkv_w_dims[2]; xpu::Context *xpu_ctx = dev_ctx.x_context(); xpu::ctx_guard RAII_GUARD(xpu_ctx); @@ -334,7 +334,7 @@ void FusedAttentionGradKernel( {0LL}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); { - int qkv_size = batch_size * seq_len * num_heads * head_dims; + int64_t qkv_size = batch_size * seq_len * num_heads * head_dims; const XPUTypeT *q_out_ptr = qkv_transpose_out_ptr; const XPUTypeT *k_out_ptr = q_out_ptr + qkv_size; const XPUTypeT *v_out_ptr = k_out_ptr + qkv_size; diff --git a/paddle/phi/kernels/xpu/fused_attention_kernel.cc b/paddle/phi/kernels/xpu/fused_attention_kernel.cc index 8fcb4efac1cc6..9c23641d1ac0e 100644 --- a/paddle/phi/kernels/xpu/fused_attention_kernel.cc +++ b/paddle/phi/kernels/xpu/fused_attention_kernel.cc @@ -34,7 +34,7 @@ void FusedAttentionKernel(const Context &dev_ctx, const paddle::optional &out_linear_bias, const paddle::optional &ln_scale_2, const paddle::optional &ln_bias_2, - int num_heads, + int num_heads_, // unused bool transpose_qkv_wb, bool pre_layer_norm, float epsilon, @@ -118,11 +118,11 @@ void FusedAttentionKernel(const Context &dev_ctx, const auto input_x_dims = x.dims(); const auto qkv_w_dims = qkv_weight.dims(); - int batch_size = input_x_dims[0]; - int seq_len = input_x_dims[1]; - int embed_dims = input_x_dims[2]; - num_heads = qkv_w_dims[1]; - int head_dims = qkv_w_dims[2]; + int64_t batch_size = input_x_dims[0]; + int64_t seq_len = input_x_dims[1]; + int64_t embed_dims = input_x_dims[2]; + int64_t num_heads = qkv_w_dims[1]; + int64_t head_dims = qkv_w_dims[2]; // 输入指针 const XPUTypeT *input_x_ptr = reinterpret_cast(x.data()); @@ -205,14 +205,14 @@ void FusedAttentionKernel(const Context &dev_ctx, XPUTypeT *qkv_ptr = NULL; // qkv[batch_size, num_heads, seq_len, head_dims] XPUTypeT *linear_out_ptr = NULL; // x4, x5 [batch_size, seq_len, embed_dims] - int temp_size_1 = batch_size * seq_len * 3 * num_heads * head_dims; - int temp_size_2 = batch_size * num_heads * seq_len * seq_len; - int temp_size_3 = batch_size * num_heads * seq_len * head_dims; - int temp_size_4 = batch_size * seq_len * embed_dims; + int64_t temp_size_1 = batch_size * seq_len * 3 * num_heads * head_dims; + int64_t temp_size_2 = batch_size * num_heads * seq_len * seq_len; + int64_t temp_size_3 = batch_size * num_heads * seq_len * head_dims; + int64_t temp_size_4 = batch_size * seq_len * embed_dims; - std::vector temp_vec = { + std::vector temp_vec = { temp_size_1, temp_size_2, temp_size_3, temp_size_4}; - std::sort(temp_vec.begin(), temp_vec.end(), std::greater()); + std::sort(temp_vec.begin(), temp_vec.end(), std::greater()); XPUTypeT *max_gm_ptr = RAII_GUARD.alloc(temp_vec[0]); PADDLE_ENFORCE_XDNN_NOT_NULL(max_gm_ptr); qkv_before_transpose_ptr = max_gm_ptr; @@ -287,16 +287,16 @@ void FusedAttentionKernel(const Context &dev_ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - int qkv_every_size = batch_size * seq_len * num_heads * head_dims; + int64_t qkv_every_size = batch_size * seq_len * num_heads * head_dims; { float alpha = 1.0 / sqrt(head_dims); - r = scale(xpu_ctx, - qkv_transpose_out_ptr, - qkv_transpose_out_ptr, - qkv_every_size, - false, - alpha, - 0.0f); + r = xpu::scale(xpu_ctx, + qkv_transpose_out_ptr, + qkv_transpose_out_ptr, + qkv_every_size, + false, + alpha, + 0.0f); PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); } diff --git a/paddle/phi/kernels/xpu/gather_grad_kernel.cc b/paddle/phi/kernels/xpu/gather_grad_kernel.cc index 32e0c4630e6a6..a3d1d41d01102 100644 --- a/paddle/phi/kernels/xpu/gather_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_grad_kernel.cc @@ -61,7 +61,7 @@ void GatherGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(x_grad); using XPUType = typename XPUTypeTrait::Type; - int r = XPU_SUCCESS; + int r = 0; if (index_type == DataType::INT32) { r = xpu::gather_grad( dev_ctx.x_context(), diff --git a/paddle/phi/kernels/xpu/gather_kernel.cc b/paddle/phi/kernels/xpu/gather_kernel.cc index 528358e612e11..1531aaa6afb9d 100644 --- a/paddle/phi/kernels/xpu/gather_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_kernel.cc @@ -57,7 +57,7 @@ void GatherKernel(const Context& dev_ctx, using XPUType = typename XPUTypeTrait::Type; - int r = XPU_SUCCESS; + int r = 0; if (index_type == DataType::INT32) { r = xpu::paddle_gather( dev_ctx.x_context(), diff --git a/paddle/phi/kernels/xpu/gather_nd_grad_kernel.cc b/paddle/phi/kernels/xpu/gather_nd_grad_kernel.cc index 3001d590cba8d..ee7823e10e9a5 100644 --- a/paddle/phi/kernels/xpu/gather_nd_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_nd_grad_kernel.cc @@ -29,7 +29,7 @@ void GatherNdGradKernel(const Context &ctx, using XPUType = typename XPUTypeTrait::Type; ctx.template Alloc(x_grad); - int r = XPU_SUCCESS; + int r = 0; XPUType *dx_data = reinterpret_cast(x_grad->data()); r = xpu::constant( ctx.x_context(), dx_data, x_grad->numel(), static_cast(0)); diff --git a/paddle/phi/kernels/xpu/gather_nd_kernel.cc b/paddle/phi/kernels/xpu/gather_nd_kernel.cc index f62355c68861b..4f3f46e6a65a1 100644 --- a/paddle/phi/kernels/xpu/gather_nd_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_nd_kernel.cc @@ -85,7 +85,7 @@ void GatherNdKernel(const Context &ctx, xpu::VectorParam x_vec = { x_shape.data(), static_cast(x_shape.size()), nullptr}; - int ret = XPU_SUCCESS; + int ret = 0; #ifndef PADDLE_WITH_XPU_PLUGIN if (index_type == DataType::INT32) { ret = xpu::gather_nd( diff --git a/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc index dc2074342d780..38d5b2dcf54dd 100644 --- a/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc @@ -43,7 +43,7 @@ void GroupNormGradKernel(const Context& dev_ctx, DenseTensor* d_bias) { using XPUType = typename XPUTypeTrait::Type; xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - int ret = xpu::SUCCESS; + int ret = 0; const DataLayout data_layout = common::StringToDataLayout(data_layout_str); const auto scale_ptr = scale.get_ptr(); const auto bias_ptr = bias.get_ptr(); diff --git a/paddle/phi/kernels/xpu/index_put_grad_kernel.cc b/paddle/phi/kernels/xpu/index_put_grad_kernel.cc index 664cc71845e2a..30f5138f1a85f 100644 --- a/paddle/phi/kernels/xpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/index_put_grad_kernel.cc @@ -73,7 +73,7 @@ void IndexPutGradKernel(const Context& dev_ctx, std::copy(xshape.begin() + int_indices_v.size(), xshape.end(), value_shape_bd.begin() + index_shape.size() - 1); - int ret = xpu::SUCCESS; + int ret = 0; using XPUType = typename XPUTypeTrait::Type; if (x_grad) { phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); diff --git a/paddle/phi/kernels/xpu/interpolate_grad_kernel.cc b/paddle/phi/kernels/xpu/interpolate_grad_kernel.cc index 054856862bc15..434b9d8125a89 100644 --- a/paddle/phi/kernels/xpu/interpolate_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/interpolate_grad_kernel.cc @@ -116,7 +116,7 @@ void InterpolateGradKernel( x_grad->Resize(dim_grad); dev_ctx.template Alloc(x_grad); - int r = XPU_SUCCESS; + int r = 0; r = xpu::constant(dev_ctx.x_context(), x_grad->data(), x_grad->numel(), diff --git a/paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc b/paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc index a9c1ea43529c7..91eb132fcb234 100644 --- a/paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/kldiv_loss_grad_kernel.cc @@ -33,7 +33,7 @@ void KLDivLossGradKernel(const Context& dev_ctx, return; } - int r = XPU_SUCCESS; + int r = 0; if (log_target) { xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); diff --git a/paddle/phi/kernels/xpu/kldiv_loss_kernel.cc b/paddle/phi/kernels/xpu/kldiv_loss_kernel.cc index 1377038ee2e26..30d441c3fd6df 100644 --- a/paddle/phi/kernels/xpu/kldiv_loss_kernel.cc +++ b/paddle/phi/kernels/xpu/kldiv_loss_kernel.cc @@ -32,7 +32,7 @@ void KLDivLossKernel(const Context& dev_ctx, return; } - int r = XPU_SUCCESS; + int r = 0; if (log_target) { xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); diff --git a/paddle/phi/kernels/xpu/logical_kernel.cc b/paddle/phi/kernels/xpu/logical_kernel.cc index d94e670be3d2b..d62c8947c996e 100644 --- a/paddle/phi/kernels/xpu/logical_kernel.cc +++ b/paddle/phi/kernels/xpu/logical_kernel.cc @@ -40,7 +40,7 @@ void LogicalBinaryKernel( std::string funcname = "logical") { dev_ctx.template Alloc(out); - int r = xpu::SUCCESS; + int r = 0; const auto* x_data = x.data(); const auto* y_data = y.data(); auto* out_data = out->data(); diff --git a/paddle/phi/kernels/xpu/reduce.h b/paddle/phi/kernels/xpu/reduce.h index 9d90f390d0605..65a5f289d483a 100644 --- a/paddle/phi/kernels/xpu/reduce.h +++ b/paddle/phi/kernels/xpu/reduce.h @@ -83,7 +83,7 @@ int XPUReduce(const Context& dev_ctx, std::vector reduce_dims; GetReduceDims(x.dims(), dims, reduce_all, &reduce_dims); - int r = xpu::SUCCESS; + int r = 0; if (reduce_dims.size() == 0) { r = xpu::copy(dev_ctx.x_context(), reinterpret_cast(x_data), diff --git a/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc b/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc index bc08afbb7f6da..afc6d7ef20296 100644 --- a/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc @@ -26,7 +26,7 @@ void ScatterNdAddGradKernel(const Context &ctx, DenseTensor *x_grad, DenseTensor *updates_grad) { using XPUType = typename XPUTypeTrait::Type; - int ret = xpu::SUCCESS; + int ret = 0; const T *out_grad_data = out_grad.data(); if (x_grad) { auto *x_grad_data = ctx.template Alloc(x_grad); diff --git a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc index 0df30d9cf50fb..db6e15fecc26d 100644 --- a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc @@ -140,7 +140,7 @@ void SetValueGradImpl(const Context& dev_ctx, } phi::funcs::SetConstant set_zero; - int r = XPU_SUCCESS; + int r = 0; if (x_grad) { // Set gradient of `Input` diff --git a/paddle/phi/kernels/xpu/set_value_kernel.cc b/paddle/phi/kernels/xpu/set_value_kernel.cc index cd12e0c0847f0..09bfaa37ab108 100644 --- a/paddle/phi/kernels/xpu/set_value_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_kernel.cc @@ -136,7 +136,7 @@ void SetValueImpl(const Context& dev_ctx, // be two ops points to the output in graph: op1 -> output <- set_value. // In this case, we have to find a way to handle the running order of // set_value is what we want. - int r = XPU_SUCCESS; + int r = 0; out->Resize(in.dims()); dev_ctx.template Alloc(out); diff --git a/paddle/phi/kernels/xpu/softmax_kernel.cc b/paddle/phi/kernels/xpu/softmax_kernel.cc index 60b35fb4e92cd..fe6a2dca62c70 100644 --- a/paddle/phi/kernels/xpu/softmax_kernel.cc +++ b/paddle/phi/kernels/xpu/softmax_kernel.cc @@ -46,7 +46,7 @@ void SoftmaxKernel(const Context& dev_ctx, x_dims.push_back(x.dims()[i]); } - int r = XPU_SUCCESS; + int r = 0; auto version = phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId()); if (version == phi::backends::xpu::XPUVersion::XPU1) { diff --git a/paddle/phi/kernels/xpu/take_along_axis_kernel.cc b/paddle/phi/kernels/xpu/take_along_axis_kernel.cc index 8100f6c15ac7f..31dcfb16d179f 100644 --- a/paddle/phi/kernels/xpu/take_along_axis_kernel.cc +++ b/paddle/phi/kernels/xpu/take_along_axis_kernel.cc @@ -62,7 +62,7 @@ void TakeAlongAxisKernel(const Context& dev_ctx, } using XPUType = typename XPUTypeTrait::Type; - int r = XPU_SUCCESS; + int r = 0; #ifndef PADDLE_WITH_XPU_PLUGIN if (index_dtype == DataType::INT32) { r = xpu::gather(dev_ctx.x_context(), diff --git a/paddle/phi/kernels/xpu/tile_kernel.cc b/paddle/phi/kernels/xpu/tile_kernel.cc index 37e725d918e0d..622c222f59c97 100644 --- a/paddle/phi/kernels/xpu/tile_kernel.cc +++ b/paddle/phi/kernels/xpu/tile_kernel.cc @@ -115,7 +115,7 @@ void TileKernel(const Context& dev_ctx, } xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - int ret = XPU_SUCCESS; + int ret = 0; if (std::is_same::value) { ret = xpu::broadcast(dev_ctx.x_context(), reinterpret_cast(x.data()), diff --git a/paddle/phi/kernels/xpu/unique_kernel.cc b/paddle/phi/kernels/xpu/unique_kernel.cc index 626f3759a0c4d..eb7523121ee70 100644 --- a/paddle/phi/kernels/xpu/unique_kernel.cc +++ b/paddle/phi/kernels/xpu/unique_kernel.cc @@ -40,7 +40,7 @@ void XPUFlattenUniqueKernelImpl(const Context& dev_ctx, using XPUType = typename XPUTypeTrait::Type; const auto* x_data = x.data(); int64_t x_len = x.numel(); - int r = XPU_SUCCESS; + int r = 0; xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); int64_t unique_len_cpu = 0; int64_t* unique_len_xpu = RAII_GUARD.alloc_l3_or_gm(1); @@ -116,7 +116,7 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx, DenseTensor* counts) { using XPUType = typename XPUTypeTrait::Type; xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - int r = xpu::SUCCESS; + int r = 0; const auto* x_data = x.data(); auto* x_trans_data = RAII_GUARD.alloc_l3_or_gm(x.numel()); std::vector permute(x.dims().size()); diff --git a/paddle/phi/kernels/xpu/xpu_api_wrapper.h b/paddle/phi/kernels/xpu/xpu_api_wrapper.h index d8e77ff8e1747..50dd185aa2832 100644 --- a/paddle/phi/kernels/xpu/xpu_api_wrapper.h +++ b/paddle/phi/kernels/xpu/xpu_api_wrapper.h @@ -99,15 +99,15 @@ inline XPUFCCalcType FCCalcType() { } struct XpuFcInfo { - int bs; - int m; - int n; - int k; + int64_t bs; + int64_t m; + int64_t n; + int64_t k; bool trans_x; bool trans_y; - int stride_x; - int stride_y; - int stride_out; + int64_t stride_x; + int64_t stride_y; + int64_t stride_out; float* max_x; float* max_y; float* max_out; @@ -139,10 +139,10 @@ struct XpuFcInfo { scale_y(nullptr), scale_x_mode(0), scale_y_mode(0) {} - void InitFcInfo(int bs, - int m, - int n, - int k, + void InitFcInfo(int64_t bs, + int64_t m, + int64_t n, + int64_t k, bool trans_x, bool trans_y, float* max_x, @@ -244,17 +244,17 @@ static void xblas_fc_wrapper(xpu::Context* ctx, const XPUType* x, const XPUType* w, XPUType* y, - int m, - int n, - int k, + int64_t m, + int64_t n, + int64_t k, bool x_trans, bool w_trans, const float* x_maxptr, const float* w_maxptr, float* y_maxptr, - int ldx, - int ldw, - int ldy, + int64_t ldx, + int64_t ldw, + int64_t ldy, float alpha, float beta, const float* bias, @@ -327,11 +327,11 @@ static void xblas_fc_wrapper(xpu::Context* ctx, if constexpr (std::is_same::value) { if (std::getenv("XPU_PADDLE_FC_BFLOAT16_XTE") != nullptr) { const int MAXPTR_N = ctx->max_ptr_size(); - int x_len = m * k; + int64_t x_len = m * k; XPUTypeFP16* x_fp16 = nullptr; x_fp16 = RAII_GUARD.alloc_l3_or_gm(x_len); PADDLE_ENFORCE_XDNN_NOT_NULL(x_fp16); - int w_len = k * n; + int64_t w_len = k * n; XPUTypeFP16* w_fp16 = nullptr; w_fp16 = RAII_GUARD.alloc_l3_or_gm(w_len); PADDLE_ENFORCE_XDNN_NOT_NULL(w_fp16); @@ -454,17 +454,17 @@ static void xblas_fc_wrapper(xpu::Context* ctx, const XPUType* x, \ const XPUType* w, \ XPUType* y, \ - int m, \ - int n, \ - int k, \ + int64_t m, \ + int64_t n, \ + int64_t k, \ bool x_trans, \ bool w_trans, \ const float* x_maxptr, \ const float* w_maxptr, \ float* y_maxptr, \ - int ldx, \ - int ldw, \ - int ldy, \ + int64_t ldx, \ + int64_t ldw, \ + int64_t ldy, \ float alpha, \ float beta, \ const float* bias, \ @@ -491,20 +491,20 @@ DECLARE_UNSUPPORTED_XBLAS_FC_WRAPPER(float, XPUTypeFP16) template static void xblas_fc_batch_wrapper(xpu::Context* xpu_ctx, - int bs, + int64_t bs, bool trans_x, bool trans_w, - int m, - int n, - int k, + int64_t m, + int64_t n, + int64_t k, float alpha, const XPUType* x, - int stride_x, + int64_t stride_x, const XPUType* w, - int stride_w, + int64_t stride_w, float beta, XPUType* y, - int stride_y, + int64_t stride_y, const float* x_maxptr, const float* w_maxptr) { #ifdef PADDLE_WITH_XPU_XRE5 @@ -554,20 +554,20 @@ static void xblas_fc_batch_wrapper(xpu::Context* xpu_ctx, template <> \ void xblas_fc_batch_wrapper( \ xpu::Context * xpu_ctx, \ - int bs, \ + int64_t bs, \ bool trans_x, \ bool trans_w, \ - int m, \ - int n, \ - int k, \ + int64_t m, \ + int64_t n, \ + int64_t k, \ float alpha, \ const XPUType* x, \ - int stride_x, \ + int64_t stride_x, \ const XPUType* w, \ - int stride_w, \ + int64_t stride_w, \ float beta, \ XPUType* y, \ - int stride_y, \ + int64_t stride_y, \ const float* x_maxptr, \ const float* w_maxptr) { \ int r = xpu::Error_t::INVALID_PARAM; \ @@ -651,13 +651,13 @@ static void MatMulXPUFunction( xblas_fc_batch_api = &xblas_fc_batch_wrapper; } - int m = fcinfo.m; - int n = fcinfo.n; - int k = fcinfo.k; - int batch_size = fcinfo.bs; - int ldx = fcinfo.stride_x; - int ldy = fcinfo.stride_y; - int ldout = fcinfo.stride_out; + int64_t m = fcinfo.m; + int64_t n = fcinfo.n; + int64_t k = fcinfo.k; + int64_t batch_size = fcinfo.bs; + int64_t ldx = fcinfo.stride_x; + int64_t ldy = fcinfo.stride_y; + int64_t ldout = fcinfo.stride_out; bool trans_x = fcinfo.trans_x; bool trans_y = fcinfo.trans_y; float* max_x = fcinfo.max_x; @@ -723,20 +723,20 @@ static void MatMulXPUFunction( } // batch matmul xblas_fc_batch_api(xpu_ctx, // Context* ctx, - batch_size, // int batch_size, + batch_size, // int64_t batch_size, trans_x, // bool x_trans, trans_y, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, + m, // int64_t m, + n, // int64_t n, + k, // int64_t k, alpha, // float alpha, x_data, // const TX* x, - ldx, // int stride_a, + ldx, // int64_t stride_a, y_data, // const TW* w, - ldy, // int stride_b, + ldy, // int64_t stride_b, beta, // float beta, reinterpret_cast(out), // TY* y, - ldout, // int stride_c, + ldout, // int64_t stride_c, max_x, // const float* x_maxptr, max_y); // const float* w_maxptr } @@ -761,10 +761,10 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx, float* max_dout = NULL; int maxptr_size = xpu_ctx->max_ptr_size(); uint64_t l3_size = uint64_t(xpu_ctx->_l3_mgr.get_size()); - int bs = (dout_shape.bs <= 1) ? (1) : (dout_shape.bs); - int dx_size = bs * dout_shape.m * dout_shape.k; - int dy_size = bs * dout_shape.k * dout_shape.n; - int dout_size = bs * dout_shape.m * dout_shape.n; + int64_t bs = (dout_shape.bs <= 1) ? (1) : (dout_shape.bs); + int64_t dx_size = bs * dout_shape.m * dout_shape.k; + int64_t dy_size = bs * dout_shape.k * dout_shape.n; + int64_t dout_size = bs * dout_shape.m * dout_shape.n; if (trans_x && trans_y) { copy_to_l3 = l3_size >= (dout_size * 2 + dy_size) * sizeof(T); } else if (trans_x) { diff --git a/paddle/phi/kernels/xpu/xpu_fused_common_function.h b/paddle/phi/kernels/xpu/xpu_fused_common_function.h index 1aac7ff1392a3..1a868003db41f 100644 --- a/paddle/phi/kernels/xpu/xpu_fused_common_function.h +++ b/paddle/phi/kernels/xpu/xpu_fused_common_function.h @@ -71,9 +71,9 @@ void Dropout(xpu::Context *xpu_ctx, T *mask, T *y, const XPUDropoutParam ¶m, - int len) { + int64_t len) { using XPUType = typename XPUTypeTrait::Type; - int r = XPU_SUCCESS; + int r = 0; if (param.dropout_prob == 0.0f) { r = xpu::copy(xpu_ctx, reinterpret_cast(x), @@ -123,7 +123,7 @@ void DropoutGrad(xpu::Context *xpu_ctx, const T *mask, T *dx, const XPUDropoutParam ¶m, - int len) { + int64_t len) { using XPUType = typename XPUTypeTrait::Type; if (param.dropout_prob == 0.0f) { int r = xpu::copy(xpu_ctx,