Skip to content

Commit 55dd3f2

Browse files
[Paddle-TRT] revert to paddle op when matrix_multiply cannot enter into paddle-trt (#54278)
* revert to paddle op when matrix_multiply cannot enter into paddle-trt
1 parent d6c8ca8 commit 55dd3f2

File tree

2 files changed

+38
-33
lines changed

2 files changed

+38
-33
lines changed

paddle/fluid/framework/ir/trt_map_ops_to_matrix_multiply_pass.cc

+24-33
Original file line numberDiff line numberDiff line change
@@ -57,60 +57,51 @@ void TrtMapOpsToMatrixMultiplyPass::ApplyImpl(ir::Graph* graph) const {
5757
VLOG(4) << "trt map some ops to matrix_multiply";
5858
GET_IR_NODE_FROM_SUBGRAPH(ops, ops, mul_matmul_matmul_v2);
5959
GET_IR_NODE_FROM_SUBGRAPH(ops_out, ops_out, mul_matmul_matmul_v2);
60-
OpDesc desc(ops->Op()->Block());
61-
desc.SetType("matrix_multiply");
62-
desc.SetInput("X", {ops->Op()->Input("X").front()});
63-
desc.SetInput("Y", {ops->Op()->Input("Y").front()});
64-
desc.SetOutput("Out", {ops_out->Name()});
65-
66-
if (ops->Op()->HasAttr("transpose_X") || ops->Op()->HasAttr("trans_x")) {
67-
if (ops->Op()->HasAttr("transpose_X")) {
68-
desc.SetAttr("transpose_x", ops->Op()->GetAttr("transpose_X"));
60+
auto op_desc = ops->Op();
61+
op_desc->SetAttr("original_type", op_desc->Type());
62+
op_desc->SetType("matrix_multiply");
63+
ops->RenameOp("matrix_multiply");
64+
65+
// OpDesc original_desc(*(ops->Op()));
66+
67+
if (op_desc->HasAttr("transpose_X") || op_desc->HasAttr("trans_x")) {
68+
if (op_desc->HasAttr("transpose_X")) {
69+
op_desc->SetAttr("transpose_x", op_desc->GetAttr("transpose_X"));
6970
} else {
70-
desc.SetAttr("transpose_x", ops->Op()->GetAttr("trans_x"));
71+
op_desc->SetAttr("transpose_x", op_desc->GetAttr("trans_x"));
7172
}
7273
} else {
73-
desc.SetAttr("transpose_x", false);
74+
op_desc->SetAttr("transpose_x", false);
7475
}
7576

76-
if (ops->Op()->HasAttr("transpose_Y") || ops->Op()->HasAttr("trans_y")) {
77-
if (ops->Op()->HasAttr("transpose_Y")) {
78-
desc.SetAttr("transpose_y", ops->Op()->GetAttr("transpose_Y"));
77+
if (op_desc->HasAttr("transpose_Y") || op_desc->HasAttr("trans_y")) {
78+
if (op_desc->HasAttr("transpose_Y")) {
79+
op_desc->SetAttr("transpose_y", op_desc->GetAttr("transpose_Y"));
7980
} else {
80-
desc.SetAttr("transpose_y", ops->Op()->GetAttr("trans_y"));
81+
op_desc->SetAttr("transpose_y", op_desc->GetAttr("trans_y"));
8182
}
8283
} else {
83-
desc.SetAttr("transpose_y", false);
84-
}
85-
86-
if (ops->Op()->HasAttr("out_threshold")) {
87-
desc.SetAttr("out_threshold", ops->Op()->GetAttr("out_threshold"));
84+
op_desc->SetAttr("transpose_y", false);
8885
}
8986

9087
// Todo: remove attr(x_num_col_dims, y_num_col_dims, alpha)
91-
if (ops->Op()->HasAttr("x_num_col_dims")) {
92-
desc.SetAttr("x_num_col_dims", ops->Op()->GetAttr("x_num_col_dims"));
88+
if (op_desc->HasAttr("x_num_col_dims")) {
89+
op_desc->SetAttr("x_num_col_dims", op_desc->GetAttr("x_num_col_dims"));
9390
} else {
9491
int32_t x_num_col_dims = -1;
95-
desc.SetAttr("x_num_col_dims", x_num_col_dims);
92+
op_desc->SetAttr("x_num_col_dims", x_num_col_dims);
9693
}
9794

9895
// op_teller: Only support y_num_col_dims == y.rank - 1;
9996
int32_t y_num_col_dims = -1;
100-
desc.SetAttr("y_num_col_dims", y_num_col_dims);
97+
op_desc->SetAttr("y_num_col_dims", y_num_col_dims);
10198

10299
float alpha = 1;
103-
if (ops->Op()->HasAttr("alpha")) {
104-
alpha = PADDLE_GET_CONST(float, ops->Op()->GetAttr("alpha"));
100+
if (op_desc->HasAttr("alpha")) {
101+
alpha = PADDLE_GET_CONST(float, op_desc->GetAttr("alpha"));
105102
}
106-
desc.SetAttr("alpha", alpha);
103+
op_desc->SetAttr("alpha", alpha);
107104

108-
auto matrix_multiply_node = g->CreateOpNode(&desc);
109-
for (auto node : ops->inputs) {
110-
IR_NODE_LINK_TO(node, matrix_multiply_node);
111-
}
112-
IR_NODE_LINK_TO(matrix_multiply_node, ops_out);
113-
GraphSafeRemoveNodes(graph, {ops});
114105
++found_count;
115106
};
116107
gpd(graph, handler);

paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc

+14
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,20 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
224224
->SetAllNodesLowerToTrt(use_cuda_graph);
225225
}
226226
}
227+
228+
// some ops are only implemented in paddle-trt,
229+
// but not in paddle ,we should revert it.
230+
for (auto *op_node : framework::ir::TopologyVarientSort(
231+
*graph, static_cast<framework::ir::SortKind>(0))) {
232+
if (op_node->Op()->Type() == "matrix_multiply") {
233+
auto origin_type =
234+
op_node->Op()->GetAttrIfExists<std::string>("original_type");
235+
LOG(WARNING) << "matrix_multiply can't enter into paddle-trt,"
236+
<< "we will revert to " << origin_type;
237+
op_node->Op()->SetType(origin_type);
238+
op_node->RenameOp(origin_type);
239+
}
240+
}
227241
}
228242

229243
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,

0 commit comments

Comments
 (0)