Skip to content

Commit 59db5ef

Browse files
committed
add tensorrt ut and refine interface.
test=release/1.0.0
1 parent 644bad1 commit 59db5ef

15 files changed

+180
-27
lines changed

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,16 @@ TEST(Analyzer, analysis_without_tensorrt) {
3737
TEST(Analyzer, analysis_with_tensorrt) {
3838
FLAGS_IA_enable_tensorrt_subgraph_engine = true;
3939
Argument argument;
40+
argument.Set<int>("minimum_subgraph_size", new int(0));
41+
argument.Set<int>("max_batch_size", new int(3));
42+
argument.Set<int>("workspace_size", new int(1 << 20));
43+
argument.Set<std::string>("precision_mode", new std::string("FP32"));
4044
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
4145
Analyzer analyser;
4246
analyser.Run(&argument);
4347
}
4448

45-
void TestWord2vecPrediction(const std::string &model_path) {
49+
void TestWord2vecPrediction(const std::string& model_path) {
4650
NativeConfig config;
4751
config.model_dir = model_path;
4852
config.use_gpu = false;
@@ -73,8 +77,8 @@ void TestWord2vecPrediction(const std::string &model_path) {
7377
// The outputs' buffers are in CPU memory.
7478
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
7579
LOG(INFO) << "data: "
76-
<< static_cast<float *>(outputs.front().data.data())[i];
77-
PADDLE_ENFORCE(static_cast<float *>(outputs.front().data.data())[i],
80+
<< static_cast<float*>(outputs.front().data.data())[i];
81+
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
7882
result[i]);
7983
}
8084
}

paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,10 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
9797
}
9898
}
9999

100-
void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
100+
void CreateTrtEngineOp(Node *node, Argument *argument,
101101
framework::proto::BlockDesc *block) {
102+
PADDLE_ENFORCE(argument->main_dfg.get());
103+
const DataFlowGraph &graph = *(argument->main_dfg);
102104
static int counter{0};
103105
PADDLE_ENFORCE(node->IsFunctionBlock());
104106
framework::OpDesc desc;
@@ -204,7 +206,10 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
204206

205207
PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc");
206208
// Set attrs
209+
207210
SetAttr(desc.Proto(), "subgraph", block->SerializeAsString());
211+
SetAttr(desc.Proto(), "max_batch_size", argument->Get<int>("max_batch_size"));
212+
SetAttr(desc.Proto(), "workspace_size", argument->Get<int>("workspace_size"));
208213
SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++));
209214
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes()));
210215
SetAttr(desc.Proto(), "output_name_mapping", output_mapping);
@@ -248,7 +253,7 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
248253
*block_desc.Proto()->mutable_vars() =
249254
argument_->origin_program_desc->blocks(0).vars();
250255
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty());
251-
CreateTrtEngineOp(node, *argument_->main_dfg, block_desc.Proto());
256+
CreateTrtEngineOp(node, argument_, block_desc.Proto());
252257
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
253258
auto *op = main_block->add_ops();
254259
PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");

paddle/fluid/inference/analysis/subgraph_splitter.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); }
309309
void SubGraphFuse::ReplaceNodesWithSubGraphs() {
310310
auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)();
311311
for (auto &subgraph : subgraphs) {
312+
if (subgraph.size() <= argument_->Get<int>("minimum_subgraph_size"))
313+
continue;
312314
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
313315
// replace this sub-graph with the first node. Two steps: 1. Create a Block
314316
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph

paddle/fluid/inference/analysis/subgraph_splitter.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020

2121
#include <vector>
2222

23+
#include "paddle/fluid/inference/analysis/argument.h"
2324
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
2425
#include "paddle/fluid/inference/analysis/node.h"
2526

@@ -63,8 +64,11 @@ class SubGraphFuse {
6364
public:
6465
using NodeInsideSubgraphTeller = SubGraphSplitter::NodeInsideSubgraphTeller;
6566

66-
SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller)
67-
: graph_(graph), node_inside_subgraph_teller_(teller) {}
67+
SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller,
68+
Argument *argument)
69+
: graph_(graph),
70+
node_inside_subgraph_teller_(teller),
71+
argument_(argument) {}
6872

6973
// The main method which run all the logic.
7074
void operator()();
@@ -76,6 +80,7 @@ class SubGraphFuse {
7680
private:
7781
DataFlowGraph *graph_;
7882
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
83+
Argument *argument_;
7984
};
8085

8186
} // namespace analysis

paddle/fluid/inference/analysis/subgraph_splitter_tester.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ TEST(SubGraphSplitter, Split) {
6666
TEST(SubGraphSplitter, Fuse) {
6767
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
6868
auto dfg = ProgramDescToDFG(desc);
69+
Argument argument;
70+
argument.Set<int>("minimum_subgraph_size", new int(3));
6971

7072
size_t count0 = dfg.nodes.size();
7173

72-
SubGraphFuse fuse(&dfg, teller);
74+
SubGraphFuse fuse(&dfg, teller, &argument);
7375
fuse();
7476

7577
int count1 = 0;

paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ TensorRTSubGraphPass::TensorRTSubGraphPass(
2424
: node_inside_subgraph_teller_(teller) {}
2525

2626
void TensorRTSubGraphPass::Run(DataFlowGraph *graph) {
27-
SubGraphFuse(graph, node_inside_subgraph_teller_)();
27+
SubGraphFuse(graph, node_inside_subgraph_teller_, argument_)();
2828
VLOG(4) << "debug info "
2929
<< graph->HumanReadableInfo(false /*show_values*/,
3030
true /*show_functions*/);

paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ class TensorRTSubGraphPass : public DataFlowGraphPass {
3333

3434
explicit TensorRTSubGraphPass(const NodeInsideSubgraphTeller& teller);
3535

36-
bool Initialize(Argument* argument) override { return true; }
36+
bool Initialize(Argument* argument) override {
37+
argument_ = argument;
38+
return true;
39+
}
3740

3841
// This class get a sub-graph as input and determine whether to transform this
3942
// sub-graph into TensorRT.
@@ -46,6 +49,7 @@ class TensorRTSubGraphPass : public DataFlowGraphPass {
4649

4750
private:
4851
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
52+
Argument* argument_;
4953
};
5054

5155
} // namespace analysis

paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ TEST(TensorRTSubGraphPass, main) {
3636
};
3737

3838
Argument argument(FLAGS_inference_model_dir);
39+
argument.Set<int>("minimum_subgraph_size", new int(0));
40+
argument.Set<int>("max_batch_size", new int(3));
41+
argument.Set<int>("workspace_size", new int(1 << 20));
42+
argument.Set<std::string>("precision_mode", new std::string("FP32"));
3943

4044
DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"};
4145
DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"};

paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
3535
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
3636
FLAGS_IA_enable_tensorrt_subgraph_engine = true;
3737
VLOG(3) << "Predictor::init()";
38-
FLAGS_tensorrt_max_batch_size = config_.max_batch_size;
39-
FLAGS_tensorrt_workspace_size = config_.workspace_size;
4038
if (config_.use_gpu) {
4139
place_ = paddle::platform::CUDAPlace(config_.device);
4240
} else {
@@ -92,6 +90,14 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
9290
void OptimizeInferenceProgram() {
9391
// Analyze inference_program
9492
Argument argument;
93+
94+
argument.Set<int>("minimum_subgraph_size",
95+
new int(config_.minimum_subgraph_size));
96+
argument.Set<int>("max_batch_size", new int(config_.max_batch_size));
97+
argument.Set<int>("workspace_size", new int(config_.workspace_size));
98+
argument.Set<std::string>("precision_mode",
99+
new std::string(config_.precision_mode));
100+
95101
if (!config_.model_dir.empty()) {
96102
argument.fluid_model_dir.reset(new std::string(config_.model_dir));
97103
} else {

paddle/fluid/inference/api/paddle_inference_api.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,14 @@ struct MixedRTConfig : public NativeConfig {
194194
// For workspace_size, refer it from here:
195195
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting
196196
int workspace_size{1 << 30};
197+
// We transform the Ops that can be converted into TRT layer in the model,
198+
// and aggregate these Ops into subgraphs for TRT execution.
199+
// We set this variable to control the minimum number of nodes in the
200+
// subgraph, 3 as default value.
201+
int minimum_subgraph_size = 3;
202+
// Reserved configuration
203+
// We just support "FP32" now, "FP16" and "INT8" will be supported.
204+
std::string precision_mode = "FP32";
197205
};
198206

199207
// NOTE WIP, not stable yet.

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,13 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI
8585
DEPS inference_anakin_api_shared dynload_cuda SERIAL)
8686
endif()
8787
endif()
88+
89+
if(WITH_GPU AND TENSORRT_FOUND)
90+
set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt")
91+
if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR})
92+
inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz")
93+
endif()
94+
cc_test(test_trt_models SRCS trt_models_tester.cc
95+
ARGS --dirname=${TRT_MODEL_INSTALL_DIR}/trt_test_models
96+
DEPS paddle_inference_tensorrt_subgraph_engine)
97+
endif()
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <gflags/gflags.h>
16+
#include <glog/logging.h>
17+
#include <gtest/gtest.h>
18+
#include "paddle/fluid/inference/analysis/analyzer.h"
19+
#include "paddle/fluid/inference/api/paddle_inference_api.h"
20+
21+
namespace paddle {
22+
using paddle::contrib::MixedRTConfig;
23+
24+
DEFINE_string(dirname, "", "Directory of the inference model.");
25+
26+
NativeConfig GetConfigNative() {
27+
NativeConfig config;
28+
config.model_dir = FLAGS_dirname;
29+
// LOG(INFO) << "dirname " << config.model_dir;
30+
config.fraction_of_gpu_memory = 0.45;
31+
config.use_gpu = true;
32+
config.device = 0;
33+
return config;
34+
}
35+
36+
MixedRTConfig GetConfigTRT() {
37+
MixedRTConfig config;
38+
config.model_dir = FLAGS_dirname;
39+
config.use_gpu = true;
40+
config.fraction_of_gpu_memory = 0.2;
41+
config.device = 0;
42+
config.max_batch_size = 3;
43+
return config;
44+
}
45+
46+
void CompareTensorRTWithFluid(int batch_size, std::string model_dirname) {
47+
NativeConfig config0 = GetConfigNative();
48+
config0.model_dir = model_dirname;
49+
50+
MixedRTConfig config1 = GetConfigTRT();
51+
config1.model_dir = model_dirname;
52+
config1.max_batch_size = batch_size;
53+
54+
auto predictor0 =
55+
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0);
56+
auto predictor1 =
57+
CreatePaddlePredictor<MixedRTConfig,
58+
PaddleEngineKind::kAutoMixedTensorRT>(config1);
59+
// Prepare inputs
60+
int height = 224;
61+
int width = 224;
62+
float *data = new float[batch_size * 3 * height * width];
63+
memset(data, 0, sizeof(float) * (batch_size * 3 * height * width));
64+
data[0] = 1.0f;
65+
66+
// Prepare inputs
67+
PaddleTensor tensor;
68+
tensor.name = "input_0";
69+
tensor.shape = std::vector<int>({batch_size, 3, height, width});
70+
tensor.data = PaddleBuf(static_cast<void *>(data),
71+
sizeof(float) * (batch_size * 3 * height * width));
72+
tensor.dtype = PaddleDType::FLOAT32;
73+
std::vector<PaddleTensor> paddle_tensor_feeds(1, tensor);
74+
75+
// Prepare outputs
76+
std::vector<PaddleTensor> outputs0;
77+
std::vector<PaddleTensor> outputs1;
78+
CHECK(predictor0->Run(paddle_tensor_feeds, &outputs0));
79+
80+
CHECK(predictor1->Run(paddle_tensor_feeds, &outputs1, batch_size));
81+
82+
// Get output.
83+
ASSERT_EQ(outputs0.size(), 1UL);
84+
ASSERT_EQ(outputs1.size(), 1UL);
85+
86+
const size_t num_elements = outputs0.front().data.length() / sizeof(float);
87+
const size_t num_elements1 = outputs1.front().data.length() / sizeof(float);
88+
EXPECT_EQ(num_elements, num_elements1);
89+
90+
auto *data0 = static_cast<float *>(outputs0.front().data.data());
91+
auto *data1 = static_cast<float *>(outputs1.front().data.data());
92+
93+
ASSERT_GT(num_elements, 0UL);
94+
for (size_t i = 0; i < std::min(num_elements, num_elements1); i++) {
95+
EXPECT_NEAR(data0[i], data1[i], 1e-3);
96+
}
97+
}
98+
99+
TEST(trt_models_test, main) {
100+
std::vector<std::string> infer_models = {"mobilenet", "resnet50",
101+
"resnext50"};
102+
for (auto &model_dir : infer_models) {
103+
CompareTensorRTWithFluid(1, FLAGS_dirname + "/" + model_dir);
104+
}
105+
}
106+
} // namespace paddle

paddle/fluid/operators/tensorrt_engine_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
namespace paddle {
2323

2424
DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT");
25-
DEFINE_int32(tensorrt_max_batch_size, 1, "TensorRT maximum batch size");
26-
DEFINE_int32(tensorrt_workspace_size, 16 << 20, "TensorRT workspace size");
2725

2826
namespace operators {
2927

@@ -34,6 +32,8 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
3432
AddOutput("Ys", "A list of outputs").AsDuplicable();
3533
AddAttr<std::string>("subgraph", "the subgraph.");
3634
AddAttr<std::string>("engine_uniq_key", "unique key for the TRT engine.");
35+
AddAttr<int>("max_batch_size", "the maximum batch size.");
36+
AddAttr<int>("workspace_size", "the workspace size.");
3737
AddComment("TensorRT engine operator.");
3838
}
3939
};

paddle/fluid/operators/tensorrt_engine_op.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
namespace paddle {
2929

3030
DECLARE_int32(tensorrt_engine_batch_size);
31-
DECLARE_int32(tensorrt_max_batch_size);
32-
DECLARE_int32(tensorrt_workspace_size);
3331

3432
namespace operators {
3533

@@ -92,14 +90,14 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
9290
public:
9391
void Compute(const framework::ExecutionContext& context) const override {
9492
auto engine_name = context.Attr<std::string>("engine_uniq_key");
93+
int max_batch_size = context.Attr<int>("max_batch_size");
9594
if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) {
9695
Prepare(context);
9796
}
9897
auto* engine = Singleton<TRT_EngineManager>::Global().Get(engine_name);
9998
auto input_names = context.op().Inputs("Xs");
10099
PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs");
101-
PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size,
102-
FLAGS_tensorrt_max_batch_size);
100+
PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, max_batch_size);
103101

104102
std::vector<std::string> output_maps =
105103
context.Attr<std::vector<std::string>>("output_name_mapping");
@@ -173,8 +171,9 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
173171
// Get the ProgramDesc and pass to convert.
174172
framework::proto::BlockDesc block_desc;
175173
block_desc.ParseFromString(context.Attr<std::string>("subgraph"));
176-
int max_batch = FLAGS_tensorrt_max_batch_size;
177-
auto max_workspace = FLAGS_tensorrt_workspace_size;
174+
int max_batch_size = context.Attr<int>("max_batch_size");
175+
int workspace_size = context.Attr<int>("workspace_size");
176+
178177
auto params = context.Attr<std::vector<std::string>>("parameters");
179178
std::unordered_set<std::string> parameters;
180179
for (const auto& param : params) {
@@ -186,7 +185,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
186185

187186
// TODO(Superjomn) replace this with a different stream
188187
auto* engine = Singleton<TRT_EngineManager>::Global().Create(
189-
max_batch, max_workspace, nullptr /*engine hold its own stream*/,
188+
max_batch_size, workspace_size, nullptr /*engine hold its own stream*/,
190189
context.Attr<std::string>("engine_uniq_key"),
191190
boost::get<platform::CUDAPlace>(context.GetPlace()).device);
192191

0 commit comments

Comments
 (0)