13
13
// limitations under the License.
14
14
15
15
#include " paddle/phi/kernels/index_select_kernel.h"
16
-
17
16
#include " paddle/phi/backends/xpu/enforce_xpu.h"
17
+ #include " paddle/phi/common/memory_utils.h"
18
18
#include " paddle/phi/core/kernel_registry.h"
19
19
#include " paddle/phi/core/utils/data_type.h"
20
20
@@ -46,8 +46,23 @@ void IndexSelectKernel(const Context& ctx,
46
46
int index_len = output->dims ()[dim];
47
47
T* out_data = ctx.template Alloc <T>(output);
48
48
int r = 0 ;
49
+ xpu::ctx_guard RAII_GUARD (ctx.x_context ());
50
+ const int8_t * index_ptr = nullptr ;
51
+ int byte_times = sizeof (index_type);
52
+ if (index .place () == CPUPlace ()) {
53
+ index_ptr = RAII_GUARD.alloc_l3_or_gm <int8_t >(byte_times * index .numel ());
54
+ PADDLE_ENFORCE_XDNN_NOT_NULL (index_ptr);
55
+ memory_utils::Copy (ctx.GetPlace (),
56
+ reinterpret_cast <void *>(const_cast <int8_t *>(index_ptr)),
57
+ CPUPlace (),
58
+ reinterpret_cast <const void *>(index .data <int >()),
59
+ byte_times * index .numel ());
60
+ } else {
61
+ index_ptr = index .template data <int8_t >();
62
+ }
49
63
if (index_type == phi::DataType::INT64) {
50
- const int64_t * index_data = index .data <int64_t >();
64
+ const int64_t * index_data =
65
+ reinterpret_cast <const int64_t *>(const_cast <int8_t *>(index_ptr));
51
66
r = xpu::gather<T, int64_t >(ctx.x_context (),
52
67
in_data,
53
68
index_data,
@@ -56,7 +71,8 @@ void IndexSelectKernel(const Context& ctx,
56
71
index_len,
57
72
dim);
58
73
} else {
59
- const int * index_data = index .data <int >();
74
+ const int * index_data =
75
+ reinterpret_cast <const int *>(const_cast <int8_t *>(index_ptr));
60
76
r = xpu::gather<T, int >(ctx.x_context (),
61
77
in_data,
62
78
index_data,
0 commit comments