Skip to content

[CINN] Reconstruce Reduce single downstream fusion #72315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions paddle/cinn/operator_fusion/graph_transformer/matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,33 +105,29 @@ struct CanFuseRxTMatcher {
}
};

struct CanFuseReduceTreeMatcher {
struct CanFuseReducePlusReduceMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
return StmtPatternGraphMatcher<ReduceTreePattern>()(graph, node) &&
!node->downstream().empty() &&
std::holds_alternative<ReduceTreePattern>(
node->downstream().at(0)->stmt_pattern()) &&
return StmtPatternGraphMatcher<AnchorPattern>()(graph, node) &&
node->loop_axis_mapping().reduce_axis_num > 0 &&
node->downstream().size() == 1 &&
node->downstream().at(0)->loop_axis_mapping().reduce_axis_num > 0 &&
graph.policy_manager()
.template GetPolicy<GeneralTopoPolicy>()
->CanFuse(node, node->downstream().at(0)) &&
graph.policy_manager()
.template GetPolicy<RelativeJudgePolicy>()
->CanFuse(node, node->downstream().at(0));
CanFuseReducePlusReduce(
node->loop_axis_mapping(),
node->downstream().at(0)->loop_axis_mapping());
}
};

struct CanFuseReduceTreeAndTrivialMatcher {
struct CanFuseReducePlusTrivialMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
return StmtPatternGraphMatcher<ReduceTreePattern>()(graph, node) &&
!node->downstream().empty() &&
std::holds_alternative<TrivialPattern>(
node->downstream().at(0)->stmt_pattern()) &&
node->downstream().at(0)->downstream().size() == 0 &&
return StmtPatternGraphMatcher<AnchorPattern>()(graph, node) &&
node->loop_axis_mapping().reduce_axis_num > 0 &&
node->downstream().size() == 1 &&
node->downstream().at(0)->loop_axis_mapping().reduce_axis_num == 0 &&
graph.policy_manager()
.template GetPolicy<GeneralTopoPolicy>()
->CanFuse(node, node->downstream().at(0)) &&
graph.policy_manager()
.template GetPolicy<RelativeJudgePolicy>()
->CanFuse(node, node->downstream().at(0));
}
};
Expand Down
87 changes: 65 additions & 22 deletions paddle/cinn/operator_fusion/graph_transformer/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct MergeTrivialPatternOperation {
}
};

struct MergeReduceTreeOperation {
struct ReducePlusReduceFusionOperation {
PatternNodePtr operator()(PatternGraph* graph, PatternNodePtr node) {
PADDLE_ENFORCE_EQ(
node->downstream().size(),
Expand All @@ -69,18 +69,43 @@ struct MergeReduceTreeOperation {
"The downstream of the ReduceTree node should be 1, but got %d.",
node->downstream().size()));
auto downstream = node->downstream().at(0);
auto merged_node = graph->MergeNode(node, downstream, MergePattern);
auto loop_sink_transform_opt = GetValidAdjacentLoopTransform(
node->loop_axis_mapping(), downstream->loop_axis_mapping(), false);
if (loop_sink_transform_opt == std::nullopt) return node;
auto loop_sink_transform = InsertSubstituteReduceAxis(
loop_sink_transform_opt.value(), node->loop_axis_mapping());
VLOG(4) << "Start AnchorFusionOperation";
VLOG(4) << "Upstream: \n" << node->DebugStr();
VLOG(4) << "Downstream: \n" << downstream->DebugStr();
const auto merge_pattern_fn =
[](const StmtPattern& upstream,
const StmtPattern& downstream) -> StmtPattern {
return AnchorPattern(
UniqueConcatVector(GetOpsInPattern(upstream),
GetOpsInPattern(downstream)),
std::make_shared<FusionTracker>(GetFusionTracker(upstream),
GetFusionTracker(downstream)),
LoopAxisMappingMerge(GetPatternLoopAxisMapping(upstream),
GetPatternLoopAxisMapping(downstream),
false));
};
auto merged_node = graph->MergeNode(node, downstream, merge_pattern_fn);
// Update tracker
auto node_tmp_id = GetNewTmpId(node->id());
merged_node->AppendInstr(std::make_shared<AxisTransformInstr>(
node->id(), node_tmp_id, loop_sink_transform));
merged_node->AppendInstr(std::make_shared<CombineInstr>(
std::vector<std::string>{node_tmp_id, downstream->id()},
merged_node->id()));
graph->RemoveNode(downstream);
graph->RemoveNode(node);
VLOG(4) << "MergeReduceTreeOperation: \nupstream " << node->DebugStr()
<< "\ndownstream " << downstream->DebugStr() << "\nmerged "
<< merged_node->DebugStr();
VLOG(4) << "Merged: \n" << merged_node->DebugStr();
merged_node->UpdateTracker();
return merged_node;
}
};

struct MergeReduceTreeAndTrivialOperation {
struct ReducePlusTrivialFusionOperation {
PatternNodePtr operator()(PatternGraph* graph, PatternNodePtr node) {
PADDLE_ENFORCE_EQ(
node->downstream().size(),
Expand All @@ -89,25 +114,42 @@ struct MergeReduceTreeAndTrivialOperation {
"The downstream of the ReduceTree node should be 1, but got %d.",
node->downstream().size()));
auto downstream = node->downstream().at(0);
VLOG(4) << "MergeReduceTreeAndTrivialOperation: \nupstream "
<< node->DebugStr() << "\ndownstream " << downstream->DebugStr();
auto fake_reduce_iter_idx = graph->policy_manager()
.template GetPolicy<RelativeJudgePolicy>()
->GetFakeReduceIterIdx(node, downstream);
const auto merge_pattern_fn = [&fake_reduce_iter_idx](
const StmtPattern& first,
const StmtPattern& secend) {
auto rt_pattern =
std::get<ReduceTreePlusTrivialPattern>(MergePattern(first, secend));
rt_pattern.fake_reduce_iter_idx = fake_reduce_iter_idx;
return rt_pattern;
auto loop_transform_pair = GetReducePlusTrivialLoopTransform(
node->loop_axis_mapping(), downstream->loop_axis_mapping());
if (loop_transform_pair == std::nullopt) return node;
VLOG(4) << "Start AnchorFusionOperation";
VLOG(4) << "Upstream: \n" << node->DebugStr();
VLOG(4) << "Downstream: \n" << downstream->DebugStr();
auto upstream_loop_transform = InsertSubstituteReduceAxis(
loop_transform_pair.value().first, node->loop_axis_mapping());
auto downstream_loop_transform = loop_transform_pair.value().second;
const auto merge_pattern_fn =
[&](const StmtPattern& upstream,
const StmtPattern& downstream) -> StmtPattern {
return AnchorPattern(
UniqueConcatVector(GetOpsInPattern(upstream),
GetOpsInPattern(downstream)),
std::make_shared<FusionTracker>(GetFusionTracker(upstream),
GetFusionTracker(downstream)),
ReducePlusTrivialLoopAxisMappingMerge(
GetPatternLoopAxisMapping(upstream),
GetPatternLoopAxisMapping(downstream),
downstream_loop_transform));
};
PatternNodePtr merged_node =
graph->MergeNode(node, downstream, merge_pattern_fn);

auto merged_node = graph->MergeNode(node, downstream, merge_pattern_fn);
// Update tracker
auto node_tmp_id = GetNewTmpId(node->id());
auto downstream_tmp_id = GetNewTmpId(downstream->id());
merged_node->AppendInstr(std::make_shared<AxisTransformInstr>(
node->id(), node_tmp_id, upstream_loop_transform));
merged_node->AppendInstr(std::make_shared<AxisTransformInstr>(
downstream->id(), downstream_tmp_id, downstream_loop_transform));
merged_node->AppendInstr(std::make_shared<CombineInstr>(
std::vector<std::string>{node_tmp_id, downstream_tmp_id},
merged_node->id()));
graph->RemoveNode(downstream);
graph->RemoveNode(node);
VLOG(4) << "merged " << merged_node->DebugStr();
VLOG(4) << "Merged: \n" << merged_node->DebugStr();
merged_node->UpdateTracker();
return merged_node;
}
Expand Down Expand Up @@ -202,6 +244,7 @@ struct AnchorFusionOperation {
return merged_node;
}
};

struct SplitRecomputeOperation {
void operator()(PatternGraph* graph, PatternNodePtr upstream) {
auto origin_name = upstream->id();
Expand Down
57 changes: 0 additions & 57 deletions paddle/cinn/operator_fusion/pattern_fuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,68 +142,11 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first,
second.loop_axis_mapping()));
}

// RR & RT

static int InsertUpstreamIntoTree(const ReduceTreePattern& upstream,
ReduceTreePattern& downstream) { // NOLINT
auto is_direct_upstream = [&](const ReducePattern& upstream,
const ReducePattern& downstream) -> bool {
auto upstream_result = upstream.GetReduceOp()->result(0);
auto user_ops = FindUserOp(downstream.ops(), upstream_result);
return !user_ops.empty();
};

if (is_direct_upstream(upstream.GetRootPattern(),
downstream.GetRootPattern())) {
downstream.InsertChild(upstream);
return 1;
}
int insert_num = 0;
for (auto& child : downstream.children()) {
insert_num += InsertUpstreamIntoTree(upstream, child);
}
return insert_num;
}

static StmtPattern MergePatternImpl(const ReduceTreePattern& upstream,
const ReduceTreePattern& downstream) {
ReduceTreePattern result = ReduceTreePattern(
downstream.children(),
downstream.GetRootPattern(),
std::make_shared<FusionTracker>(upstream.tracker_,
downstream.tracker_)); // copy first.
int insert_num = InsertUpstreamIntoTree(upstream, result);
result.set_loop_axis_mapping(LoopAxisMappingMerge(
upstream.loop_axis_mapping(), downstream.loop_axis_mapping(), false));
PADDLE_ENFORCE_EQ(insert_num,
1,
::common::errors::PreconditionNotMet(
"Must insert only once, but insert %d", insert_num));
return result;
}

static StmtPattern MergePatternImpl(const ReduceTreePattern& first,
const TrivialPattern& second) {
auto result = ReduceTreePlusTrivialPattern(
first,
second,
std::make_shared<FusionTracker>(first.tracker_, second.tracker_));
result.set_loop_axis_mapping(ReducePlusTrivialLoopAxisMappingMerge(
first.loop_axis_mapping(), second.loop_axis_mapping()));
return result;
}

static StmtPattern MergePattern(const StmtPattern& first,
const StmtPattern& second) {
VLOG(4) << "MergePattern: " << GetPatternId(first) << " x "
<< GetPatternId(second);
const auto PatternMatch = adt::match{
[&](const ReduceTreePattern& lhs, const ReduceTreePattern& rhs) {
return MergePatternImpl(lhs, rhs);
},
[&](const ReduceTreePattern& lhs, const TrivialPattern& rhs) {
return MergePatternImpl(lhs, rhs);
},
[&](const TrivialPattern& lhs, const ReducePattern& rhs) {
return MergePatternImpl(lhs, rhs);
},
Expand Down
28 changes: 9 additions & 19 deletions paddle/cinn/operator_fusion/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ std::vector<PatternNodePtr> PatternGraph::ClusterOps() {
VLOG(4) << "[Group Cluster] After SinkTrivialPattern: ";
PrintGraphInfo();

// ReducePattern -> ReduceTreePattern
VLOG(4) << "[Group Cluster] Start ReduceLiftReduceTree";
ReduceLiftReduceTree();
VLOG(4) << "[Group Cluster] After ReduceLiftReduceTree: ";
PrintGraphInfo();

// ReduceTreePattern + ReduceTreePattern fusion
VLOG(4) << "[Group Cluster] Start ReduceTreeGrown";
ReduceTreeGrown();
Expand Down Expand Up @@ -153,24 +147,20 @@ void PatternGraph::SinkTrivialPattern() {
MergeTrivialPatternOperation>(this);
}

void PatternGraph::ReduceLiftReduceTree() {
GraphTransformer<
NodePattern,
And<DownstreamSmallerThan<2>, StmtPatternGraphMatcher<ReducePattern>>,
LiftReduceToReduceTreeOperation>(this);
}

void PatternGraph::ReduceTreeGrown() {
GraphTransformer<NodePattern,
And<CanFuseReduceTreeMatcher, Not<IsOutputNodeMatcher>>,
MergeReduceTreeOperation>(this);
StmtPatternGraphMatcher<ReducePattern>,
LiftToAnchorPatternOperation>(this);

GraphTransformer<NodePattern,
CanFuseReducePlusReduceMatcher,
ReducePlusReduceFusionOperation>(this);
}

void PatternGraph::ReduceTree_Trivial_Fusion() {
GraphTransformer<
NodePattern,
And<CanFuseReduceTreeAndTrivialMatcher, Not<IsOutputNodeMatcher>>,
MergeReduceTreeAndTrivialOperation>(this);
GraphTransformer<NodePattern,
CanFuseReducePlusTrivialMatcher,
ReducePlusTrivialFusionOperation>(this);
}

void PatternGraph::AnchorFusion() {
Expand Down
Loading
Loading