Skip to content

Commit a9618bf

Browse files
committed
Add fused_attention_op: add impl wrappers. (PaddlePaddle#35903)
1 parent 6840cf5 commit a9618bf

File tree

6 files changed

+487
-8
lines changed

6 files changed

+487
-8
lines changed

paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
108108

109109
template <typename InT, typename OutT, int VecSize, typename Functor>
110110
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
111-
__device__ inline void operator()(Functor func, InT **args, OutT *result) {
111+
__device__ inline void operator()(Functor func, InT (*args)[VecSize],
112+
OutT *result) {
112113
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
113114
result, args[0], args[1], args[2], func);
114115
}

paddle/fluid/operators/fused/attention_layer_norm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class AttnLayerNorm {
5050
}
5151
}
5252

53-
void ComputeBackward(const T* x_data const T* y_data,
53+
void ComputeBackward(const T* x_data, const T* d_y_data,
5454
const LayerNormParamType<T>* scale_data,
5555
const LayerNormParamType<T>* mean_data,
5656
const LayerNormParamType<T>* var_data, T* d_x_data,

paddle/fluid/operators/fused/attn_bias_add.cu.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ namespace cub = hipcub;
3434
#define LAUNCH_BOUNDS(BlockDim)
3535
#endif
3636

37+
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
3738
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
3839
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
3940
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
@@ -51,11 +52,6 @@ using CudnnDataType = platform::CudnnDataType<T>;
5152
template <typename T>
5253
using ReduceParamType = typename CudnnDataType<T>::BatchNormParamType;
5354

54-
template <typename T>
55-
struct AddFunctor {
56-
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; }
57-
};
58-
5955
template <typename InT, typename OutT, int ShapeSize, int VecSize,
6056
int DATA_PER_THREAD, typename Functor>
6157
__global__ void BroadcastKernelBinary(
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#pragma once
13+
14+
#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
15+
#include "paddle/fluid/operators/math/blas.h"
16+
#include "paddle/fluid/platform/float16.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
// support gemm-nt and gemm-nn, which is used in fused_attention_op.
22+
template <typename T>
23+
class AttnMatMul {
24+
public:
25+
// (m, n, k) = bsz_seq, output_size, input_size
26+
AttnMatMul(const platform::CUDADeviceContext& dev_ctx, bool transA,
27+
bool transB, int bsz_seq, int output_size, int input_size,
28+
bool compute_bias)
29+
: dev_ctx_(dev_ctx),
30+
transA_(transA),
31+
transB_(transB),
32+
bsz_seq_(bsz_seq),
33+
output_size_(output_size),
34+
input_size_(input_size),
35+
compute_bias_(compute_bias) {}
36+
37+
~AttnMatMul() {}
38+
39+
void ComputeForward(const T* weight_data, const T* input_data,
40+
const T* bias_data, T* output_data, T* bias_out_data) {
41+
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
42+
// here: (transa, transb): nt, input * weight.
43+
CBLAS_TRANSPOSE transA = CblasNoTrans;
44+
CBLAS_TRANSPOSE transB = CblasNoTrans;
45+
if (transA_) {
46+
transA = CblasTrans;
47+
}
48+
if (transB_) {
49+
transB = CblasTrans;
50+
}
51+
T alpha = static_cast<T>(1.0);
52+
T beta = static_cast<T>(0.0);
53+
54+
// here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
55+
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
56+
blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
57+
input_data, weight_data, beta, output_data);
58+
if (compute_bias_) {
59+
// compute output + bias
60+
LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data,
61+
bias_data, bias_out_data);
62+
}
63+
}
64+
65+
void ComputeBackward(const T* input, const T* weight, const T* d_output,
66+
T* d_input, T* d_weight, T* d_bias) {
67+
T alpha = static_cast<T>(1.0);
68+
T beta = static_cast<T>(0.0);
69+
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
70+
71+
CBLAS_TRANSPOSE dB_transA = CblasNoTrans;
72+
CBLAS_TRANSPOSE dB_transB = CblasNoTrans;
73+
CBLAS_TRANSPOSE dA_transA = CblasNoTrans;
74+
CBLAS_TRANSPOSE dA_transB = CblasNoTrans;
75+
int dB_m = 1;
76+
int dB_n = 1;
77+
int dB_k = 1;
78+
int dA_m = 1;
79+
int dA_n = 1;
80+
int dA_k = 1;
81+
82+
T* dB_input_1_ptr = nullptr;
83+
T* dB_input_2_ptr = nullptr;
84+
T* dB_output_ptr = d_weight;
85+
86+
T* dA_input_1_ptr = nullptr;
87+
T* dA_input_2_ptr = nullptr;
88+
T* dA_output_ptr = d_input;
89+
90+
if (!transA_) {
91+
// fw: gemm-nt
92+
if (transB_) {
93+
// bw: gemm-tn, dB = (dC)^t * A
94+
dB_transA = CblasTrans;
95+
dB_transB = CblasNoTrans;
96+
dB_m = output_size_;
97+
dB_n = input_size_;
98+
dB_k = bsz_seq_;
99+
100+
// bw: gemm-nn, dA = dC * B
101+
dA_transA = CblasNoTrans;
102+
dA_transB = CblasNoTrans;
103+
dA_m = bsz_seq_;
104+
dA_n = input_size_;
105+
dA_k = output_size_;
106+
107+
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output,
108+
input, beta, dB_output_ptr);
109+
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
110+
weight, beta, dA_output_ptr);
111+
} else { // fw: gemm-nn
112+
// bw: gemm-tn, dB = A^t * dC
113+
dB_transA = CblasTrans;
114+
dB_transB = CblasNoTrans;
115+
dB_m = input_size_;
116+
dB_n = output_size_;
117+
dB_k = bsz_seq_;
118+
119+
// bw: gemm-nt, dA = dC * B^t
120+
dA_transA = CblasNoTrans;
121+
dA_transB = CblasTrans;
122+
dA_m = bsz_seq_;
123+
dA_n = input_size_;
124+
dA_k = output_size_;
125+
126+
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input,
127+
d_output, beta, dB_output_ptr);
128+
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
129+
weight, beta, dA_output_ptr);
130+
}
131+
} else if (transB_) {
132+
PADDLE_THROW(platform::errors::InvalidArgument(
133+
"AttnMatMul wrapper do not support (transA=T, transB=T)"
134+
"parameters."));
135+
} else {
136+
PADDLE_THROW(platform::errors::InvalidArgument(
137+
"AttnMatMul wrapper do not support (transA=T, transB=N)"
138+
"parameters."));
139+
}
140+
if (compute_bias_) {
141+
LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias);
142+
}
143+
}
144+
145+
private:
146+
const platform::CUDADeviceContext& dev_ctx_;
147+
148+
bool transA_;
149+
bool transB_;
150+
151+
int bsz_seq_;
152+
int output_size_;
153+
int input_size_;
154+
155+
int compute_bias_;
156+
};
157+
158+
} // namespace operators
159+
} // namespace paddle

0 commit comments

Comments
 (0)