Skip to content

Commit 0abf756

Browse files
authored
Added workaround for elementwise oneDNN kernel (#47080)
* return proper state * fix for dims * fix
1 parent 06ef3f0 commit 0abf756

File tree

2 files changed

+59
-8
lines changed

2 files changed

+59
-8
lines changed

paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
7878
scale_x,
7979
scale_y,
8080
scale_o,
81+
true,
8182
get_post_ops(ctx));
8283

8384
// oneDNN's binary is optimized for broadcasting y into x, so in other case
@@ -126,7 +127,14 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
126127
binary_prim->execute(astream, args);
127128
astream.wait();
128129

129-
z->set_mem_desc(dst_memory->get_desc());
130+
if (handler.use_broadcasting_hack == false) {
131+
z->set_mem_desc(dst_memory->get_desc());
132+
} else {
133+
auto dims = dst_memory->get_desc().dims();
134+
dims.insert(dims.begin(), x->dims()[0]);
135+
dims[1] /= dims[0];
136+
z->set_mem_desc(dst_memory->get_desc().reshape(dims));
137+
}
130138
}
131139
};
132140

@@ -210,7 +218,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
210218
dx,
211219
1.0f,
212220
1.0f,
213-
1.0f);
221+
1.0f,
222+
false);
214223

215224
const auto src_dout_memory = binary_handler.AcquireSrcMemory(dout);
216225
const auto src_y_memory = binary_handler.AcquireSecondSrcMemory(y);
@@ -276,7 +285,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
276285
nullptr,
277286
1.0f,
278287
1.0f,
279-
1.0f);
288+
1.0f,
289+
false);
280290

281291
src_1_memory = binary_handler.AcquireSecondSrcMemory(x);
282292

@@ -291,7 +301,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
291301
nullptr,
292302
1.0f,
293303
1.0f,
294-
1.0f);
304+
1.0f,
305+
false);
295306

296307
post_op_memory = post_op_binary_handler.AcquireSrcMemory(y);
297308

@@ -310,6 +321,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
310321
-1.0f,
311322
1.0f,
312323
1.0f,
324+
false,
313325
po);
314326

315327
src_1_memory = binary_handler.AcquireSecondSrcMemory(out);

paddle/phi/backends/onednn/onednn_reuse.h

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,7 @@ class ReorderOneDNNHandler {
825825
template <typename T>
826826
class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
827827
public:
828+
bool use_broadcasting_hack;
828829
BinaryOneDNNHandler(const dnnl::algorithm algo,
829830
const int axis,
830831
const dnnl::engine engine,
@@ -835,15 +836,17 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
835836
float scale_x,
836837
float scale_y,
837838
float scale_out,
839+
bool allow_hack,
838840
const dnnl::post_ops& post_ops = dnnl::post_ops{})
839841
: OneDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
842+
use_broadcasting_hack = false;
840843
const auto src_x_tz = vectorize(x->dims());
841844
const auto src_y_tz = vectorize(y->dims());
842845
// if output tensor(z) is nullptr then we are computing into oneDNN
843846
// managed buffer
844847
auto rankdiff = x->dims().size() - y->dims().size();
845-
const auto dst_tz = (out == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
846-
: vectorize(out->dims());
848+
auto dst_tz = (out == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
849+
: vectorize(out->dims());
847850

848851
auto src0_md = x->mem_desc();
849852
auto src1_md = y->mem_desc();
@@ -870,12 +873,48 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
870873
}
871874
src0_md = src0_md.reshape(dims0_ex);
872875
}
873-
const auto dst_md =
874-
memory::desc(dst_tz, OneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
875876

876877
auto attributes =
877878
CreateAttributes(algo, scale_x, scale_y, scale_out, post_ops);
878879

880+
// Workaround for U2++ model which deletes first tensor dimensions to enable
881+
// optimized oneDNNs broadcasting. Output tensor is reshaped back afterwards
882+
// at the end of the kernel, after the computation
883+
if (allow_hack && dst_tz.size() == 4 &&
884+
src0_md.dims()[2] != src1_md.dims()[2]) {
885+
auto are_strides_plain = [](int64_t* strides, int ndims) {
886+
for (int i = 0; i < ndims - 1; ++i) {
887+
if (strides[i] < strides[i + 1]) {
888+
return false;
889+
}
890+
}
891+
return true;
892+
};
893+
894+
auto src0_strides = src0_md.data.format_desc.blocking.strides;
895+
auto src1_strides = src1_md.data.format_desc.blocking.strides;
896+
auto src0_dims = src0_md.dims();
897+
auto src1_dims = src1_md.dims();
898+
899+
bool can_squeeze = src0_dims[0] == src1_dims[0] &&
900+
src0_dims[1] == src1_dims[1] &&
901+
src0_dims[3] == src1_dims[3];
902+
903+
if (can_squeeze && are_strides_plain(src0_strides, 4) &&
904+
are_strides_plain(src1_strides, 4)) {
905+
src0_dims[1] *= dst_tz[0];
906+
src1_dims[1] *= dst_tz[0];
907+
dst_tz[1] *= dst_tz[0];
908+
dst_tz.erase(dst_tz.begin());
909+
src0_md = src0_md.reshape({src0_dims.begin() + 1, src0_dims.end()});
910+
src1_md = src1_md.reshape({src1_dims.begin() + 1, src1_dims.end()});
911+
use_broadcasting_hack = true;
912+
}
913+
}
914+
915+
auto dst_md =
916+
memory::desc(dst_tz, OneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
917+
879918
if (x->numel() < y->numel()) {
880919
if (algo == dnnl::algorithm::binary_sub) {
881920
attributes = CreateAttributes(

0 commit comments

Comments
 (0)