Skip to content

Commit f41ccbd

Browse files
authored
[PHI] Migrate matmul kernel (#48162)
* cleanup unused code * unify is_int8 is_bfloat16 * Simplify matmul_v2 FWD kernel * remove RunKernel methods * remove import namespace * remove headers * clean fluid/phi cross imports * remove fluid axpy_handler * delete fluid methods * activations * OneDNNMemDesc * MKLDNNFormatForSize * MatchShapeToLayout * MKLDNNMemoryFormat * MKLDNNFormat * ReorderMKLDNNHandler * to_void_cast * review suggestions * interpolate * remove fluid depedency * init * ExecuteMatMulV2 * rm fluid kernel * matmul_grad * remove mutable_data * mul_grad * matmul fwd * add extra attr * temp disable passes * re-enable passes * workaround for matmul+act * fix for matmul+eltwise_add * fix typo * merge bugfix #48364 * remove merge conflict
1 parent c928a35 commit f41ccbd

File tree

4 files changed

+186
-21
lines changed

4 files changed

+186
-21
lines changed

paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
381381
}
382382

383383
template <typename T>
384-
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
384+
class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
385385
public:
386386
void Compute(const ExecutionContext &ctx) const override {
387387
if (ctx.HasAttr("head_number")) {
@@ -696,21 +696,13 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
696696
REGISTER_OP_KERNEL(matmul,
697697
MKLDNN,
698698
::paddle::platform::CPUPlace,
699-
MatMulV2MKLDNNKernel<float>,
700-
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
701-
MatMulV2MKLDNNKernel<int8_t>,
702-
MatMulV2MKLDNNKernel<uint8_t>);
699+
MatMulMKLDNNKernel<float>,
700+
MatMulMKLDNNKernel<paddle::platform::bfloat16>,
701+
MatMulMKLDNNKernel<int8_t>,
702+
MatMulMKLDNNKernel<uint8_t>);
703703

704704
REGISTER_OP_KERNEL(matmul_grad,
705705
MKLDNN,
706706
::paddle::platform::CPUPlace,
707707
MatMulGradMKLDNNKernel<float>,
708708
MatMulGradMKLDNNKernel<paddle::platform::bfloat16>);
709-
710-
REGISTER_OP_KERNEL(matmul_v2,
711-
MKLDNN,
712-
::paddle::platform::CPUPlace,
713-
MatMulV2MKLDNNKernel<float>,
714-
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
715-
MatMulV2MKLDNNKernel<int8_t>,
716-
MatMulV2MKLDNNKernel<uint8_t>);

paddle/fluid/operators/ops_extra_info.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
9898
{"fuse_alpha", ExtraAttrProperty::ONEDNN},
9999
{"fuse_beta", ExtraAttrProperty::ONEDNN},
100100
{"fuse_relu", ExtraAttrProperty::ONEDNN},
101+
{"fused_output_scale", ExtraAttrProperty::ONEDNN},
101102
{"fuse_residual_connection", ExtraAttrProperty::ONEDNN},
102103
{"fuse_with_relu", ExtraAttrProperty::ONEDNN},
103104
{"fused_reshape_Out", ExtraAttrProperty::ONEDNN},
@@ -221,7 +222,8 @@ class ExtraInfoUtils {
221222
std::unordered_map<std::string, std::vector<std::string>>
222223
g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}},
223224
{"conv2d_transpose", {"Bias"}},
224-
{"conv2d_grad", {"Bias"}}};
225+
{"conv2d_grad", {"Bias"}},
226+
{"matmul_v2", {"ResidualData"}}};
225227
std::vector<std::string> empty_extra_input_names_;
226228
};
227229

paddle/phi/backends/onednn/onednn_reuse.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,9 +1874,11 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
18741874
if (scale_out != 1.0f) {
18751875
matmul_attrs.set_output_scales(0, {scale_out});
18761876
}
1877+
const auto* residual_data = dev_ctx.HasDnnInput("ResidualData")
1878+
? dev_ctx.GetDnnInput("ResidualData")
1879+
: nullptr;
18771880

1878-
if (dev_ctx.HasDnnInput("ResidualData")) {
1879-
auto* residual_data = dev_ctx.GetDnnInput("ResidualData");
1881+
if (residual_data) {
18801882
auto residual_data_tz = vectorize(residual_data->dims());
18811883
auto residual_data_md = memory::desc(residual_data_tz,
18821884
OneDNNGetDataType<OT>(),
@@ -1893,9 +1895,11 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
18931895

18941896
AppendActivation(dev_ctx, post_operations);
18951897

1896-
if (dev_ctx.HasDnnAttr("fused_output_scale")) {
1897-
float scale_alpha =
1898-
PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale"));
1898+
const float scale_alpha =
1899+
dev_ctx.HasDnnAttr("fused_output_scale")
1900+
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale"))
1901+
: 1.0f;
1902+
if (scale_alpha != 1.0f) {
18991903
post_operations.append_eltwise(
19001904
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
19011905
}
@@ -2014,8 +2018,11 @@ void ExecuteMatmul(const OneDNNContext& dev_ctx,
20142018
{DNNL_ARG_WEIGHTS, *weights_memory_p},
20152019
{DNNL_ARG_DST, *dst_memory_p}};
20162020

2017-
if (dev_ctx.HasDnnInput("ResidualData")) {
2018-
auto* residual_data = dev_ctx.GetDnnInput("ResidualData");
2021+
const auto* residual_data = dev_ctx.HasDnnInput("ResidualData")
2022+
? dev_ctx.GetDnnInput("ResidualData")
2023+
: nullptr;
2024+
2025+
if (residual_data) {
20192026
const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data);
20202027
matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
20212028
*residual_data_memory_p});
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/matmul_kernel.h"
16+
17+
#include "paddle/phi/backends/onednn/onednn_reuse.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
20+
namespace phi {
21+
22+
DDim GetDimsForInput(const OneDNNContext &dev_ctx,
23+
DDim input_dims,
24+
std::string input_name) {
25+
auto shape =
26+
dev_ctx.HasDnnAttr("fused_reshape_" + input_name)
27+
? PADDLE_GET_CONST(std::vector<int>,
28+
dev_ctx.GetDnnAttr("fused_reshape_" + input_name))
29+
: std::vector<int>();
30+
auto axis = dev_ctx.HasDnnAttr("fused_transpose_" + input_name)
31+
? PADDLE_GET_CONST(
32+
std::vector<int>,
33+
dev_ctx.GetDnnAttr("fused_transpose_" + input_name))
34+
: std::vector<int>();
35+
if (!shape.empty() && !axis.empty()) {
36+
return input_dims.reshape(shape).transpose(axis);
37+
}
38+
return input_dims;
39+
}
40+
41+
void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
42+
const std::vector<int64_t> &y_dims,
43+
std::vector<int64_t> *x_bd_dims,
44+
std::vector<int64_t> *y_bd_dims,
45+
DenseTensor *out,
46+
const bool is_output_fused) {
47+
if (x_dims.size() == 1) {
48+
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
49+
} else if (x_dims.size() == 2) {
50+
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[1];
51+
(*x_bd_dims)[(*x_bd_dims).size() - 2] = x_dims[0];
52+
} else {
53+
for (size_t i = 0; i < x_dims.size(); ++i) {
54+
(*x_bd_dims)[(*x_bd_dims).size() - x_dims.size() + i] = x_dims[i];
55+
}
56+
}
57+
if (y_dims.size() == 1) {
58+
(*y_bd_dims)[(*x_bd_dims).size() - 2] = y_dims[0];
59+
} else if (y_dims.size() == 2) {
60+
(*y_bd_dims)[(*y_bd_dims).size() - 1] = y_dims[1];
61+
(*y_bd_dims)[(*y_bd_dims).size() - 2] = y_dims[0];
62+
} else {
63+
for (size_t i = 0; i < y_dims.size(); ++i) {
64+
(*y_bd_dims)[(*y_bd_dims).size() - y_dims.size() + i] = y_dims[i];
65+
}
66+
}
67+
68+
if (!is_output_fused && x_dims.size() > 2 && y_dims.size() > 2) {
69+
auto out_dims = vectorize(out->dims());
70+
for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
71+
PADDLE_ENFORCE_EQ(
72+
(*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 ||
73+
(*y_bd_dims)[i] == 1,
74+
true,
75+
errors::InvalidArgument(
76+
"Tensor dimensions are incorrect for broadcasting."
77+
"Dimensions in X and Y must be same or equal to 1, but "
78+
"received x_dim[%d]=%d and y_dims[%d]= %d",
79+
i,
80+
(*x_bd_dims)[i],
81+
i,
82+
(*y_bd_dims)[i]));
83+
(out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
84+
}
85+
out->Resize(make_ddim((out_dims)));
86+
}
87+
}
88+
89+
template <typename T, typename Context>
90+
void MatmulKernel(const Context &dev_ctx,
91+
const DenseTensor &x,
92+
const DenseTensor &y,
93+
bool transpose_x,
94+
bool transpose_y,
95+
DenseTensor *out) {
96+
if (dev_ctx.HasDnnAttr("head_number")) {
97+
const auto head_number =
98+
PADDLE_GET_CONST(int, dev_ctx.GetDnnAttr("head_number"));
99+
PADDLE_ENFORCE_EQ(
100+
head_number,
101+
1,
102+
errors::Unimplemented(
103+
"oneDNN matmul doesn't support multiple heads. Expected "
104+
"head_number=1. But received `head_number` is %d",
105+
head_number));
106+
}
107+
108+
constexpr bool is_int8 = funcs::is_int8<T>();
109+
constexpr bool is_bfloat16 = funcs::is_bfloat16<T>();
110+
const bool force_fp32_output =
111+
dev_ctx.HasDnnAttr("force_fp32_output")
112+
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
113+
: false;
114+
115+
bool fuse_relu = false;
116+
if (dev_ctx.HasDnnAttr("fuse_activation")) {
117+
auto act_type =
118+
PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation"));
119+
if (act_type == "relu" || act_type == "relu6") {
120+
fuse_relu = true;
121+
}
122+
}
123+
124+
auto x_dims = vectorize(GetDimsForInput(dev_ctx, x.dims(), "X"));
125+
auto y_dims = vectorize(GetDimsForInput(dev_ctx, y.dims(), "Y"));
126+
127+
int ndims = std::max(x_dims.size(), y_dims.size());
128+
ndims = std::max(ndims, 3);
129+
130+
std::vector<int64_t> x_bd_dims(ndims, 1);
131+
std::vector<int64_t> y_bd_dims(ndims, 1);
132+
133+
CalculateMatrixDims(x_dims,
134+
y_dims,
135+
&x_bd_dims,
136+
&y_bd_dims,
137+
out,
138+
funcs::IsOutputFused(dev_ctx));
139+
140+
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
141+
funcs::ExecuteMatmul<T, float>(
142+
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
143+
} else if (is_bfloat16) {
144+
funcs::ExecuteMatmul<T, paddle::platform::bfloat16>(
145+
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
146+
} else if (fuse_relu) {
147+
funcs::ExecuteMatmul<T, uint8_t>(
148+
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
149+
} else {
150+
funcs::ExecuteMatmul<T, int8_t>(
151+
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
152+
}
153+
}
154+
155+
} // namespace phi
156+
157+
PD_REGISTER_KERNEL(matmul,
158+
OneDNN,
159+
ONEDNN,
160+
phi::MatmulKernel,
161+
float,
162+
phi::dtype::bfloat16,
163+
int8_t,
164+
uint8_t) {}

0 commit comments

Comments
 (0)