Skip to content

Commit e200954

Browse files
authored
[CINN] Adjust the code format in cinn (#55009)
* adjust the code format in cinn * fix merge conflict
1 parent e572568 commit e200954

File tree

8 files changed

+50
-50
lines changed

8 files changed

+50
-50
lines changed

paddle/cinn/frontend/decomposer/activation_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ TEST(Decomposer, relu) {
3838
std::vector<std::string> output_names = {out->id};
3939
std::vector<std::vector<int>> output_shapes = {{20, 10}};
4040
RunAndCheck<float>(
41-
builder, input_names, output_names, output_shapes, relu_cpu, -1, 1);
41+
&builder, input_names, output_names, output_shapes, relu_cpu, -1, 1);
4242
}
4343

4444
TEST(Decomposer, relu_grad) {
@@ -62,7 +62,7 @@ TEST(Decomposer, relu_grad) {
6262
std::vector<std::string> output_names = {dx->id};
6363
std::vector<std::vector<int>> output_shapes = {{20, 10}};
6464
RunAndCheck<float>(
65-
builder, input_names, output_names, output_shapes, relu_grad_cpu, -1, 1);
65+
&builder, input_names, output_names, output_shapes, relu_grad_cpu, -1, 1);
6666
}
6767

6868
TEST(Decomposer, softmax_decomposer) {

paddle/cinn/frontend/decomposer/broadcast_test.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ TEST(Decomposer, elementwise_add_bcast0) {
2727
std::vector<std::string> input_names = {x.id().data(), y.id().data()};
2828
std::vector<std::string> output_names = {out->id};
2929
std::vector<std::vector<int>> output_shapes = {{4, 10, 20, 10}};
30-
RunAndCheckShape<float>(builder, input_names, output_names, output_shapes);
30+
RunAndCheckShape<float>(&builder, input_names, output_names, output_shapes);
3131
}
3232

3333
TEST(Decomposer, elementwise_add_bcase1) {
@@ -39,7 +39,7 @@ TEST(Decomposer, elementwise_add_bcase1) {
3939
std::vector<std::string> input_names = {x.id().data(), y.id().data()};
4040
std::vector<std::string> output_names = {out->id};
4141
std::vector<std::vector<int>> output_shapes = {{4, 10, 20, 10}};
42-
RunAndCheckShape<float>(builder, input_names, output_names, output_shapes);
42+
RunAndCheckShape<float>(&builder, input_names, output_names, output_shapes);
4343
}
4444

4545
TEST(Decomposer, elementwise_add_grad_bcast0) {
@@ -52,7 +52,7 @@ TEST(Decomposer, elementwise_add_grad_bcast0) {
5252
std::vector<std::string> input_names = {dout.id().data()};
5353
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
5454
std::vector<std::vector<int>> output_shapes = {{4, 1, 20, 10}, {10, 20}};
55-
RunAndCheckShape<float>(builder, input_names, output_names, output_shapes);
55+
RunAndCheckShape<float>(&builder, input_names, output_names, output_shapes);
5656
}
5757

5858
TEST(Decomposer, elementwise_add_bcast1) {
@@ -80,7 +80,7 @@ TEST(Decomposer, elementwise_add_bcast1) {
8080
std::vector<std::string> output_names = {out->id};
8181
std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}};
8282
RunAndCheck<float>(
83-
builder, input_names, output_names, output_shapes, add_cpu);
83+
&builder, input_names, output_names, output_shapes, add_cpu);
8484
}
8585

8686
TEST(Decomposer, elementwise_add_bcast1_2) {
@@ -108,7 +108,7 @@ TEST(Decomposer, elementwise_add_bcast1_2) {
108108
std::vector<std::string> output_names = {out->id};
109109
std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}};
110110
RunAndCheck<float>(
111-
builder, input_names, output_names, output_shapes, add_cpu);
111+
&builder, input_names, output_names, output_shapes, add_cpu);
112112
}
113113

114114
TEST(Decomposer, elementwise_add_grad_bcast1) {
@@ -140,7 +140,7 @@ TEST(Decomposer, elementwise_add_grad_bcast1) {
140140
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
141141
std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}, {64}};
142142
RunAndCheck<float>(
143-
builder, input_names, output_names, output_shapes, add_grad_cpu);
143+
&builder, input_names, output_names, output_shapes, add_grad_cpu);
144144
}
145145

146146
TEST(Decomposer, elementwise_add_bcast2) {
@@ -165,7 +165,7 @@ TEST(Decomposer, elementwise_add_bcast2) {
165165
std::vector<std::string> output_names = {out->id};
166166
std::vector<std::vector<int>> output_shapes = {{32, 16}};
167167
RunAndCheck<float>(
168-
builder, input_names, output_names, output_shapes, add_cpu);
168+
&builder, input_names, output_names, output_shapes, add_cpu);
169169
}
170170

171171
TEST(Decomposer, elementwise_add_bcast2_2) {
@@ -190,7 +190,7 @@ TEST(Decomposer, elementwise_add_bcast2_2) {
190190
std::vector<std::string> output_names = {out->id};
191191
std::vector<std::vector<int>> output_shapes = {{32, 16}};
192192
RunAndCheck<float>(
193-
builder, input_names, output_names, output_shapes, add_cpu);
193+
&builder, input_names, output_names, output_shapes, add_cpu);
194194
}
195195

196196
TEST(Decomposer, elementwise_add_bcast2_3) {
@@ -217,7 +217,7 @@ TEST(Decomposer, elementwise_add_bcast2_3) {
217217
std::vector<std::string> output_names = {out->id};
218218
std::vector<std::vector<int>> output_shapes = {{32, 16}};
219219
RunAndCheck<int_ty>(
220-
builder, input_names, output_names, output_shapes, add_cpu);
220+
&builder, input_names, output_names, output_shapes, add_cpu);
221221
}
222222

223223
TEST(Decomposer, elementwise_add_grad_bcast2) {
@@ -244,7 +244,7 @@ TEST(Decomposer, elementwise_add_grad_bcast2) {
244244
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
245245
std::vector<std::vector<int>> output_shapes = {{32, 16}, {1}};
246246
RunAndCheck<float>(
247-
builder, input_names, output_names, output_shapes, add_grad_cpu);
247+
&builder, input_names, output_names, output_shapes, add_grad_cpu);
248248
}
249249

250250
TEST(Decomposer, elementwise_add_same_dims) {
@@ -268,7 +268,7 @@ TEST(Decomposer, elementwise_add_same_dims) {
268268
std::vector<std::string> output_names = {out->id};
269269
std::vector<std::vector<int>> output_shapes = {{32, 16}};
270270
RunAndCheck<float>(
271-
builder, input_names, output_names, output_shapes, add_cpu);
271+
&builder, input_names, output_names, output_shapes, add_cpu);
272272
}
273273

274274
TEST(Decomposer, elementwise_add_grad_same_dims) {
@@ -295,7 +295,7 @@ TEST(Decomposer, elementwise_add_grad_same_dims) {
295295
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
296296
std::vector<std::vector<int>> output_shapes = {{32, 16}, {32, 16}};
297297
RunAndCheck<float>(
298-
builder, input_names, output_names, output_shapes, add_grad_cpu);
298+
&builder, input_names, output_names, output_shapes, add_grad_cpu);
299299
}
300300

301301
} // namespace cinn::frontend

paddle/cinn/frontend/decomposer/elementwise_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ TEST(Decomposer, sum) {
4242
std::vector<std::string> output_names = {out->id};
4343
std::vector<std::vector<int>> output_shapes = {{32, 16}};
4444
RunAndCheck<float>(
45-
builder, input_names, output_names, output_shapes, sum_cpu);
45+
&builder, input_names, output_names, output_shapes, sum_cpu);
4646
}
4747

4848
} // namespace cinn::frontend

paddle/cinn/frontend/decomposer/test_helper.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ void RunDecomposer(Program* prog,
193193
const std::vector<std::string>& fetch_ids = {});
194194

195195
template <typename T>
196-
void RunAndCheckShape(NetBuilder& builder,
196+
void RunAndCheckShape(NetBuilder* builder,
197197
const std::vector<std::string>& input_names,
198198
const std::vector<std::string>& output_names,
199199
const std::vector<std::vector<int>>& output_shapes,
@@ -202,7 +202,7 @@ void RunAndCheckShape(NetBuilder& builder,
202202
T low = 0,
203203
T high = 1,
204204
const std::vector<std::string>& passes = {"Decomposer"}) {
205-
auto prog = builder.Build();
205+
auto prog = builder->Build();
206206
Target target = common::DefaultTarget();
207207
RunDecomposer(&prog, target, passes, output_names);
208208
auto graph = std::make_shared<hlir::framework::Graph>(prog, target);
@@ -238,7 +238,7 @@ void RunAndCheckShape(NetBuilder& builder,
238238
}
239239

240240
template <typename T>
241-
void RunAndCheck(NetBuilder& builder,
241+
void RunAndCheck(NetBuilder* builder,
242242
const std::vector<std::string>& input_names,
243243
const std::vector<std::string>& output_names,
244244
const std::vector<std::vector<int>>& output_shapes,

paddle/cinn/frontend/pass/fill_constant_rewriter_test.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ TEST(FillConstantRewriter, remove_reshape_single) {
4242
std::vector<std::string> program_passes = {"FillConstantRewriter",
4343
"RemoveIdentity"};
4444
int num_removed_ops =
45-
tester.RunAndCheck(builder, program_passes, input_names, output_names);
45+
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
4646
ASSERT_EQ(num_removed_ops, 2);
4747
}
4848

@@ -68,7 +68,7 @@ TEST(FillConstantRewriter, remove_reshape_with_fill_constant) {
6868
std::vector<std::string> program_passes = {"FillConstantRewriter",
6969
"RemoveIdentity"};
7070
int num_removed_ops =
71-
tester.RunAndCheck(builder, program_passes, input_names, output_names);
71+
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
7272
ASSERT_EQ(num_removed_ops, 2);
7373
}
7474

@@ -93,7 +93,7 @@ TEST(FillConstantRewriter, remove_scale_single) {
9393
std::vector<std::string> program_passes = {"FillConstantRewriter",
9494
"RemoveIdentity"};
9595
int num_removed_ops =
96-
tester.RunAndCheck(builder, program_passes, input_names, output_names);
96+
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
9797
ASSERT_EQ(num_removed_ops, 2);
9898
}
9999

@@ -118,7 +118,7 @@ TEST(FillConstantRewriter, remove_scale_with_fill_constant) {
118118
std::vector<std::string> program_passes = {"FillConstantRewriter",
119119
"RemoveIdentity"};
120120
int num_removed_ops =
121-
tester.RunAndCheck(builder, program_passes, input_names, output_names);
121+
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
122122
ASSERT_EQ(num_removed_ops, 2);
123123
}
124124

@@ -150,7 +150,7 @@ TEST(FillConstantRewriter, remove_multi_scale_with_fill_constant) {
150150
std::vector<std::string> program_passes = {"FillConstantRewriter",
151151
"RemoveIdentity"};
152152
int num_removed_ops =
153-
tester.RunAndCheck(builder, program_passes, input_names, output_names);
153+
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
154154
ASSERT_EQ(num_removed_ops, 4);
155155
}
156156

@@ -167,7 +167,7 @@ TEST(FillConstantRewriter, two_fill_constant) {
167167
std::vector<std::string> program_passes = {"FillConstantRewriter",
168168
"RemoveIdentity"};
169169
int num_removed_ops =
170-
tester.RunAndCheck(builder, program_passes, input_names, output_names);
170+
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
171171
ASSERT_EQ(num_removed_ops, 0);
172172
}
173173

paddle/cinn/frontend/pass/remove_identity_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TEST(RemoveIdentity, remove_single) {
4040
std::vector<std::string> program_passes = {"RemoveIdentity",
4141
"DeadCodeEliminate"};
4242
int num_removed_ops =
43-
tester.RunAndCheck(builder, program_passes, input_names, output_names);
43+
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
4444
ASSERT_EQ(num_removed_ops, 3);
4545
}
4646

@@ -63,7 +63,7 @@ TEST(RemoveIdentity, remove_branch) {
6363
std::vector<std::string> output_names = {reduce_sum_1->id, reduce_sum_2->id};
6464
std::vector<std::string> program_passes = {"RemoveIdentity"};
6565
int num_removed_ops =
66-
tester.RunAndCheck(builder, program_passes, input_names, output_names);
66+
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
6767
ASSERT_EQ(num_removed_ops, 1);
6868
}
6969

@@ -92,7 +92,7 @@ TEST(RemoveIdentity, remove_multiple) {
9292
std::vector<std::string> output_names = {mul_1->id};
9393
std::vector<std::string> program_passes = {"RemoveIdentity"};
9494
int num_removed_ops =
95-
tester.RunAndCheck(builder, program_passes, input_names, output_names);
95+
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
9696
ASSERT_EQ(num_removed_ops, 3);
9797
}
9898

@@ -121,7 +121,7 @@ TEST(RemoveIdentity, cannot_remove_fetch) {
121121
std::vector<std::string> output_names = {identity_2->id, mul_1->id};
122122
std::vector<std::string> program_passes = {"RemoveIdentity"};
123123
int num_removed_ops =
124-
tester.RunAndCheck(builder, program_passes, input_names, output_names);
124+
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
125125
ASSERT_EQ(num_removed_ops, 1);
126126
}
127127

paddle/cinn/frontend/pass/test_helper.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ class PassTest {
7575
public:
7676
PassTest() { target_ = common::DefaultTarget(); }
7777

78-
int RunAndCheck(NetBuilder& builder,
78+
int RunAndCheck(NetBuilder* builder,
7979
const std::vector<std::string>& program_passes,
8080
const std::vector<std::string>& input_names,
8181
const std::vector<std::string>& output_names) {
82-
auto program = builder.Build();
82+
auto program = builder->Build();
8383
CHECK(IsValid(program)) << "The origin program is not valid.";
8484
int origin_program_size = program.size();
8585
LOG(INFO) << "Run origin program";

paddle/cinn/hlir/pass/fusion_merge_pass.cc

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class FusionMergePassHelper : public FusionHelperBase {
176176

177177
bool HorizontalFusion(
178178
GroupPtr producer,
179-
std::unordered_set<GroupPtr, Hasher, Comparator>& consumers) {
179+
const std::unordered_set<GroupPtr, Hasher, Comparator>& consumers) {
180180
VLOG(3) << "HorizontalFusion...!";
181181
if (consumers.size() <= 1) {
182182
return false;
@@ -249,7 +249,7 @@ class FusionMergePassHelper : public FusionHelperBase {
249249
return updated;
250250
}
251251

252-
void HorizontalFuse(GroupList& consumers) {
252+
void HorizontalFuse(const GroupList& consumers) {
253253
VLOG(3) << "HorizontalFuse Groups...";
254254
// create fusion group
255255
auto fused_group = std::make_shared<Graph::Group>();
@@ -400,8 +400,8 @@ class FusionMergePassHelper : public FusionHelperBase {
400400
}
401401

402402
bool VerticalFusion(
403-
GroupPtr& producer,
404-
std::unordered_set<GroupPtr, Hasher, Comparator>& consumers,
403+
const GroupPtr& producer,
404+
const std::unordered_set<GroupPtr, Hasher, Comparator>& consumers,
405405
bool recompute) {
406406
VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size();
407407
auto& relation = fusion_relation_map_[producer->op_pattern_kind];
@@ -463,14 +463,14 @@ class FusionMergePassHelper : public FusionHelperBase {
463463
if (!recompute) {
464464
return false;
465465
} else {
466-
RecomputeEleGraph(producer, fuse_consumers_unsafe);
466+
RecomputeEleGraph(producer, &fuse_consumers_unsafe);
467467
VerticalFuse(producer, fuse_consumers_unsafe);
468468
return true;
469469
}
470470
}
471471

472472
if (fuse_consumers.size()) {
473-
SelectConsumerToFuse(producer, fuse_consumers);
473+
SelectConsumerToFuse(producer, &fuse_consumers);
474474
}
475475

476476
// if fusionable consumers exist
@@ -482,9 +482,9 @@ class FusionMergePassHelper : public FusionHelperBase {
482482
return false;
483483
}
484484

485-
void VerticalFuse(
486-
GroupPtr& producer,
487-
std::unordered_set<GroupPtr, Hasher, Comparator>& fusionable_consumers) {
485+
void VerticalFuse(const GroupPtr& producer,
486+
const std::unordered_set<GroupPtr, Hasher, Comparator>&
487+
fusionable_consumers) {
488488
VLOG(3) << "VerticalFuse...!";
489489
GroupList fused_groups;
490490
GroupPtr master_fuesd_group(nullptr);
@@ -671,19 +671,19 @@ class FusionMergePassHelper : public FusionHelperBase {
671671

672672
void RecomputeEleGraph(
673673
const GroupPtr& producer,
674-
std::unordered_set<GroupPtr, Hasher, Comparator>& fusionable_consumers) {
674+
std::unordered_set<GroupPtr, Hasher, Comparator>* fusionable_consumers) {
675675
if (producer->op_pattern_kind != framework::kElementWise) {
676676
SelectConsumerToFuse(producer, fusionable_consumers);
677677
}
678678
}
679679

680680
void SelectConsumerToFuse(
681681
const GroupPtr& producer,
682-
std::unordered_set<GroupPtr, Hasher, Comparator>& fusionable_consumers) {
682+
std::unordered_set<GroupPtr, Hasher, Comparator>* fusionable_consumers) {
683683
// if is const op
684684
if (is_const_group(this, producer)) {
685685
std::unordered_set<GroupPtr, Hasher, Comparator> candidates;
686-
for (auto& consumer : fusionable_consumers) {
686+
for (auto& consumer : *fusionable_consumers) {
687687
// if can be output node.
688688
if (is_same_shape(this, producer, consumer)) {
689689
candidates.insert(consumer);
@@ -707,10 +707,10 @@ class FusionMergePassHelper : public FusionHelperBase {
707707
CHECK_GE(producer->consumer_groups.size(), candidates.size());
708708
if (producer->consumer_groups.size() == 0 && candidates.size() == 0 &&
709709
output_nodes_set_.count(producer->CollectNodes()[0]) == 0) {
710-
producer->belong_groups.insert(*fusionable_consumers.begin());
710+
producer->belong_groups.insert(*fusionable_consumers->begin());
711711
}
712712

713-
fusionable_consumers = candidates;
713+
*fusionable_consumers = candidates;
714714
return;
715715
}
716716
// 1 to 1 fusion.
@@ -720,7 +720,7 @@ class FusionMergePassHelper : public FusionHelperBase {
720720

721721
if (FLAGS_enhance_vertical_fusion_with_recompute) {
722722
std::vector<GroupPtr> candidates;
723-
for (auto& consumer : fusionable_consumers) {
723+
for (auto& consumer : *fusionable_consumers) {
724724
if (consumer->op_pattern_kind == framework::kElementWise) {
725725
candidates.push_back(consumer);
726726
continue;
@@ -764,13 +764,13 @@ class FusionMergePassHelper : public FusionHelperBase {
764764
return lhs->op_pattern_kind < rhs->op_pattern_kind;
765765
});
766766

767-
fusionable_consumers.clear();
767+
fusionable_consumers->clear();
768768
if (candidates.size()) {
769-
fusionable_consumers.insert(*candidates.begin());
769+
fusionable_consumers->insert(*candidates.begin());
770770
}
771771
} else {
772772
std::unordered_set<GroupPtr, Hasher, Comparator> candidates;
773-
for (auto& consumer : fusionable_consumers) {
773+
for (auto& consumer : *fusionable_consumers) {
774774
if (consumer->op_pattern_kind == framework::kElementWise) {
775775
candidates.insert(consumer);
776776
continue;
@@ -787,9 +787,9 @@ class FusionMergePassHelper : public FusionHelperBase {
787787
}
788788
}
789789

790-
fusionable_consumers.clear();
790+
fusionable_consumers->clear();
791791
if (candidates.size()) {
792-
fusionable_consumers.insert(*candidates.begin());
792+
fusionable_consumers->insert(*candidates.begin());
793793
}
794794
}
795795
}

0 commit comments

Comments
 (0)