Skip to content

Commit a923e01

Browse files
authored
[NPU] Support fused_linear_param_grad_add (PaddlePaddle#976)
1 parent 8db52f4 commit a923e01

File tree

2 files changed

+298
-0
lines changed

2 files changed

+298
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "kernels/funcs/npu_funcs.h"
16+
#include "kernels/funcs/npu_op_runner.h"
17+
18+
namespace custom_kernel {
19+
20+
template <typename T, typename Context>
21+
void FusedLinearParamGradAdd(const Context &dev_ctx,
22+
const phi::DenseTensor &x,
23+
const phi::DenseTensor &dout,
24+
const paddle::optional<phi::DenseTensor> &dweight,
25+
const paddle::optional<phi::DenseTensor> &dbias,
26+
bool multi_precision,
27+
bool has_bias,
28+
phi::DenseTensor *dweight_out,
29+
phi::DenseTensor *dbias_out) {
30+
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
31+
32+
int64_t K = x.dims()[x.dims().size() - 1];
33+
int64_t N = dout.dims()[dout.dims().size() - 1];
34+
35+
phi::DenseTensor reshape_x(x);
36+
reshape_x.Resize({x.numel() / K, K});
37+
phi::DenseTensor reshape_dout(dout);
38+
reshape_dout.Resize({dout.numel() / N, N});
39+
40+
if (dweight_out && dweight) {
41+
*dweight_out = dweight.get();
42+
if (multi_precision) {
43+
PADDLE_ENFORCE_EQ(
44+
dweight_out->dtype(),
45+
phi::CppTypeToDataType<MT>::Type(),
46+
phi::errors::InvalidArgument("Invaid data type error."));
47+
} else {
48+
PADDLE_ENFORCE_EQ(
49+
dweight_out->dtype(),
50+
phi::CppTypeToDataType<T>::Type(),
51+
phi::errors::InvalidArgument("Invaid data type error."));
52+
}
53+
} else {
54+
dweight_out->Resize(phi::make_ddim({K, N}));
55+
if (multi_precision) {
56+
dev_ctx.template Alloc<MT>(dweight_out);
57+
} else {
58+
dev_ctx.template Alloc<T>(dweight_out);
59+
}
60+
}
61+
62+
if (has_bias && dbias_out) {
63+
dev_ctx.template Alloc<T>(dbias_out);
64+
}
65+
66+
float alpha = 1.0;
67+
float beta = 1.0;
68+
69+
phi::DenseTensor new_dweight;
70+
if (dweight) {
71+
new_dweight = dweight.get();
72+
} else {
73+
phi::DenseTensorMeta dweight_meta = {x.dtype(), {K, N}};
74+
new_dweight.set_meta(dweight_meta);
75+
FillNpuTensorWithConstant<T>(&new_dweight, dev_ctx, static_cast<T>(0));
76+
new_dweight.Resize({K, N});
77+
}
78+
79+
int64_t trans_a = 1;
80+
int64_t trans_b = 0;
81+
int8_t cube_math_type = 0;
82+
bool keep_dim = false;
83+
EXEC_NPU_CMD(aclnnGemm,
84+
dev_ctx,
85+
reshape_x,
86+
reshape_dout,
87+
new_dweight,
88+
alpha,
89+
beta,
90+
trans_a,
91+
trans_b,
92+
*dweight_out,
93+
cube_math_type);
94+
if (has_bias) {
95+
phi::IntArray axis = {0};
96+
97+
phi::DenseTensor new_dbias;
98+
if (dbias) {
99+
new_dbias = dbias.get();
100+
} else {
101+
phi::DenseTensorMeta new_dbias_meta = {x.dtype(), {N}};
102+
new_dbias.set_meta(new_dbias_meta);
103+
FillNpuTensorWithConstant<T>(&new_dbias, dev_ctx, static_cast<T>(0));
104+
}
105+
106+
auto dst_dtype = ConvertToNpuDtype(reshape_dout.dtype());
107+
108+
phi::DenseTensor bias_sum;
109+
phi::DenseTensorMeta bias_sum_meta = {x.dtype(), {N}};
110+
bias_sum.set_meta(bias_sum_meta);
111+
dev_ctx.template Alloc<T>(&bias_sum);
112+
EXEC_NPU_CMD(aclnnReduceSum,
113+
dev_ctx,
114+
reshape_dout,
115+
axis,
116+
keep_dim,
117+
dst_dtype,
118+
bias_sum);
119+
phi::Scalar add_alpha = 1.0;
120+
EXEC_NPU_CMD(aclnnAdd, dev_ctx, bias_sum, new_dbias, add_alpha, *dbias_out);
121+
}
122+
}
123+
124+
} // namespace custom_kernel
125+
126+
PD_REGISTER_PLUGIN_KERNEL(fused_linear_param_grad_add,
127+
npu,
128+
ALL_LAYOUT,
129+
custom_kernel::FusedLinearParamGradAdd,
130+
float,
131+
double,
132+
phi::dtype::float16,
133+
phi::dtype::bfloat16) {}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
from paddle import _C_ops
21+
from npu_utils import check_soc_version
22+
23+
24+
def promote_dtype(x):
25+
if x.dtype in [paddle.float16, paddle.bfloat16]:
26+
return x.astype(paddle.float32)
27+
else:
28+
return x
29+
30+
31+
def recreate(x, multi_precision):
32+
if isinstance(x, (list, tuple)):
33+
return [recreate(item, multi_precision) for item in x]
34+
35+
if x is None:
36+
return None
37+
38+
if multi_precision:
39+
x = promote_dtype(x)
40+
41+
return paddle.to_tensor(x.numpy())
42+
43+
44+
def run_ground_truth(x, dy, dweight, dbias, multi_precision, has_bias):
45+
x, dy, dweight, dbias = recreate([x, dy, dweight, dbias], multi_precision)
46+
47+
dweight_tmp = paddle.matmul(
48+
x.reshape([-1, x.shape[-1]]),
49+
dy.reshape([-1, dy.shape[-1]]),
50+
transpose_x=True,
51+
)
52+
if dweight is None:
53+
dweight = dweight_tmp
54+
else:
55+
assert dweight.shape == dweight_tmp.shape
56+
assert dweight.dtype == dweight.dtype
57+
dweight += dweight_tmp
58+
59+
if has_bias:
60+
dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0)
61+
if dbias is None:
62+
dbias = dbias_tmp
63+
else:
64+
assert dbias.shape == dbias_tmp.shape
65+
assert dbias.dtype == dbias_tmp.dtype
66+
dbias += dbias_tmp
67+
68+
return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy()
69+
else:
70+
return promote_dtype(dweight).numpy()
71+
72+
73+
def run_fused_linear_param_grad_add(x, dy, dweight, dbias, multi_precision, has_bias):
74+
dweight_new, dbias_new = _C_ops.fused_linear_param_grad_add(
75+
x, dy, dweight, dbias, multi_precision, has_bias
76+
)
77+
if dweight is not None:
78+
assert dweight_new.data_ptr() == dweight.data_ptr()
79+
if has_bias:
80+
return (
81+
promote_dtype(dweight_new).numpy(),
82+
promote_dtype(dbias_new).numpy(),
83+
)
84+
else:
85+
return promote_dtype(dweight_new).numpy()
86+
87+
88+
class TestMainClassBase(unittest.TestCase):
89+
def setUp(self):
90+
self.shape = [3, 4, 32]
91+
self.output_size = 128
92+
self.dtype = paddle.float16
93+
94+
def config(self):
95+
pass
96+
97+
def rand(self, shape, dtype=None):
98+
x = np.random.randint(low=-5, high=5, size=shape)
99+
x = paddle.to_tensor(x)
100+
return x.astype(dtype or self.dtype)
101+
102+
def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision, has_bias):
103+
x_shape = self.shape
104+
dy_shape = self.shape[:-1] + [self.output_size]
105+
dweight_shape = [self.shape[-1], self.output_size]
106+
dbias_shape = [self.output_size]
107+
108+
x = self.rand(x_shape)
109+
dy = self.rand(dy_shape)
110+
if has_dweight:
111+
dweight = self.rand(dweight_shape)
112+
if multi_precision:
113+
dweight = promote_dtype(dweight)
114+
else:
115+
dweight = None
116+
117+
if has_bias and has_dbias:
118+
dbias = self.rand(dbias_shape)
119+
if multi_precision:
120+
dbias = promote_dtype(dbias)
121+
else:
122+
dbias = None
123+
return x, dy, dweight, dbias
124+
125+
def check_main(self, has_dweight, has_dbias, multi_precision, has_bias):
126+
x, dy, dweight, dbias = self.generate_rand_inputs(
127+
has_dweight, has_dbias, multi_precision, has_bias
128+
)
129+
res1 = run_ground_truth(x, dy, dweight, dbias, multi_precision, has_bias)
130+
res2 = run_fused_linear_param_grad_add(
131+
x, dy, dweight, dbias, multi_precision, has_bias
132+
)
133+
self.assertEqual(len(res1), len(res2))
134+
for r1, r2 in zip(res1, res2):
135+
max_diff = np.max(np.abs(r1 - r2))
136+
self.assertLess(max_diff, 1e-10)
137+
138+
@check_soc_version
139+
def test_main(self):
140+
for has_dweight in [False, True]:
141+
for has_bias in [False, True]:
142+
for has_dbias in [False, True]:
143+
for multi_precision in [False, True]:
144+
self.check_main(
145+
has_dweight, has_dbias, multi_precision, has_bias
146+
)
147+
148+
149+
class TestMainClassBF16(TestMainClassBase):
150+
def config(self):
151+
self.dtype = paddle.bfloat16
152+
153+
154+
class TestMainClassFP32(TestMainClassBase):
155+
def config(self):
156+
self.dtype = paddle.float32
157+
158+
159+
class TestMainClassFP64(TestMainClassBase):
160+
def config(self):
161+
self.dtype = paddle.float64
162+
163+
164+
if __name__ == "__main__":
165+
unittest.main()

0 commit comments

Comments
 (0)