Skip to content

Commit bad882c

Browse files
author
gongenlei
authored
Merge pull request PaddlePaddle#858 from gongel/ft_decoder_op
[Feat] Add custom op for FasterTransformer decoder
2 parents 15b3dad + 8019078 commit bad882c

File tree

10 files changed

+1359
-4
lines changed

10 files changed

+1359
-4
lines changed

paddlenlp/ops/CMakeLists.txt

+7-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ option(USE_TENSORRT "Compile with TensorRT."
2626
option(WITH_TRANSFORMER "Compile with Transformer" ON)
2727
option(WITH_GPT "Compile with GPT" OFF)
2828
option(WITH_UNIFIED "Compile with Unified Transformer" ON)
29+
option(WITH_DECODER "Compile with Transformer Decoder" ON)
2930

3031
if(NOT WITH_GPU)
3132
message(FATAL_ERROR "Faster transformer custom op doesn't support CPU. Please add the flag -DWITH_GPU=ON to use GPU. ")
@@ -43,8 +44,12 @@ if(WITH_UNIFIED)
4344
list(APPEND decoding_op_files fusion_unified_decoding_op.cc fusion_unified_decoding_op.cu)
4445
endif()
4546

46-
if(NOT WITH_TRANSFORMER AND NOT WITH_GPT)
47-
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON must be set to use FasterTransformer. ")
47+
if(WITH_DECODER)
48+
list(APPEND decoder_op_files fusion_decoder_op.cc fusion_decoder_op.cu)
49+
endif()
50+
51+
if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER)
52+
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON must be set to use FasterTransformer. ")
4853
endif()
4954

5055
set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})

paddlenlp/ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from .faster_transformer.transformer.decoding import *
1616
from .faster_transformer.transformer.faster_transformer import *
17+
from .faster_transformer.transformer.decoder import *
1718
from .einsum import *
1819
from .distributed import *
1920
from . import optimizer
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Batch size during inference.
2+
infer_batch_size: 8
3+
max_out_len: 256
4+
5+
# Hyparams for model:
6+
# These following five vocabularies related configurations will be set
7+
# automatically according to the passed vocabulary path and special tokens.
8+
# Size of source word dictionary.
9+
src_vocab_size: 38512
10+
# Size of target word dictionay
11+
trg_vocab_size: 38512
12+
# Index for <bos> token
13+
bos_idx: 0
14+
# Index for <eos> token
15+
eos_idx: 1
16+
# Index for <unk> token
17+
unk_idx: 2
18+
# Max length of sequences deciding the size of position encoding table.
19+
max_length: 256
20+
# The dimension for word embeddings, which is also the last dimension of
21+
# the input and output of multi-head attention, position-wise feed-forward
22+
# networks, encoder and decoder.
23+
d_model: 512
24+
# Size of the hidden layer in position-wise feed-forward networks.
25+
d_inner_hid: 2048
26+
# Number of head used in multi-head attention.
27+
n_head: 8
28+
# Number of sub-layers to be stacked in the encoder.
29+
num_encoder_layers: 6
30+
# Number of sub-layers to be stacked in the decoder.
31+
num_decoder_layers: 6
32+
# Dropout rates.
33+
dropout: 0.1
34+
# The flag indicating whether to share embedding and softmax weights.
35+
# Vocabularies in source and target should be same for weight sharing.
36+
weight_sharing: True
37+
38+
# Path of trained parameter, to make prediction
39+
init_from_params: base_trained_models/step_final
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from attrdict import AttrDict
16+
import argparse
17+
import time
18+
19+
import yaml
20+
from pprint import pprint
21+
import paddle
22+
23+
from paddlenlp.ops import FasterDecoder
24+
from paddlenlp.utils.log import logger
25+
26+
27+
def parse_args():
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument(
30+
"--config",
31+
default="./config/decoder.sample.yaml",
32+
type=str,
33+
help="Path of the config file. ")
34+
parser.add_argument(
35+
"--decoder_lib",
36+
default="../../build/lib/libdecoder_op.so",
37+
type=str,
38+
help="Path of libdecoder_op.so. ")
39+
parser.add_argument(
40+
"--use_fp16_decoder",
41+
action="store_true",
42+
help="Whether to use fp16 decoder to predict. ")
43+
args = parser.parse_args()
44+
return args
45+
46+
47+
def do_predict(args):
48+
place = "gpu"
49+
paddle.set_device(place)
50+
51+
# Define model
52+
transformer = FasterDecoder(
53+
src_vocab_size=args.src_vocab_size,
54+
trg_vocab_size=args.trg_vocab_size,
55+
max_length=args.max_length + 1,
56+
num_encoder_layers=args.num_encoder_layers,
57+
num_decoder_layers=args.num_decoder_layers,
58+
n_head=args.n_head,
59+
d_model=args.d_model,
60+
d_inner_hid=args.d_inner_hid,
61+
dropout=args.dropout,
62+
weight_sharing=args.weight_sharing,
63+
bos_id=args.bos_idx,
64+
eos_id=args.eos_idx,
65+
max_out_len=args.max_out_len,
66+
decoder_lib=args.decoder_lib,
67+
use_fp16_decoder=args.use_fp16_decoder)
68+
69+
# Load checkpoint.
70+
transformer.load(
71+
os.path.join(args.init_from_params, "transformer.pdparams"))
72+
# Set evaluate mode
73+
transformer.eval()
74+
75+
# Generate data randomly
76+
dec_input = paddle.randn(
77+
shape=[args.infer_batch_size, 1, args.d_model], dtype='float32')
78+
enc_output = paddle.randn(
79+
shape=[args.infer_batch_size, args.max_length, args.d_model],
80+
dtype='float32')
81+
mem_seq_lens = paddle.full(
82+
shape=[args.infer_batch_size, 1],
83+
fill_value=args.max_length,
84+
dtype='int32')
85+
dtype = 'float32'
86+
if args.use_fp16_decoder:
87+
dtype = 'float16'
88+
dec_input = paddle.cast(dec_input, dtype=dtype)
89+
enc_output = paddle.cast(enc_output, dtype=dtype)
90+
self_cache = paddle.zeros(
91+
shape=[
92+
args.num_decoder_layers, 2, 0, args.infer_batch_size, args.d_model
93+
],
94+
dtype=dtype)
95+
mem_cache = paddle.zeros(
96+
shape=[
97+
args.num_decoder_layers, 2, args.infer_batch_size, args.max_length,
98+
args.d_model
99+
],
100+
dtype=dtype)
101+
102+
with paddle.no_grad():
103+
for i in range(100):
104+
# For warmup.
105+
if 50 == i:
106+
start = time.time()
107+
dec_output, self_cache, mem_cache = transformer.decoder(
108+
from_tensor=dec_input,
109+
memory_tensor=enc_output,
110+
mem_seq_len=mem_seq_lens,
111+
self_cache=self_cache,
112+
mem_cache=mem_cache)
113+
logger.info("Average test time for decoder is %f ms" % (
114+
(time.time() - start) / 50 * 1000))
115+
116+
117+
if __name__ == "__main__":
118+
ARGS = parse_args()
119+
yaml_file = ARGS.config
120+
with open(yaml_file, 'rt') as f:
121+
args = AttrDict(yaml.safe_load(f))
122+
args.decoder_lib = ARGS.decoder_lib
123+
args.use_fp16_decoder = ARGS.use_fp16_decoder
124+
pprint(args)
125+
126+
do_predict(args)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from attrdict import AttrDict
16+
import argparse
17+
import time
18+
19+
import yaml
20+
from pprint import pprint
21+
import paddle
22+
from paddlenlp.ops import FasterDecoder
23+
from paddlenlp.utils.log import logger
24+
25+
26+
def parse_args():
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument(
29+
"--config",
30+
default="./config/decoder.sample.yaml",
31+
type=str,
32+
help="Path of the config file. ")
33+
parser.add_argument(
34+
"--decoder_lib",
35+
default="../../build/lib/libdecoder_op.so",
36+
type=str,
37+
help="Path of libdecoder_op.so. ")
38+
parser.add_argument(
39+
"--use_fp16_decoder",
40+
action="store_true",
41+
help="Whether to use fp16 decoder to predict. ")
42+
args = parser.parse_args()
43+
return args
44+
45+
46+
def do_predict(args):
47+
place = "gpu"
48+
paddle.set_device(place)
49+
50+
# Define model
51+
transformer = FasterDecoder(
52+
src_vocab_size=args.src_vocab_size,
53+
trg_vocab_size=args.trg_vocab_size,
54+
max_length=args.max_length + 1,
55+
num_encoder_layers=args.num_encoder_layers,
56+
num_decoder_layers=args.num_decoder_layers,
57+
n_head=args.n_head,
58+
d_model=args.d_model,
59+
d_inner_hid=args.d_inner_hid,
60+
dropout=args.dropout,
61+
weight_sharing=args.weight_sharing,
62+
bos_id=args.bos_idx,
63+
eos_id=args.eos_idx,
64+
max_out_len=args.max_out_len,
65+
decoder_lib=args.decoder_lib,
66+
use_fp16_decoder=args.use_fp16_decoder)
67+
68+
# Load checkpoint.
69+
transformer.load(
70+
os.path.join(args.init_from_params, "transformer.pdparams"))
71+
# Set evaluate mode
72+
transformer.eval()
73+
74+
# Generate src_word randomly
75+
src_word = paddle.randint(
76+
0,
77+
args.src_vocab_size,
78+
shape=[args.infer_batch_size, args.max_length],
79+
dtype='int64')
80+
81+
with paddle.no_grad():
82+
for i in range(100):
83+
# For warmup.
84+
if 50 == i:
85+
start = time.time()
86+
finished_seq, finished_scores = transformer(src_word=src_word)
87+
logger.info("Average test time for decoder is %f ms" % (
88+
(time.time() - start) / 50 * 1000))
89+
90+
91+
if __name__ == "__main__":
92+
ARGS = parse_args()
93+
yaml_file = ARGS.config
94+
with open(yaml_file, 'rt') as f:
95+
args = AttrDict(yaml.safe_load(f))
96+
args.decoder_lib = ARGS.decoder_lib
97+
args.use_fp16_decoder = ARGS.use_fp16_decoder
98+
pprint(args)
99+
100+
do_predict(args)

paddlenlp/ops/faster_transformer/src/CMakeLists.txt

+6-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ if(ON_INFER)
145145
set(DEPS ${DEPS} shlwapi.lib)
146146
endif(NOT WIN32)
147147

148-
cuda_add_library(pd_infer_custom_op ${decoding_op_files} SHARED)
148+
cuda_add_library(pd_infer_custom_op ${decoding_op_files} ${decoder_op_files} SHARED)
149149
add_dependencies(pd_infer_custom_op extern_${THIRD_PARTY_NAME})
150150
string(REPLACE "/" ";" DEMO_PATH ${DEMO})
151151

@@ -269,4 +269,8 @@ else(ON_INFER)
269269
add_library(decoding_op SHARED ${decoding_op_files})
270270
add_dependencies(decoding_op extern_${THIRD_PARTY_NAME} boost)
271271
target_link_libraries(decoding_op PRIVATE -lcublas -lcudart ${lib_link} ${ft_lib_link})
272-
endif()
272+
273+
add_library(decoder_op SHARED ${decoder_op_files})
274+
add_dependencies(decoder_op extern_${THIRD_PARTY_NAME} boost)
275+
target_link_libraries(decoder_op PRIVATE -lcublas -lcudart -ldecoder ${lib_link})
276+
endif()

0 commit comments

Comments
 (0)