-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[XPU] update StridedCopyKernel #72030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,17 @@ void StridedCopyKernel(const Context& dev_ctx, | |
input.numel(), | ||
out->numel())); | ||
|
||
if (input.numel() <= 0) { | ||
return; | ||
} | ||
|
||
PADDLE_ENFORCE_NOT_NULL(out->data<T>(), | ||
common::errors::InvalidArgument( | ||
"StridedCopyKernel's out tensor must complete " | ||
"mutable data before call kernel.")); | ||
|
||
// 下述XPU算子有性能问题,因此暂时禁用掉,改成“先拷贝到CPU,按照CPU算子逻辑计算,再拷贝回XPU”的临时方案 | ||
/* | ||
// use XPUCopyTypeTrait to deal with double and int16_t copy instead of | ||
// XPUTypeTrait | ||
using XPUType = typename XPUCopyTypeTrait<T>::Type; | ||
|
@@ -68,6 +79,63 @@ void StridedCopyKernel(const Context& dev_ctx, | |
common::vectorize<int64_t>(out->strides())); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_copy"); | ||
} | ||
*/ | ||
|
||
// CPU buffer for input | ||
char* input_on_cpu = new char[input.Holder()->size()]; | ||
memory_utils::Copy(CPUPlace(), | ||
static_cast<void*>(input_on_cpu), | ||
dev_ctx.GetPlace(), | ||
static_cast<const void*>(input.Holder()->ptr()), | ||
input.Holder()->size()); | ||
|
||
// CPU buffer for out | ||
char* output_on_cpu = new char[out->Holder()->size()]; | ||
memory_utils::Copy(CPUPlace(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里需要Wait一下吗?我看底层用的是xpu_memcpy_async。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如单独沟通,继续保留使用 |
||
static_cast<void*>(output_on_cpu), | ||
dev_ctx.GetPlace(), | ||
static_cast<const void*>(out->Holder()->ptr()), | ||
out->Holder()->size()); | ||
|
||
// follow paddle/phi/kernels/cpu/strided_copy_kernel.cc | ||
const T* input_data = | ||
reinterpret_cast<T*>(input_on_cpu + input.meta().offset); | ||
int input_rank = input.dims().size(); | ||
const int64_t* input_dims = input.dims().Get(); | ||
const int64_t* input_stride = input.strides().Get(); | ||
|
||
T* output_data = reinterpret_cast<T*>(output_on_cpu + offset); | ||
int output_rank = meta.dims.size(); | ||
const int64_t* output_dims = meta.dims.Get(); | ||
const int64_t* output_stride = meta.strides.Get(); | ||
|
||
auto numel = input.numel(); | ||
|
||
for (int64_t i = 0; i < numel; i++) { | ||
int64_t input_offset = 0; | ||
int64_t index_tmp = i; | ||
for (int dim = input_rank - 1; dim >= 0; --dim) { | ||
input_offset += (index_tmp % input_dims[dim]) * input_stride[dim]; | ||
index_tmp = index_tmp / input_dims[dim]; | ||
} | ||
int64_t output_offset = 0; | ||
index_tmp = i; | ||
for (int dim = output_rank - 1; dim >= 0; --dim) { | ||
output_offset += (index_tmp % output_dims[dim]) * output_stride[dim]; | ||
index_tmp = index_tmp / output_dims[dim]; | ||
} | ||
output_data[output_offset] = input_data[input_offset]; | ||
} | ||
|
||
// copy out tensor, from cpu to xpu | ||
memory_utils::Copy(dev_ctx.GetPlace(), | ||
static_cast<void*>(out->Holder()->ptr()), | ||
CPUPlace(), | ||
static_cast<const void*>(output_on_cpu), | ||
out->Holder()->size()); | ||
|
||
delete[] input_on_cpu; | ||
delete[] output_on_cpu; | ||
} | ||
|
||
} // namespace phi | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前Paddle repo里暂时没有中文注释,换成英文吧~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done