Skip to content

Commit b888f04

Browse files
authored
[XPU] update StridedCopyKernel (PaddlePaddle#72030)
* [XPU] update StridedCopyKernel * fix ut env * follow comments
1 parent 2ded875 commit b888f04

File tree

3 files changed

+96
-17
lines changed

3 files changed

+96
-17
lines changed

paddle/phi/kernels/xpu/strided_copy_kernel.cc

+93-13
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2-
Licensed under the Apache License, Version 2.0 (the "License");
3-
you may not use this file except in compliance with the License.
4-
You may obtain a copy of the License at
5-
http://www.apache.org/licenses/LICENSE-2.0
6-
Unless required by applicable law or agreed to in writing, software
7-
distributed under the License is distributed on an "AS IS" BASIS,
8-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9-
See the License for the specific language governing permissions and
10-
limitations under the License. */
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
1114

1215
#include "paddle/phi/kernels/strided_copy_kernel.h"
13-
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
1416
#include "paddle/phi/backends/xpu/enforce_xpu.h"
17+
#include "paddle/phi/common/memory_utils.h"
1518
#include "paddle/phi/core/kernel_registry.h"
16-
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
17-
1819
namespace phi {
1920

2021
template <typename T, typename Context>
@@ -44,6 +45,20 @@ void StridedCopyKernel(const Context& dev_ctx,
4445
input.numel(),
4546
out->numel()));
4647

48+
if (input.numel() <= 0) {
49+
return;
50+
}
51+
52+
PADDLE_ENFORCE_NOT_NULL(out->data<T>(),
53+
common::errors::InvalidArgument(
54+
"StridedCopyKernel's out tensor must complete "
55+
"mutable data before call kernel."));
56+
57+
// The following XPU operators have performance issues and are temporarily
58+
// disabled. A temporary workaround has been implemented: "First copy data to
59+
// CPU, perform computation using CPU operator logic, then copy results back
60+
// to XPU".
61+
/*
4762
// use XPUCopyTypeTrait to deal with double and int16_t copy instead of
4863
// XPUTypeTrait
4964
using XPUType = typename XPUCopyTypeTrait<T>::Type;
@@ -68,6 +83,71 @@ void StridedCopyKernel(const Context& dev_ctx,
6883
common::vectorize<int64_t>(out->strides()));
6984
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_copy");
7085
}
86+
*/
87+
88+
// wait before copy
89+
dev_ctx.Wait();
90+
91+
// CPU buffer for input
92+
char* input_on_cpu = new char[input.Holder()->size()];
93+
memory_utils::Copy(CPUPlace(),
94+
static_cast<void*>(input_on_cpu),
95+
dev_ctx.GetPlace(),
96+
static_cast<const void*>(input.Holder()->ptr()),
97+
input.Holder()->size());
98+
99+
// CPU buffer for out
100+
char* output_on_cpu = new char[out->Holder()->size()];
101+
memory_utils::Copy(CPUPlace(),
102+
static_cast<void*>(output_on_cpu),
103+
dev_ctx.GetPlace(),
104+
static_cast<const void*>(out->Holder()->ptr()),
105+
out->Holder()->size());
106+
107+
// wait after copy
108+
dev_ctx.Wait();
109+
110+
// follow paddle/phi/kernels/cpu/strided_copy_kernel.cc
111+
const T* input_data =
112+
reinterpret_cast<T*>(input_on_cpu + input.meta().offset);
113+
int input_rank = input.dims().size();
114+
const int64_t* input_dims = input.dims().Get();
115+
const int64_t* input_stride = input.strides().Get();
116+
117+
T* output_data = reinterpret_cast<T*>(output_on_cpu + offset);
118+
int output_rank = meta.dims.size();
119+
const int64_t* output_dims = meta.dims.Get();
120+
const int64_t* output_stride = meta.strides.Get();
121+
122+
auto numel = input.numel();
123+
124+
for (int64_t i = 0; i < numel; i++) {
125+
int64_t input_offset = 0;
126+
int64_t index_tmp = i;
127+
for (int dim = input_rank - 1; dim >= 0; --dim) {
128+
input_offset += (index_tmp % input_dims[dim]) * input_stride[dim];
129+
index_tmp = index_tmp / input_dims[dim];
130+
}
131+
int64_t output_offset = 0;
132+
index_tmp = i;
133+
for (int dim = output_rank - 1; dim >= 0; --dim) {
134+
output_offset += (index_tmp % output_dims[dim]) * output_stride[dim];
135+
index_tmp = index_tmp / output_dims[dim];
136+
}
137+
output_data[output_offset] = input_data[input_offset];
138+
}
139+
140+
// copy out tensor, from cpu to xpu
141+
memory_utils::Copy(dev_ctx.GetPlace(),
142+
static_cast<void*>(out->Holder()->ptr()),
143+
CPUPlace(),
144+
static_cast<const void*>(output_on_cpu),
145+
out->Holder()->size());
146+
// wait after copy
147+
dev_ctx.Wait();
148+
149+
delete[] input_on_cpu;
150+
delete[] output_on_cpu;
71151
}
72152

73153
} // namespace phi

test/indexing/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
77
foreach(TEST_OP ${TEST_OPS})
88
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
99
endforeach()
10+
11+
set_tests_properties(test_setitem_appendix
12+
PROPERTIES ENVIRONMENT "FLAGS_use_stride_kernel=1")

test/indexing/test_setitem_appendix.py

-4
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,6 @@ def test_tensor(self):
194194
self.accuracy_check(x, y)
195195

196196

197-
@unittest.skipIf(
198-
paddle.core.is_compiled_with_xpu(),
199-
"There are some bugs on XPU.",
200-
)
201197
class TestSetitemDygraphCombinedIndex(unittest.TestCase):
202198
def accuracy_check(self, numpy_array, paddle_t):
203199
np.testing.assert_allclose(numpy_array, paddle_t.numpy())

0 commit comments

Comments
 (0)