@@ -228,26 +228,37 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx,
228
228
inverse_cpu[ori_idx_cpu[0 ]] = 0 ;
229
229
IndexT unique_len = 1 ;
230
230
IndexT repeat_cnt = 1 ;
231
- for (IndexT i = 1 ; i < axis_len; ++i) {
232
- int cnt_cpu = 0 ;
233
- int * cnt_xpu = RAII_GUARD.alloc_l3_or_gm <int >(1 );
234
- r = xpu::nonzero_count<bool >(dev_ctx.x_context (),
235
- compare_results + (i - 1 ) * slice_size,
236
- cnt_xpu,
237
- slice_size);
238
- PADDLE_ENFORCE_XDNN_SUCCESS (r, " nonzero_count" );
239
- memory_utils::Copy (
240
- phi::CPUPlace (), &cnt_cpu, dev_ctx.GetPlace (), cnt_xpu, sizeof (int ));
241
- if (cnt_cpu != slice_size) {
242
- unique_axis.push_back (i);
243
- indices_cpu.push_back (ori_idx_cpu[i]);
244
- counts_cpu.push_back (repeat_cnt);
245
- ++unique_len;
246
- repeat_cnt = 1 ;
247
- } else {
248
- ++repeat_cnt;
231
+ if (axis_len > 1 ) {
232
+ DenseTensor adj_identical_cpu;
233
+ adj_identical_cpu.Resize ({axis_len - 1 });
234
+ bool * adj_identical_cpu_data =
235
+ dev_ctx.template HostAlloc <bool >(&adj_identical_cpu);
236
+ auto * adj_identical_xpu = RAII_GUARD.alloc_l3_or_gm <bool >(axis_len - 1 );
237
+ r = xpu::reduce_all<bool >(dev_ctx.x_context (),
238
+ compare_results,
239
+ adj_identical_xpu,
240
+ {axis_len - 1 , slice_size},
241
+ {1 });
242
+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " reduce_all" );
243
+
244
+ memory_utils::Copy (phi::CPUPlace (),
245
+ adj_identical_cpu_data,
246
+ dev_ctx.GetPlace (),
247
+ adj_identical_xpu,
248
+ (axis_len - 1 ) * sizeof (bool ));
249
+
250
+ for (IndexT i = 1 ; i < axis_len; ++i) {
251
+ if (!adj_identical_cpu_data[i - 1 ]) {
252
+ unique_axis.push_back (i);
253
+ indices_cpu.push_back (ori_idx_cpu[i]);
254
+ counts_cpu.push_back (repeat_cnt);
255
+ ++unique_len;
256
+ repeat_cnt = 1 ;
257
+ } else {
258
+ ++repeat_cnt;
259
+ }
260
+ inverse_cpu[ori_idx_cpu[i]] = unique_len - 1 ;
249
261
}
250
- inverse_cpu[ori_idx_cpu[i]] = unique_len - 1 ;
251
262
}
252
263
counts_cpu.push_back (repeat_cnt);
253
264
DDim out_dims = x.dims ();
0 commit comments