Skip to content

[Sparse]Fix the bug of elementwise_grad #52102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 161 additions & 28 deletions paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"

namespace phi {
Expand All @@ -39,6 +40,7 @@ void AllocCsrPtr(const Context& dev_ctx,
DenseTensor dx_crows = phi::EmptyLike<IntT>(dev_ctx, x.crows());
DenseTensor dx_cols = phi::EmptyLike<IntT>(dev_ctx, x.cols());
DenseTensor dx_values = phi::EmptyLike<T>(dev_ctx, x.values());
dx->set_meta(x.meta());
dx->SetMember(dx_crows, dx_cols, dx_values, x.dims());
}

Expand All @@ -48,9 +50,117 @@ void AllocCooPtr(const Context& dev_ctx,
SparseCooTensor* dx) {
DenseTensor dx_indices = phi::EmptyLike<IntT>(dev_ctx, x.indices());
DenseTensor dx_values = phi::EmptyLike<T>(dev_ctx, x.values());
dx->set_meta(x.meta());
dx->SetMember(dx_indices, dx_values, x.dims(), x.coalesced());
}

template <typename T, typename IntT, typename Context>
void CopyCooValues(const Context& dev_ctx,
const SparseCooTensor& dout,
const SparseCooTensor& x,
SparseCooTensor* dx) {
Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, dx->mutable_indices());

const int sparse_dim = x.sparse_dim();
std::vector<IntT> sparse_offsets(sparse_dim), dout_indexs(dout.nnz()),
x_indexs(x.nnz());

phi::funcs::sparse::CalcOffsetsPerDim<IntT>(
dout.dims(), sparse_dim, sparse_offsets.data());

phi::funcs::sparse::FlattenIndices(dout.indices().data<IntT>(),
sparse_offsets.data(),
dout.nnz(),
sparse_dim,
0,
1,
dout_indexs.data());

phi::funcs::sparse::FlattenIndices(x.indices().data<IntT>(),
sparse_offsets.data(),
x.nnz(),
sparse_dim,
0,
1,
x_indexs.data());

size_t i = 0, j = 0;
T* dx_values_ptr = dx->mutable_values()->data<T>();
const T* dout_values_ptr = dout.values().data<T>();

int64_t element_size = 1;
for (auto j = 1; j < x.values().dims().size(); ++j) {
element_size *= x.values().dims()[j];
}

while (i < dout_indexs.size() && j < x_indexs.size()) {
if (dout_indexs[i] == x_indexs[j]) {
memcpy(dx_values_ptr + j * element_size,
dout_values_ptr + i * element_size,
element_size * sizeof(T));
++i;
++j;
} else if (dout_indexs[i] > x_indexs[j]) {
memset(dx_values_ptr + j * element_size, 0, element_size * sizeof(T));
++j;
} else {
++i;
}
}
while (j < x_indexs.size()) {
memset(dx_values_ptr + j * element_size, 0, element_size * sizeof(T));
++j;
}
}

template <typename T, typename IntT, typename Context>
void CopyCsrValues(const Context& dev_ctx,
const SparseCsrTensor& dout,
const SparseCsrTensor& x,
SparseCsrTensor* dx) {
Copy(dev_ctx, x.crows(), dev_ctx.GetPlace(), false, dx->mutable_crows());
Copy(dev_ctx, x.cols(), dev_ctx.GetPlace(), false, dx->mutable_cols());

const auto& x_dims = x.dims();
int batch = x_dims.size() == 2 ? 1 : x_dims[0];
int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];

const IntT* x_crows_ptr = x.crows().data<IntT>();
const IntT* x_cols_ptr = x.cols().data<IntT>();

const IntT* dout_crows_ptr = dout.crows().data<IntT>();
const IntT* dout_cols_ptr = dout.cols().data<IntT>();
const T* dout_values_ptr = dout.values().data<T>();

T* dx_values_ptr = dx->mutable_values()->data<T>();

for (int b = 0; b < batch; b++) {
for (int r = 0; r < rows; r++) {
int x_start = x_crows_ptr[b * (rows + 1) + r];
int dout_start = dout_crows_ptr[b * (rows + 1) + r];
int x_row_nnz = x_crows_ptr[b * (rows + 1) + r + 1] - x_start;
int dout_row_nnz = dout_crows_ptr[b * (rows + 1) + r + 1] - dout_start;
int i = 0, j = 0;
while (i < x_row_nnz && j < dout_row_nnz) {
if (x_cols_ptr[x_start + i] == dout_cols_ptr[dout_start + j]) {
dx_values_ptr[x_start + i] = dout_values_ptr[dout_start + j];
++i;
++j;
} else if (x_cols_ptr[x_start + i] < dout_cols_ptr[dout_start + j]) {
dx_values_ptr[x_start + i] = static_cast<T>(0);
++i;
} else {
++j;
}
}
while (i < x_row_nnz) {
dx_values_ptr[x_start + i] = static_cast<T>(0);
++i;
}
}
}
}

template <typename T, typename IntT, typename Context>
void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
Expand All @@ -62,16 +172,16 @@ void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx,
if (dx != nullptr && dy == nullptr) {
VLOG(4) << "Special case when dy is not needed";
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x, dx);
} else if (dx == nullptr && dy != nullptr) {
VLOG(4) << "Special case when dx is not needed";
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y, dy);
} else {
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x, dx);
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y, dy);
}
}

Expand All @@ -84,12 +194,12 @@ void ElementWiseSubtractCsrGradCPUKernel(const Context& dev_ctx,
SparseCsrTensor* dy) {
if (dx) {
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x, dx);
}

if (dy) {
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y, dy);
phi::NegativeKernel<T, Context>(
dev_ctx, dout.values(), dy->mutable_values());
}
Expand All @@ -105,13 +215,19 @@ void ElementWiseMultiplyCsrGradCPUKernel(const Context& dev_ctx,
if (dx) {
// dout*y
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, y, dx);
SparseCsrTensor tmp_dx;
AllocCsrPtr<T, IntT>(dev_ctx, x, &tmp_dx);
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
}

if (dy) {
// dout*x
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, x, dy);
SparseCsrTensor tmp_dy;
AllocCsrPtr<T, IntT>(dev_ctx, y, &tmp_dy);
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, x, &tmp_dy);
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
}
}

Expand All @@ -126,17 +242,24 @@ void ElementWiseDivideCsrGradCPUKernel(const Context& dev_ctx,
if (dx) {
// dout/y
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, dout, y, dx);
SparseCsrTensor tmp_dx;
AllocCsrPtr<T, IntT>(dev_ctx, x, &tmp_dx);
sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
}

if (dy) {
// -dout * out / y
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
SparseCsrTensor tmp_dy;
AllocCsrPtr<T, IntT>(dev_ctx, y, &tmp_dy);

Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, &tmp_dy);
phi::NegativeKernel<T, Context>(
dev_ctx, dout.values(), dy->mutable_values());
auto tmp = sparse::ElementWiseMultiplyCsr<T, Context>(dev_ctx, *dy, out);
sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, tmp, y, dy);
dev_ctx, dout.values(), tmp_dy.mutable_values());
auto tmp = sparse::ElementWiseMultiplyCsr<T, Context>(dev_ctx, tmp_dy, out);
sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, tmp, y, &tmp_dy);
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
}
}

Expand All @@ -151,16 +274,16 @@ void ElementWiseAddCooGradCPUKernel(const Context& dev_ctx,
if (dx != nullptr && dy == nullptr) {
VLOG(4) << "Special case when dy is not needed";
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
CopyCooValues<T, IntT, Context>(dev_ctx, dout, x, dx);
} else if (dx == nullptr && dy != nullptr) {
VLOG(4) << "Special case when dx is not needed";
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
CopyCooValues<T, IntT, Context>(dev_ctx, dout, y, dy);
} else {
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
CopyCooValues<T, IntT, Context>(dev_ctx, dout, x, dx);
CopyCooValues<T, IntT, Context>(dev_ctx, dout, y, dy);
}
}

Expand All @@ -173,12 +296,12 @@ void ElementWiseSubtractCooGradCPUKernel(const Context& dev_ctx,
SparseCooTensor* dy) {
if (dx) {
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
CopyCooValues<T, IntT, Context>(dev_ctx, dout, x, dx);
}

if (dy) {
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
CopyCooValues<T, IntT, Context>(dev_ctx, dout, y, dy);
phi::NegativeKernel<T, Context>(
dev_ctx, dout.values(), dy->mutable_values());
}
Expand All @@ -194,13 +317,19 @@ void ElementWiseMultiplyCooGradCPUKernel(const Context& dev_ctx,
if (dx) {
// dout*y
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, y, dx);
SparseCooTensor tmp_dx;
AllocCooPtr<T, IntT>(dev_ctx, x, &tmp_dx);
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
}

if (dy) {
// dout*x
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, x, dy);
SparseCooTensor tmp_dy;
AllocCooPtr<T, IntT>(dev_ctx, y, &tmp_dy);
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, x, &tmp_dy);
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
}
}

Expand All @@ -215,22 +344,26 @@ void ElementWiseDivideCooGradCPUKernel(const Context& dev_ctx,
if (dx) {
// dout/y
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, dout, y, dx);
SparseCooTensor tmp_dx;
AllocCooPtr<T, IntT>(dev_ctx, x, &tmp_dx);
sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
}

if (dy) {
// -dout * out / y
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
SparseCooTensor tmp_dy;
AllocCooPtr<T, IntT>(dev_ctx, y, &tmp_dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, &tmp_dy);
phi::NegativeKernel<T, Context>(
dev_ctx, dout.values(), dy->mutable_values());
auto tmp = sparse::ElementWiseMultiplyCoo<T, Context>(dev_ctx, *dy, out);
sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, tmp, y, dy);
dev_ctx, dout.values(), tmp_dy.mutable_values());
auto tmp = sparse::ElementWiseMultiplyCoo<T, Context>(dev_ctx, tmp_dy, out);
sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, tmp, y, &tmp_dy);
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
}
}
// CPU Kernel end

// Kernel
template <typename T, typename Context>
void ElementWiseDivideCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
Expand Down
Loading