|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/phi/kernels/index_elementwise_get_grad_kernel.h" |
| 16 | + |
| 17 | +#include "paddle/phi/backends/gpu/gpu_context.h" |
| 18 | +#include "paddle/phi/common/bfloat16.h" |
| 19 | +#include "paddle/phi/core/kernel_registry.h" |
| 20 | +#include "paddle/phi/kernels/funcs/eigen/common.h" |
| 21 | +#include "paddle/phi/kernels/funcs/index_elementwise.cu.h" |
| 22 | +#include "paddle/phi/kernels/funcs/stride_utils.h" |
| 23 | + |
| 24 | +namespace phi { |
| 25 | + |
| 26 | +template <typename T, typename IndexT = int> |
| 27 | +void GPUIndexElementwisePutKernel(const phi::GPUContext& ctx, |
| 28 | + const DenseTensor& input, |
| 29 | + const DenseTensor& value, |
| 30 | + const std::vector<const DenseTensor*>& index, |
| 31 | + const std::vector<int64_t>& input_dims, |
| 32 | + const std::vector<int64_t>& input_strides, |
| 33 | + const std::vector<int64_t>& index_dims, |
| 34 | + const std::vector<int64_t>& index_strides, |
| 35 | + DenseTensor* output) { |
| 36 | + int64_t numel = 0; |
| 37 | + |
| 38 | + auto num_indices = index_dims.size(); |
| 39 | + |
| 40 | + auto sizes = std::array<int64_t, 25>{}; |
| 41 | + auto strides = std::array<int64_t, 25>{}; |
| 42 | + for (unsigned i = 0; i < num_indices; i++) { |
| 43 | + sizes[i] = index_dims[i]; |
| 44 | + strides[i] = index_strides[i]; |
| 45 | + } |
| 46 | + auto index_ptrs = funcs::GetIndexDataPtrs<IndexT>(index); |
| 47 | + |
| 48 | + std::array<int64_t*, 3> strides_array; |
| 49 | + std::vector<int64_t> desired_shape; |
| 50 | + |
| 51 | + funcs::IndexPutStride<3>(input_dims, |
| 52 | + input_strides, |
| 53 | + phi::SizeOf(input.dtype()), |
| 54 | + std::vector<int64_t>(), |
| 55 | + std::vector<int64_t>(), |
| 56 | + phi::SizeOf(value.dtype()), |
| 57 | + common::vectorize<int64_t>(index[0]->dims()), |
| 58 | + common::vectorize<int64_t>(index[0]->strides()), |
| 59 | + phi::SizeOf(index[0]->dtype()), |
| 60 | + &desired_shape, |
| 61 | + &strides_array, |
| 62 | + &numel); |
| 63 | + |
| 64 | + const int64_t* template_stride = strides_array[2]; |
| 65 | + PADDLE_ENFORCE( |
| 66 | + template_stride != nullptr, |
| 67 | + "strides_array[2] should not be nullptr in GPUIndexElementwiseGetKernel"); |
| 68 | + size_t stride_size = desired_shape.size(); |
| 69 | + std::vector<std::vector<int64_t>> strides_vector; |
| 70 | + strides_vector.reserve(num_indices + 2); |
| 71 | + |
| 72 | + for (int i = 0; i < 2; ++i) { |
| 73 | + if (i < strides_array.size() && strides_array[i] != nullptr) { |
| 74 | + strides_vector.emplace_back(strides_array[i], |
| 75 | + strides_array[i] + stride_size); |
| 76 | + } else { |
| 77 | + strides_vector.emplace_back(stride_size, 0); |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + std::vector<int64_t> template_vec(template_stride, |
| 82 | + template_stride + stride_size); |
| 83 | + for (size_t i = 0; i < num_indices; ++i) { |
| 84 | + strides_vector.push_back(template_vec); |
| 85 | + } |
| 86 | + |
| 87 | + auto offset_calc = funcs::make_offset_calculator<3>( |
| 88 | + desired_shape.size(), desired_shape.data(), strides_vector); |
| 89 | + |
| 90 | + const int64_t N = numel; |
| 91 | + PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(), |
| 92 | + "N >= 0 && N <= std::numeric_limits<int32_t>::max()"); |
| 93 | + |
| 94 | + constexpr int nt = 128; |
| 95 | + constexpr int vt = 4; |
| 96 | + const dim3 block(nt); |
| 97 | + const dim3 grid((N + block.x * vt - 1) / (block.x * vt)); |
| 98 | + auto stream = ctx.stream(); |
| 99 | + |
| 100 | + using dtype = funcs::OpaqueType<sizeof(T)>; |
| 101 | + |
| 102 | + const char* in_ptr = reinterpret_cast<const char*>(value.data<T>()); |
| 103 | + char* out_ptr = reinterpret_cast<char*>(output->data<T>()); |
| 104 | + |
| 105 | + funcs::index_elementwise_kernel<nt, vt> |
| 106 | + <<<grid, block, 0, stream>>>(N, [=] __device__(int idx) { |
| 107 | + const auto offsets = offset_calc.get(idx); |
| 108 | + char* const out_data = out_ptr + offsets[0]; |
| 109 | + const char* const in_data = in_ptr + offsets[1]; |
| 110 | + |
| 111 | + int64_t offset = 0; |
| 112 | +#pragma unroll |
| 113 | + for (int i = 0; i < num_indices; i++) { |
| 114 | + int64_t index = |
| 115 | + *reinterpret_cast<int64_t*>(index_ptrs[i] + offsets[2]); |
| 116 | + PADDLE_ENFORCE(-sizes[i] <= index && index < sizes[i], |
| 117 | + "index out of bounds"); |
| 118 | + if (index < 0) { |
| 119 | + index += sizes[i]; |
| 120 | + } |
| 121 | + offset += index * strides[i]; |
| 122 | + } |
| 123 | + *reinterpret_cast<dtype*>(out_data + offset) = |
| 124 | + *reinterpret_cast<const dtype*>(in_data); |
| 125 | + }); |
| 126 | +} |
| 127 | + |
| 128 | +template <typename T, typename Context> |
| 129 | +void IndexElementwiseGetGradKernel(const Context& ctx, |
| 130 | + const DenseTensor& x, |
| 131 | + const std::vector<const DenseTensor*>& index, |
| 132 | + const DenseTensor& out_grad, |
| 133 | + const std::vector<int64_t>& input_dims, |
| 134 | + const std::vector<int64_t>& input_strides, |
| 135 | + const std::vector<int64_t>& index_dims, |
| 136 | + const std::vector<int64_t>& index_strides, |
| 137 | + DenseTensor* x_grad) { |
| 138 | + ctx.template Alloc<T>(x_grad); |
| 139 | + auto dxt = phi::EigenVector<T>::Flatten(*x_grad); |
| 140 | + auto& place = *ctx.eigen_device(); |
| 141 | + dxt.device(place) = dxt.constant(static_cast<T>(0)); |
| 142 | + if (out_grad.numel() == 0) return; |
| 143 | + |
| 144 | + const auto& index_type = index[0]->dtype(); |
| 145 | + PADDLE_ENFORCE_EQ( |
| 146 | + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64, |
| 147 | + true, |
| 148 | + common::errors::InvalidArgument( |
| 149 | + "Index holds the wrong type, it holds [%s], but " |
| 150 | + "desires to be [%s] or [%s].", |
| 151 | + index_type, |
| 152 | + phi::DataType::INT32, |
| 153 | + phi::DataType::INT64)); |
| 154 | + |
| 155 | + if (index_type == phi::DataType::INT32) { |
| 156 | + GPUIndexElementwisePutKernel<T, int>(ctx, |
| 157 | + x, |
| 158 | + out_grad, |
| 159 | + index, |
| 160 | + input_dims, |
| 161 | + input_strides, |
| 162 | + index_dims, |
| 163 | + index_strides, |
| 164 | + x_grad); |
| 165 | + } else if (index_type == phi::DataType::INT64) { |
| 166 | + GPUIndexElementwisePutKernel<T, int64_t>(ctx, |
| 167 | + x, |
| 168 | + out_grad, |
| 169 | + index, |
| 170 | + input_dims, |
| 171 | + input_strides, |
| 172 | + index_dims, |
| 173 | + index_strides, |
| 174 | + x_grad); |
| 175 | + } |
| 176 | +} |
| 177 | + |
| 178 | +} // namespace phi |
| 179 | +PD_REGISTER_KERNEL(index_elementwise_get_grad, |
| 180 | + GPU, |
| 181 | + ALL_LAYOUT, |
| 182 | + phi::IndexElementwiseGetGradKernel, |
| 183 | + bool, |
| 184 | + float, |
| 185 | + double, |
| 186 | + int, |
| 187 | + int8_t, |
| 188 | + int64_t, |
| 189 | + int16_t, |
| 190 | + uint8_t, |
| 191 | + phi::dtype::float16, |
| 192 | + phi::dtype::bfloat16, |
| 193 | + phi::dtype::complex<float>, |
| 194 | + phi::dtype::complex<double>) {} |
0 commit comments