Skip to content

[XPU] update StridedCopyKernel #72030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions paddle/phi/kernels/xpu/strided_copy_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ void StridedCopyKernel(const Context& dev_ctx,
input.numel(),
out->numel()));

if (input.numel() <= 0) {
return;
}

PADDLE_ENFORCE_NOT_NULL(out->data<T>(),
common::errors::InvalidArgument(
"StridedCopyKernel's out tensor must complete "
"mutable data before call kernel."));

// 下述XPU算子有性能问题,因此暂时禁用掉,改成“先拷贝到CPU,按照CPU算子逻辑计算,再拷贝回XPU”的临时方案
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前Paddle repo里暂时没有中文注释,换成英文吧~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

// use XPUCopyTypeTrait to deal with double and int16_t copy instead of
// XPUTypeTrait
using XPUType = typename XPUCopyTypeTrait<T>::Type;
Expand All @@ -68,6 +79,63 @@ void StridedCopyKernel(const Context& dev_ctx,
common::vectorize<int64_t>(out->strides()));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_copy");
}
*/

// CPU buffer for input
char* input_on_cpu = new char[input.Holder()->size()];
memory_utils::Copy(CPUPlace(),
static_cast<void*>(input_on_cpu),
dev_ctx.GetPlace(),
static_cast<const void*>(input.Holder()->ptr()),
input.Holder()->size());

// CPU buffer for out
char* output_on_cpu = new char[out->Holder()->size()];
memory_utils::Copy(CPUPlace(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要Wait一下吗?我看底层用的是xpu_memcpy_async。
一般,在Kernel中我们使用phi::Copy,如果给phi::Copy传入Stream则代表不blocking,如果stream为Nullptr,表示需要Blocking。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如单独沟通,继续保留使用memory_utils::Copy,并且补充了同步操作。

static_cast<void*>(output_on_cpu),
dev_ctx.GetPlace(),
static_cast<const void*>(out->Holder()->ptr()),
out->Holder()->size());

// follow paddle/phi/kernels/cpu/strided_copy_kernel.cc
const T* input_data =
reinterpret_cast<T*>(input_on_cpu + input.meta().offset);
int input_rank = input.dims().size();
const int64_t* input_dims = input.dims().Get();
const int64_t* input_stride = input.strides().Get();

T* output_data = reinterpret_cast<T*>(output_on_cpu + offset);
int output_rank = meta.dims.size();
const int64_t* output_dims = meta.dims.Get();
const int64_t* output_stride = meta.strides.Get();

auto numel = input.numel();

for (int64_t i = 0; i < numel; i++) {
int64_t input_offset = 0;
int64_t index_tmp = i;
for (int dim = input_rank - 1; dim >= 0; --dim) {
input_offset += (index_tmp % input_dims[dim]) * input_stride[dim];
index_tmp = index_tmp / input_dims[dim];
}
int64_t output_offset = 0;
index_tmp = i;
for (int dim = output_rank - 1; dim >= 0; --dim) {
output_offset += (index_tmp % output_dims[dim]) * output_stride[dim];
index_tmp = index_tmp / output_dims[dim];
}
output_data[output_offset] = input_data[input_offset];
}

// copy out tensor, from cpu to xpu
memory_utils::Copy(dev_ctx.GetPlace(),
static_cast<void*>(out->Holder()->ptr()),
CPUPlace(),
static_cast<const void*>(output_on_cpu),
out->Holder()->size());

delete[] input_on_cpu;
delete[] output_on_cpu;
}

} // namespace phi
Expand Down
3 changes: 3 additions & 0 deletions test/indexing/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach()

set_tests_properties(test_setitem_appendix
PROPERTIES ENVIRONMENT "FLAGS_use_stride_kernel=1")
4 changes: 0 additions & 4 deletions test/indexing/test_setitem_appendix.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,6 @@ def test_tensor(self):
self.accuracy_check(x, y)


@unittest.skipIf(
paddle.core.is_compiled_with_xpu(),
"There are some bugs on XPU.",
)
class TestSetitemDygraphCombinedIndex(unittest.TestCase):
def accuracy_check(self, numpy_array, paddle_t):
np.testing.assert_allclose(numpy_array, paddle_t.numpy())
Expand Down