@@ -57,60 +57,51 @@ void TrtMapOpsToMatrixMultiplyPass::ApplyImpl(ir::Graph* graph) const {
57
57
VLOG (4 ) << " trt map some ops to matrix_multiply" ;
58
58
GET_IR_NODE_FROM_SUBGRAPH (ops, ops, mul_matmul_matmul_v2);
59
59
GET_IR_NODE_FROM_SUBGRAPH (ops_out, ops_out, mul_matmul_matmul_v2);
60
- OpDesc desc (ops->Op ()->Block ());
61
- desc.SetType (" matrix_multiply" );
62
- desc.SetInput (" X" , {ops->Op ()->Input (" X" ).front ()});
63
- desc.SetInput (" Y" , {ops->Op ()->Input (" Y" ).front ()});
64
- desc.SetOutput (" Out" , {ops_out->Name ()});
65
-
66
- if (ops->Op ()->HasAttr (" transpose_X" ) || ops->Op ()->HasAttr (" trans_x" )) {
67
- if (ops->Op ()->HasAttr (" transpose_X" )) {
68
- desc.SetAttr (" transpose_x" , ops->Op ()->GetAttr (" transpose_X" ));
60
+ auto op_desc = ops->Op ();
61
+ op_desc->SetAttr (" original_type" , op_desc->Type ());
62
+ op_desc->SetType (" matrix_multiply" );
63
+ ops->RenameOp (" matrix_multiply" );
64
+
65
+ // OpDesc original_desc(*(ops->Op()));
66
+
67
+ if (op_desc->HasAttr (" transpose_X" ) || op_desc->HasAttr (" trans_x" )) {
68
+ if (op_desc->HasAttr (" transpose_X" )) {
69
+ op_desc->SetAttr (" transpose_x" , op_desc->GetAttr (" transpose_X" ));
69
70
} else {
70
- desc. SetAttr (" transpose_x" , ops-> Op () ->GetAttr (" trans_x" ));
71
+ op_desc-> SetAttr (" transpose_x" , op_desc ->GetAttr (" trans_x" ));
71
72
}
72
73
} else {
73
- desc. SetAttr (" transpose_x" , false );
74
+ op_desc-> SetAttr (" transpose_x" , false );
74
75
}
75
76
76
- if (ops-> Op ()-> HasAttr (" transpose_Y" ) || ops-> Op () ->HasAttr (" trans_y" )) {
77
- if (ops-> Op () ->HasAttr (" transpose_Y" )) {
78
- desc. SetAttr (" transpose_y" , ops-> Op () ->GetAttr (" transpose_Y" ));
77
+ if (op_desc-> HasAttr (" transpose_Y" ) || op_desc ->HasAttr (" trans_y" )) {
78
+ if (op_desc ->HasAttr (" transpose_Y" )) {
79
+ op_desc-> SetAttr (" transpose_y" , op_desc ->GetAttr (" transpose_Y" ));
79
80
} else {
80
- desc. SetAttr (" transpose_y" , ops-> Op () ->GetAttr (" trans_y" ));
81
+ op_desc-> SetAttr (" transpose_y" , op_desc ->GetAttr (" trans_y" ));
81
82
}
82
83
} else {
83
- desc.SetAttr (" transpose_y" , false );
84
- }
85
-
86
- if (ops->Op ()->HasAttr (" out_threshold" )) {
87
- desc.SetAttr (" out_threshold" , ops->Op ()->GetAttr (" out_threshold" ));
84
+ op_desc->SetAttr (" transpose_y" , false );
88
85
}
89
86
90
87
// Todo: remove attr(x_num_col_dims, y_num_col_dims, alpha)
91
- if (ops-> Op () ->HasAttr (" x_num_col_dims" )) {
92
- desc. SetAttr (" x_num_col_dims" , ops-> Op () ->GetAttr (" x_num_col_dims" ));
88
+ if (op_desc ->HasAttr (" x_num_col_dims" )) {
89
+ op_desc-> SetAttr (" x_num_col_dims" , op_desc ->GetAttr (" x_num_col_dims" ));
93
90
} else {
94
91
int32_t x_num_col_dims = -1 ;
95
- desc. SetAttr (" x_num_col_dims" , x_num_col_dims);
92
+ op_desc-> SetAttr (" x_num_col_dims" , x_num_col_dims);
96
93
}
97
94
98
95
// op_teller: Only support y_num_col_dims == y.rank - 1;
99
96
int32_t y_num_col_dims = -1 ;
100
- desc. SetAttr (" y_num_col_dims" , y_num_col_dims);
97
+ op_desc-> SetAttr (" y_num_col_dims" , y_num_col_dims);
101
98
102
99
float alpha = 1 ;
103
- if (ops-> Op () ->HasAttr (" alpha" )) {
104
- alpha = PADDLE_GET_CONST (float , ops-> Op () ->GetAttr (" alpha" ));
100
+ if (op_desc ->HasAttr (" alpha" )) {
101
+ alpha = PADDLE_GET_CONST (float , op_desc ->GetAttr (" alpha" ));
105
102
}
106
- desc. SetAttr (" alpha" , alpha);
103
+ op_desc-> SetAttr (" alpha" , alpha);
107
104
108
- auto matrix_multiply_node = g->CreateOpNode (&desc);
109
- for (auto node : ops->inputs ) {
110
- IR_NODE_LINK_TO (node, matrix_multiply_node);
111
- }
112
- IR_NODE_LINK_TO (matrix_multiply_node, ops_out);
113
- GraphSafeRemoveNodes (graph, {ops});
114
105
++found_count;
115
106
};
116
107
gpd (graph, handler);
0 commit comments