Skip to content

Commit 75e1d29

Browse files
lj970926BeingGod
authored andcommitted
optimize unique and index_put (PaddlePaddle#56582)
1 parent 1387c46 commit 75e1d29

File tree

3 files changed

+45
-22
lines changed

3 files changed

+45
-22
lines changed

paddle/phi/kernels/funcs/index_put_utils.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <vector>
18+
#include "paddle/phi/backends/context_pool.h"
1819
#include "paddle/phi/common/int_array.h"
1920
#include "paddle/phi/common/memory_utils.h"
2021
#include "paddle/phi/common/place.h"
@@ -106,7 +107,14 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices(
106107
SplitWithNumKernel<int64_t, Context>(
107108
dev_ctx, nonzero_indices, rank, 1, integer_indices);
108109
#ifdef PADDLE_WITH_XPU
109-
dev_ctx.Wait();
110+
auto place = dev_ctx.GetPlace();
111+
if (place.GetType() == phi::AllocationType::XPU) {
112+
auto& pool = phi::DeviceContextPool::Instance();
113+
auto* xpu_ctx = static_cast<phi::XPUContext*>(pool.Get(place));
114+
if (xpu_ctx->x_context()->xpu_stream) {
115+
dev_ctx.Wait();
116+
}
117+
}
110118
#endif
111119

112120
} else if ((indices_v[i]->dtype() == phi::DataType::INT64) ||

paddle/phi/kernels/xpu/index_put_kernel.cc

+6-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ void XPUDealWithIndices(const Context& dev_ctx,
6565
}
6666

6767
StackKernel<int64_t, Context>(dev_ctx, tmp_indices_ptr, -1, out);
68-
dev_ctx.Wait();
68+
if (dev_ctx.x_context()->xpu_stream) {
69+
dev_ctx.Wait();
70+
}
6971
}
7072

7173
template <typename T, typename Context>
@@ -140,7 +142,9 @@ void IndexPutKernel(const Context& dev_ctx,
140142
index_shape,
141143
accumulate);
142144
PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_put");
143-
dev_ctx.Wait();
145+
if (dev_ctx.x_context()->xpu_stream) {
146+
dev_ctx.Wait();
147+
}
144148
}
145149
} // namespace phi
146150

paddle/phi/kernels/xpu/unique_kernel.cc

+30-19
Original file line numberDiff line numberDiff line change
@@ -228,26 +228,37 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx,
228228
inverse_cpu[ori_idx_cpu[0]] = 0;
229229
IndexT unique_len = 1;
230230
IndexT repeat_cnt = 1;
231-
for (IndexT i = 1; i < axis_len; ++i) {
232-
int cnt_cpu = 0;
233-
int* cnt_xpu = RAII_GUARD.alloc_l3_or_gm<int>(1);
234-
r = xpu::nonzero_count<bool>(dev_ctx.x_context(),
235-
compare_results + (i - 1) * slice_size,
236-
cnt_xpu,
237-
slice_size);
238-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "nonzero_count");
239-
memory_utils::Copy(
240-
phi::CPUPlace(), &cnt_cpu, dev_ctx.GetPlace(), cnt_xpu, sizeof(int));
241-
if (cnt_cpu != slice_size) {
242-
unique_axis.push_back(i);
243-
indices_cpu.push_back(ori_idx_cpu[i]);
244-
counts_cpu.push_back(repeat_cnt);
245-
++unique_len;
246-
repeat_cnt = 1;
247-
} else {
248-
++repeat_cnt;
231+
if (axis_len > 1) {
232+
DenseTensor adj_identical_cpu;
233+
adj_identical_cpu.Resize({axis_len - 1});
234+
bool* adj_identical_cpu_data =
235+
dev_ctx.template HostAlloc<bool>(&adj_identical_cpu);
236+
auto* adj_identical_xpu = RAII_GUARD.alloc_l3_or_gm<bool>(axis_len - 1);
237+
r = xpu::reduce_all<bool>(dev_ctx.x_context(),
238+
compare_results,
239+
adj_identical_xpu,
240+
{axis_len - 1, slice_size},
241+
{1});
242+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_all");
243+
244+
memory_utils::Copy(phi::CPUPlace(),
245+
adj_identical_cpu_data,
246+
dev_ctx.GetPlace(),
247+
adj_identical_xpu,
248+
(axis_len - 1) * sizeof(bool));
249+
250+
for (IndexT i = 1; i < axis_len; ++i) {
251+
if (!adj_identical_cpu_data[i - 1]) {
252+
unique_axis.push_back(i);
253+
indices_cpu.push_back(ori_idx_cpu[i]);
254+
counts_cpu.push_back(repeat_cnt);
255+
++unique_len;
256+
repeat_cnt = 1;
257+
} else {
258+
++repeat_cnt;
259+
}
260+
inverse_cpu[ori_idx_cpu[i]] = unique_len - 1;
249261
}
250-
inverse_cpu[ori_idx_cpu[i]] = unique_len - 1;
251262
}
252263
counts_cpu.push_back(repeat_cnt);
253264
DDim out_dims = x.dims();

0 commit comments

Comments
 (0)