Skip to content

Commit e3e9e5c

Browse files
committed
[XPU] copy index to xpu inside index_select; test=develop
1 parent c71f332 commit e3e9e5c

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

paddle/phi/api/yaml/ops.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,8 @@
12631263
func : index_select
12641264
data_type : x
12651265
backward : index_select_grad
1266+
data_transform :
1267+
skip_transform : index
12661268

12671269
- op : index_select_strided
12681270
args : (Tensor x, int64_t index, int axis = 0)

paddle/phi/backends/xpu/xpu1_op_list.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ XPUOpMap& get_kl1_ops() {
361361
{"where_index", XPUKernelSet({phi::DataType::BOOL})},
362362
{"where",
363363
XPUKernelSet({phi::DataType::INT32,
364-
// phi::DataType::INT64,
364+
phi::DataType::INT64,
365365
phi::DataType::FLOAT32})},
366366
// AddMore
367367
};

paddle/phi/kernels/xpu/index_select_kernel.cc

+19-3
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
// limitations under the License.
1414

1515
#include "paddle/phi/kernels/index_select_kernel.h"
16-
1716
#include "paddle/phi/backends/xpu/enforce_xpu.h"
17+
#include "paddle/phi/common/memory_utils.h"
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/core/utils/data_type.h"
2020

@@ -46,8 +46,23 @@ void IndexSelectKernel(const Context& ctx,
4646
int index_len = output->dims()[dim];
4747
T* out_data = ctx.template Alloc<T>(output);
4848
int r = 0;
49+
xpu::ctx_guard RAII_GUARD(ctx.x_context());
50+
const int8_t* index_ptr = nullptr;
51+
int byte_times = sizeof(index_type);
52+
if (index.place() == CPUPlace()) {
53+
index_ptr = RAII_GUARD.alloc_l3_or_gm<int8_t>(byte_times * index.numel());
54+
PADDLE_ENFORCE_XDNN_NOT_NULL(index_ptr);
55+
memory_utils::Copy(ctx.GetPlace(),
56+
reinterpret_cast<void*>(const_cast<int8_t*>(index_ptr)),
57+
CPUPlace(),
58+
reinterpret_cast<const void*>(index.data<int>()),
59+
byte_times * index.numel());
60+
} else {
61+
index_ptr = index.template data<int8_t>();
62+
}
4963
if (index_type == phi::DataType::INT64) {
50-
const int64_t* index_data = index.data<int64_t>();
64+
const int64_t* index_data =
65+
reinterpret_cast<const int64_t*>(const_cast<int8_t*>(index_ptr));
5166
r = xpu::gather<T, int64_t>(ctx.x_context(),
5267
in_data,
5368
index_data,
@@ -56,7 +71,8 @@ void IndexSelectKernel(const Context& ctx,
5671
index_len,
5772
dim);
5873
} else {
59-
const int* index_data = index.data<int>();
74+
const int* index_data =
75+
reinterpret_cast<const int*>(const_cast<int8_t*>(index_ptr));
6076
r = xpu::gather<T, int>(ctx.x_context(),
6177
in_data,
6278
index_data,

0 commit comments

Comments
 (0)