@@ -825,6 +825,7 @@ class ReorderOneDNNHandler {
825
825
template <typename T>
826
826
class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT <T, dnnl::binary> {
827
827
public:
828
+ bool use_broadcasting_hack;
828
829
BinaryOneDNNHandler (const dnnl::algorithm algo,
829
830
const int axis,
830
831
const dnnl::engine engine,
@@ -835,15 +836,17 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
835
836
float scale_x,
836
837
float scale_y,
837
838
float scale_out,
839
+ bool allow_hack,
838
840
const dnnl::post_ops& post_ops = dnnl::post_ops{})
839
841
: OneDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
842
+ use_broadcasting_hack = false ;
840
843
const auto src_x_tz = vectorize (x->dims ());
841
844
const auto src_y_tz = vectorize (y->dims ());
842
845
// if output tensor(z) is nullptr then we are computing into oneDNN
843
846
// managed buffer
844
847
auto rankdiff = x->dims ().size () - y->dims ().size ();
845
- const auto dst_tz = (out == nullptr ) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
846
- : vectorize (out->dims ());
848
+ auto dst_tz = (out == nullptr ) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
849
+ : vectorize (out->dims ());
847
850
848
851
auto src0_md = x->mem_desc ();
849
852
auto src1_md = y->mem_desc ();
@@ -870,12 +873,48 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
870
873
}
871
874
src0_md = src0_md.reshape (dims0_ex);
872
875
}
873
- const auto dst_md =
874
- memory::desc (dst_tz, OneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
875
876
876
877
auto attributes =
877
878
CreateAttributes (algo, scale_x, scale_y, scale_out, post_ops);
878
879
880
+ // Workaround for U2++ model which deletes first tensor dimensions to enable
881
+ // optimized oneDNNs broadcasting. Output tensor is reshaped back afterwards
882
+ // at the end of the kernel, after the computation
883
+ if (allow_hack && dst_tz.size () == 4 &&
884
+ src0_md.dims ()[2 ] != src1_md.dims ()[2 ]) {
885
+ auto are_strides_plain = [](int64_t * strides, int ndims) {
886
+ for (int i = 0 ; i < ndims - 1 ; ++i) {
887
+ if (strides[i] < strides[i + 1 ]) {
888
+ return false ;
889
+ }
890
+ }
891
+ return true ;
892
+ };
893
+
894
+ auto src0_strides = src0_md.data .format_desc .blocking .strides ;
895
+ auto src1_strides = src1_md.data .format_desc .blocking .strides ;
896
+ auto src0_dims = src0_md.dims ();
897
+ auto src1_dims = src1_md.dims ();
898
+
899
+ bool can_squeeze = src0_dims[0 ] == src1_dims[0 ] &&
900
+ src0_dims[1 ] == src1_dims[1 ] &&
901
+ src0_dims[3 ] == src1_dims[3 ];
902
+
903
+ if (can_squeeze && are_strides_plain (src0_strides, 4 ) &&
904
+ are_strides_plain (src1_strides, 4 )) {
905
+ src0_dims[1 ] *= dst_tz[0 ];
906
+ src1_dims[1 ] *= dst_tz[0 ];
907
+ dst_tz[1 ] *= dst_tz[0 ];
908
+ dst_tz.erase (dst_tz.begin ());
909
+ src0_md = src0_md.reshape ({src0_dims.begin () + 1 , src0_dims.end ()});
910
+ src1_md = src1_md.reshape ({src1_dims.begin () + 1 , src1_dims.end ()});
911
+ use_broadcasting_hack = true ;
912
+ }
913
+ }
914
+
915
+ auto dst_md =
916
+ memory::desc (dst_tz, OneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
917
+
879
918
if (x->numel () < y->numel ()) {
880
919
if (algo == dnnl::algorithm::binary_sub) {
881
920
attributes = CreateAttributes (
0 commit comments