Skip to content

Commit fa1fa75

Browse files
FrostMLZeyuChenguoshengCS
authored
Transformer decoding support fuse qkv (#1455)
* decoding support fuseqkv ! ! ! * update force version * fp16 * alternative * force decoding support global cublashandle and cublaslthandle * update * update * update * rm ref Co-authored-by: Zeyu Chen <chenzeyu01@baidu.com> Co-authored-by: Guo Sheng <whucsgs@163.com>
1 parent 5f2862a commit fa1fa75

9 files changed

+141
-68
lines changed

paddlenlp/ops/faster_transformer/sample/decoding_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def parse_args():
3535
parser = argparse.ArgumentParser()
3636
parser.add_argument(
3737
"--config",
38-
default="./sample/config/decoding.sample.yaml",
38+
default="./faster_transformer/sample/config/decoding.sample.yaml",
3939
type=str,
4040
help="Path of the config file. ")
4141
parser.add_argument(

paddlenlp/ops/faster_transformer/src/cublas_handle.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ CublasHandle* CublasHandle::GetInstance() {
2525
CublasHandle::~CublasHandle() {
2626
cublasDestroy(cublas_handle_);
2727
cublasLtDestroy(cublaslt_handle_);
28-
}
28+
}

paddlenlp/ops/faster_transformer/src/cublas_handle.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@ class CublasHandle {
5555
cublasLtHandle_t cublaslt_handle_;
5656

5757
~CublasHandle();
58-
};
58+
};

paddlenlp/ops/faster_transformer/src/fusion_decoding_op.cu

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ limitations under the License. */
2121
#include <sstream>
2222
#include <vector>
2323

24-
#include "cublas_handle.h"
2524
#include "fastertransformer/cuda/cub/cub.cuh"
2625
#include "fusion_decoding_op.h"
2726
#include "pd_traits.h"
@@ -125,6 +124,10 @@ std::vector<paddle::Tensor> decoding_kernel(
125124
DecoderInitParam<DataType_>* params =
126125
new DecoderInitParam<DataType_>[num_layer_];
127126

127+
auto q_weight_shape = self_attn_query_weight[0].shape();
128+
auto k_weight_shape = self_attn_key_weight[0].shape();
129+
bool fuse_qkv = (q_weight_shape[1] == k_weight_shape[1]) ? false : true;
130+
128131
for (int i = 0; i < num_layer_; i++) {
129132
params[i].stream = stream;
130133
params[i].cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
@@ -261,7 +264,8 @@ std::vector<paddle::Tensor> decoding_kernel(
261264
start_id_,
262265
end_id_,
263266
beam_search_diversity_rate_,
264-
true); // is_fuse_topk_softMax
267+
true, // is_fuse_topk_softMax
268+
fuse_qkv);
265269

266270
decoding_beam_search_->forward(params, decoding_params);
267271

@@ -283,7 +287,7 @@ std::vector<paddle::Tensor> decoding_kernel(
283287
end_id_,
284288
beam_search_diversity_rate_,
285289
true, // is_fuse_topk_softMax
286-
false, // is_fuse_qkv
290+
fuse_qkv,
287291
true, // keep_alive_beam
288292
alpha);
289293

@@ -307,7 +311,8 @@ std::vector<paddle::Tensor> decoding_kernel(
307311
start_id_,
308312
end_id_,
309313
candidate_num_,
310-
probability_threshold_);
314+
probability_threshold_,
315+
fuse_qkv);
311316

312317
decoding_sampling_->forward(params, decoding_params);
313318

paddlenlp/ops/faster_transformer/src/fusion_decoding_op.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License. */
1616
#include <string>
1717
#include <vector>
1818

19+
#include "cublas_handle.h"
20+
1921
#include "fastertransformer/decoding_beamsearch.h"
2022
#include "fastertransformer/decoding_sampling.h"
2123
#include "fastertransformer/open_decoder.h"

paddlenlp/ops/faster_transformer/src/fusion_force_decoding_op.cu

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,18 @@ std::vector<paddle::Tensor> decoding_kernel(
8383
paddle::Tensor& output_ids,
8484
paddle::Tensor& parent_ids,
8585
paddle::Tensor& sequence_length,
86-
std::string decoding_strategy,
87-
int beam_size,
88-
int topk,
89-
float topp,
90-
int head_num_,
91-
int size_per_head_,
92-
int num_layer_,
93-
int start_id_,
94-
int end_id_,
95-
int64_t max_seq_len_,
96-
float beam_search_diversity_rate_,
97-
float alpha,
98-
cublasHandle_t cublas_handle_,
99-
cublasLtHandle_t cublaslt_handle_,
86+
const std::string& decoding_strategy,
87+
const int beam_size,
88+
const int topk,
89+
const float topp,
90+
const int head_num_,
91+
const int size_per_head_,
92+
const int num_layer_,
93+
const int start_id_,
94+
const int end_id_,
95+
const int64_t max_seq_len_,
96+
const float beam_search_diversity_rate_,
97+
const float alpha,
10098
cudaStream_t stream) {
10199
int beam_width_ = (decoding_strategy == "beam_search" ||
102100
decoding_strategy == "beam_search_v2")
@@ -119,8 +117,9 @@ std::vector<paddle::Tensor> decoding_kernel(
119117
typedef typename traits_::data_t data_t_;
120118

121119
DecodingInitParam<DataType_> decoding_params;
122-
decoding_params.cublas_handle = cublas_handle_;
123-
decoding_params.cublaslt_handle = cublaslt_handle_;
120+
decoding_params.cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
121+
decoding_params.cublaslt_handle =
122+
CublasHandle::GetInstance()->cublaslt_handle_;
124123

125124
decoding_params.output_ids = output_ids.mutable_data<int>(input.place());
126125
decoding_params.parent_ids = parent_ids.mutable_data<int>(input.place());
@@ -156,10 +155,14 @@ std::vector<paddle::Tensor> decoding_kernel(
156155
DecoderInitParam<DataType_>* params =
157156
new DecoderInitParam<DataType_>[num_layer_];
158157

158+
auto q_weight_shape = self_attn_query_weight[0].shape();
159+
auto k_weight_shape = self_attn_key_weight[0].shape();
160+
bool fuse_qkv = (q_weight_shape[1] == k_weight_shape[1]) ? false : true;
161+
159162
for (int i = 0; i < num_layer_; i++) {
160163
params[i].stream = stream;
161-
params[i].cublas_handle = cublas_handle_;
162-
params[i].cublaslt_handle = cublaslt_handle_;
164+
params[i].cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
165+
params[i].cublaslt_handle = CublasHandle::GetInstance()->cublaslt_handle_;
163166

164167
if (decoding_strategy == "beam_search" ||
165168
decoding_strategy == "beam_search_v2") {
@@ -292,7 +295,8 @@ std::vector<paddle::Tensor> decoding_kernel(
292295
start_id_,
293296
end_id_,
294297
beam_search_diversity_rate_,
295-
true); // is_fuse_topk_softMax
298+
true, // is_fuse_topk_softMax
299+
fuse_qkv); // is_fuse_qkv
296300

297301
decoding_beam_search_->forward(params, decoding_params);
298302

@@ -314,7 +318,7 @@ std::vector<paddle::Tensor> decoding_kernel(
314318
end_id_,
315319
beam_search_diversity_rate_,
316320
true, // is_fuse_topk_softMax
317-
false, // is_fuse_qkv
321+
fuse_qkv, // is_fuse_qkv
318322
true, // keep_alive_beam
319323
alpha);
320324

@@ -338,7 +342,8 @@ std::vector<paddle::Tensor> decoding_kernel(
338342
start_id_,
339343
end_id_,
340344
candidate_num_,
341-
probability_threshold_);
345+
probability_threshold_,
346+
fuse_qkv);
342347

343348
decoding_sampling_->forward(params, decoding_params);
344349

@@ -392,24 +397,20 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
392397
paddle::Tensor& output_ids,
393398
paddle::Tensor& parent_ids,
394399
paddle::Tensor& sequence_length,
395-
std::string decoding_strategy,
396-
int beam_size,
397-
int topk,
398-
float topp,
399-
int n_head,
400-
int size_per_head,
401-
int num_layer,
402-
int bos_id,
403-
int eos_id,
404-
int64_t max_len,
405-
float beam_search_diversity_rate,
406-
float alpha) {
400+
const std::string& decoding_strategy,
401+
const int beam_size,
402+
const int topk,
403+
const float topp,
404+
const int n_head,
405+
const int size_per_head,
406+
const int num_layer,
407+
const int bos_id,
408+
const int eos_id,
409+
const int64_t max_len,
410+
const float beam_search_diversity_rate,
411+
const float alpha) {
407412
auto stream = input.stream();
408-
cublasHandle_t cublas_handle_;
409-
cublasCreate(&cublas_handle_);
410-
cublasLtHandle_t cublaslt_handle_;
411-
cublasLtCreate(&cublaslt_handle_);
412-
cublasSetStream(cublas_handle_, stream);
413+
cublasSetStream(CublasHandle::GetInstance()->cublas_handle_, stream);
413414

414415
std::vector<paddle::Tensor> ret;
415416

@@ -466,8 +467,6 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
466467
max_len,
467468
beam_search_diversity_rate,
468469
alpha,
469-
cublas_handle_,
470-
cublaslt_handle_,
471470
stream);
472471
break;
473472
}
@@ -523,8 +522,6 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
523522
max_len,
524523
beam_search_diversity_rate,
525524
alpha,
526-
cublas_handle_,
527-
cublaslt_handle_,
528525
stream);
529526
break;
530527
}
@@ -536,7 +533,5 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
536533
}
537534
}
538535

539-
cublasDestroy(cublas_handle_);
540-
cublasLtDestroy(cublaslt_handle_);
541536
return ret;
542537
}

paddlenlp/ops/faster_transformer/src/fusion_force_decoding_op.h

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License. */
1616
#include <string>
1717
#include <vector>
1818

19+
#include "cublas_handle.h"
20+
1921
#include "fastertransformer/decoding_beamsearch.h"
2022
#include "fastertransformer/decoding_sampling.h"
2123
#include "fastertransformer/open_decoder.h"
@@ -67,15 +69,15 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
6769
paddle::Tensor& output_ids,
6870
paddle::Tensor& parent_ids,
6971
paddle::Tensor& sequence_length,
70-
std::string decoding_strategy,
71-
int beam_size,
72-
int topk,
73-
float topp,
74-
int n_head,
75-
int size_per_head,
76-
int num_layer,
77-
int bos_id,
78-
int eos_id,
79-
int64_t max_len,
80-
float beam_search_diversity_rate,
81-
float alpha);
72+
const std::string& decoding_strategy,
73+
const int beam_size,
74+
const int topk,
75+
const float topp,
76+
const int n_head,
77+
const int size_per_head,
78+
const int num_layer,
79+
const int bos_id,
80+
const int eos_id,
81+
const int64_t max_len,
82+
const float beam_search_diversity_rate,
83+
const float alpha);

paddlenlp/ops/faster_transformer/transformer/decoding.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,13 @@ def __init__(self,
619619
)
620620
load("FasterTransformer", verbose=True)
621621

622+
size_per_head = d_model / n_head
623+
# fuse_qkv can only support size_per_head is one of [32, 64, 128].
624+
if size_per_head in [32, 64, 128]:
625+
self._fuse_qkv = True
626+
else:
627+
self._fuse_qkv = False
628+
622629
super(InferTransformerDecoding, self).__init__()
623630
for arg, value in locals().items():
624631
if arg not in [
@@ -715,11 +722,41 @@ def __init__(self,
715722
self.ffn_out_weight = []
716723
self.ffn_out_bias = []
717724

718-
for mod in decoder.layers:
725+
for i, mod in enumerate(decoder.layers):
719726
self.slf_ln_weight.append(mod.norm1.weight)
720727
self.slf_ln_bias.append(mod.norm1.bias)
721-
self.slf_q_weight.append(mod.self_attn.q_proj.weight)
722-
self.slf_q_bias.append(mod.self_attn.q_proj.bias)
728+
729+
if self._fuse_qkv:
730+
q_weight_shape = mod.self_attn.q_proj.weight.shape
731+
k_weight_shape = mod.self_attn.k_proj.weight.shape
732+
v_weight_shape = mod.self_attn.v_proj.weight.shape
733+
734+
q_weights = self.create_parameter(
735+
shape=[
736+
q_weight_shape[0], q_weight_shape[1] + k_weight_shape[1]
737+
+ v_weight_shape[1]
738+
],
739+
dtype="float16" if use_fp16_decoding else "float32")
740+
setattr(self, "slf_q_weight_" + str(i), q_weights)
741+
self.slf_q_weight.append(
742+
getattr(self, "slf_q_weight_" + str(i)))
743+
744+
q_bias_shape = mod.self_attn.q_proj.bias.shape
745+
k_bias_shape = mod.self_attn.k_proj.bias.shape
746+
v_bias_shape = mod.self_attn.v_proj.bias.shape
747+
748+
q_biases = self.create_parameter(
749+
shape=[
750+
q_bias_shape[0] + k_bias_shape[0] + v_bias_shape[0]
751+
],
752+
dtype="float16" if use_fp16_decoding else "float32",
753+
is_bias=True)
754+
setattr(self, "slf_q_bias_" + str(i), q_biases)
755+
self.slf_q_bias.append(getattr(self, "slf_q_bias_" + str(i)))
756+
else:
757+
self.slf_q_weight.append(mod.self_attn.q_proj.weight)
758+
self.slf_q_bias.append(mod.self_attn.q_proj.bias)
759+
723760
self.slf_k_weight.append(mod.self_attn.k_proj.weight)
724761
self.slf_k_bias.append(mod.self_attn.k_proj.bias)
725762
self.slf_v_weight.append(mod.self_attn.v_proj.weight)

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,25 @@ def load(self, init_from_params):
272272
model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
273273
self.max_length, self.d_model)
274274

275+
if self.decoding._fuse_qkv:
276+
for item in self.state_dict():
277+
if "decoder" in item and "self_attn.q_proj" in item:
278+
num_layer = item.split(".")[3]
279+
param_type = item.split(".")[-1]
280+
281+
model_dict["decoding.slf_q_" + param_type + "_" +
282+
num_layer] = np.concatenate(
283+
(model_dict[item], model_dict[
284+
"transformer.decoder.layers." + num_layer
285+
+ ".self_attn.k_proj." + param_type],
286+
model_dict["transformer.decoder.layers." +
287+
num_layer + ".self_attn.v_proj."
288+
+ param_type]),
289+
axis=-1)
290+
275291
if self.use_fp16_decoding:
276292
for item in self.state_dict():
277-
if "decoder" in item:
293+
if "decoder" in item or "decoding.slf" in item:
278294
model_dict[item] = np.float16(model_dict[item])
279295
model_dict["decoding_linear.weight"] = np.float16(model_dict[
280296
"decoding_linear.weight"])
@@ -377,9 +393,25 @@ def export_params(self, init_from_params, place):
377393
model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
378394
self.max_length, self.d_model)
379395

396+
if self.decoding._fuse_qkv:
397+
for item in self.state_dict():
398+
if "decoder" in item and "self_attn.q_proj" in item:
399+
num_layer = item.split(".")[3]
400+
param_type = item.split(".")[-1]
401+
402+
model_dict["decoding.slf_q_" + param_type + "_" +
403+
num_layer] = np.concatenate(
404+
(model_dict[item], model_dict[
405+
"transformer.decoder.layers." + num_layer
406+
+ ".self_attn.k_proj." + param_type],
407+
model_dict["transformer.decoder.layers." +
408+
num_layer + ".self_attn.v_proj."
409+
+ param_type]),
410+
axis=-1)
411+
380412
if self.use_fp16_decoding:
381413
for item in self.state_dict():
382-
if "decoder" in item:
414+
if "decoder" in item or "decoding.slf" in item:
383415
model_dict[item] = np.float16(model_dict[item])
384416
model_dict["decoding_linear.weight"] = np.float16(model_dict[
385417
"decoding_linear.weight"])

0 commit comments

Comments
 (0)