diff --git a/paddle/cinn/operator_fusion/graph_transformer/matcher.h b/paddle/cinn/operator_fusion/graph_transformer/matcher.h index 7d8ec000414292..798734a664d903 100644 --- a/paddle/cinn/operator_fusion/graph_transformer/matcher.h +++ b/paddle/cinn/operator_fusion/graph_transformer/matcher.h @@ -105,33 +105,29 @@ struct CanFuseRxTMatcher { } }; -struct CanFuseReduceTreeMatcher { +struct CanFuseReducePlusReduceMatcher { bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - return StmtPatternGraphMatcher()(graph, node) && - !node->downstream().empty() && - std::holds_alternative( - node->downstream().at(0)->stmt_pattern()) && + return StmtPatternGraphMatcher()(graph, node) && + node->loop_axis_mapping().reduce_axis_num > 0 && + node->downstream().size() == 1 && + node->downstream().at(0)->loop_axis_mapping().reduce_axis_num > 0 && graph.policy_manager() .template GetPolicy() ->CanFuse(node, node->downstream().at(0)) && - graph.policy_manager() - .template GetPolicy() - ->CanFuse(node, node->downstream().at(0)); + CanFuseReducePlusReduce( + node->loop_axis_mapping(), + node->downstream().at(0)->loop_axis_mapping()); } }; -struct CanFuseReduceTreeAndTrivialMatcher { +struct CanFuseReducePlusTrivialMatcher { bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - return StmtPatternGraphMatcher()(graph, node) && - !node->downstream().empty() && - std::holds_alternative( - node->downstream().at(0)->stmt_pattern()) && - node->downstream().at(0)->downstream().size() == 0 && + return StmtPatternGraphMatcher()(graph, node) && + node->loop_axis_mapping().reduce_axis_num > 0 && + node->downstream().size() == 1 && + node->downstream().at(0)->loop_axis_mapping().reduce_axis_num == 0 && graph.policy_manager() .template GetPolicy() - ->CanFuse(node, node->downstream().at(0)) && - graph.policy_manager() - .template GetPolicy() ->CanFuse(node, node->downstream().at(0)); } }; diff --git a/paddle/cinn/operator_fusion/graph_transformer/operation.h b/paddle/cinn/operator_fusion/graph_transformer/operation.h index cae9085698488d..17f43792a79c74 100644 --- a/paddle/cinn/operator_fusion/graph_transformer/operation.h +++ b/paddle/cinn/operator_fusion/graph_transformer/operation.h @@ -60,7 +60,7 @@ struct MergeTrivialPatternOperation { } }; -struct MergeReduceTreeOperation { +struct ReducePlusReduceFusionOperation { PatternNodePtr operator()(PatternGraph* graph, PatternNodePtr node) { PADDLE_ENFORCE_EQ( node->downstream().size(), @@ -69,18 +69,43 @@ struct MergeReduceTreeOperation { "The downstream of the ReduceTree node should be 1, but got %d.", node->downstream().size())); auto downstream = node->downstream().at(0); - auto merged_node = graph->MergeNode(node, downstream, MergePattern); + auto loop_sink_transform_opt = GetValidAdjacentLoopTransform( + node->loop_axis_mapping(), downstream->loop_axis_mapping(), false); + if (loop_sink_transform_opt == std::nullopt) return node; + auto loop_sink_transform = InsertSubstituteReduceAxis( + loop_sink_transform_opt.value(), node->loop_axis_mapping()); + VLOG(4) << "Start AnchorFusionOperation"; + VLOG(4) << "Upstream: \n" << node->DebugStr(); + VLOG(4) << "Downstream: \n" << downstream->DebugStr(); + const auto merge_pattern_fn = + [](const StmtPattern& upstream, + const StmtPattern& downstream) -> StmtPattern { + return AnchorPattern( + UniqueConcatVector(GetOpsInPattern(upstream), + GetOpsInPattern(downstream)), + std::make_shared(GetFusionTracker(upstream), + GetFusionTracker(downstream)), + LoopAxisMappingMerge(GetPatternLoopAxisMapping(upstream), + GetPatternLoopAxisMapping(downstream), + false)); + }; + auto merged_node = graph->MergeNode(node, downstream, merge_pattern_fn); + // Update tracker + auto node_tmp_id = GetNewTmpId(node->id()); + merged_node->AppendInstr(std::make_shared( + node->id(), node_tmp_id, loop_sink_transform)); + merged_node->AppendInstr(std::make_shared( + std::vector{node_tmp_id, downstream->id()}, + merged_node->id())); graph->RemoveNode(downstream); graph->RemoveNode(node); - VLOG(4) << "MergeReduceTreeOperation: \nupstream " << node->DebugStr() - << "\ndownstream " << downstream->DebugStr() << "\nmerged " - << merged_node->DebugStr(); + VLOG(4) << "Merged: \n" << merged_node->DebugStr(); merged_node->UpdateTracker(); return merged_node; } }; -struct MergeReduceTreeAndTrivialOperation { +struct ReducePlusTrivialFusionOperation { PatternNodePtr operator()(PatternGraph* graph, PatternNodePtr node) { PADDLE_ENFORCE_EQ( node->downstream().size(), @@ -89,25 +114,42 @@ struct MergeReduceTreeAndTrivialOperation { "The downstream of the ReduceTree node should be 1, but got %d.", node->downstream().size())); auto downstream = node->downstream().at(0); - VLOG(4) << "MergeReduceTreeAndTrivialOperation: \nupstream " - << node->DebugStr() << "\ndownstream " << downstream->DebugStr(); - auto fake_reduce_iter_idx = graph->policy_manager() - .template GetPolicy() - ->GetFakeReduceIterIdx(node, downstream); - const auto merge_pattern_fn = [&fake_reduce_iter_idx]( - const StmtPattern& first, - const StmtPattern& secend) { - auto rt_pattern = - std::get(MergePattern(first, secend)); - rt_pattern.fake_reduce_iter_idx = fake_reduce_iter_idx; - return rt_pattern; + auto loop_transform_pair = GetReducePlusTrivialLoopTransform( + node->loop_axis_mapping(), downstream->loop_axis_mapping()); + if (loop_transform_pair == std::nullopt) return node; + VLOG(4) << "Start AnchorFusionOperation"; + VLOG(4) << "Upstream: \n" << node->DebugStr(); + VLOG(4) << "Downstream: \n" << downstream->DebugStr(); + auto upstream_loop_transform = InsertSubstituteReduceAxis( + loop_transform_pair.value().first, node->loop_axis_mapping()); + auto downstream_loop_transform = loop_transform_pair.value().second; + const auto merge_pattern_fn = + [&](const StmtPattern& upstream, + const StmtPattern& downstream) -> StmtPattern { + return AnchorPattern( + UniqueConcatVector(GetOpsInPattern(upstream), + GetOpsInPattern(downstream)), + std::make_shared(GetFusionTracker(upstream), + GetFusionTracker(downstream)), + ReducePlusTrivialLoopAxisMappingMerge( + GetPatternLoopAxisMapping(upstream), + GetPatternLoopAxisMapping(downstream), + downstream_loop_transform)); }; - PatternNodePtr merged_node = - graph->MergeNode(node, downstream, merge_pattern_fn); - + auto merged_node = graph->MergeNode(node, downstream, merge_pattern_fn); + // Update tracker + auto node_tmp_id = GetNewTmpId(node->id()); + auto downstream_tmp_id = GetNewTmpId(downstream->id()); + merged_node->AppendInstr(std::make_shared( + node->id(), node_tmp_id, upstream_loop_transform)); + merged_node->AppendInstr(std::make_shared( + downstream->id(), downstream_tmp_id, downstream_loop_transform)); + merged_node->AppendInstr(std::make_shared( + std::vector{node_tmp_id, downstream_tmp_id}, + merged_node->id())); graph->RemoveNode(downstream); graph->RemoveNode(node); - VLOG(4) << "merged " << merged_node->DebugStr(); + VLOG(4) << "Merged: \n" << merged_node->DebugStr(); merged_node->UpdateTracker(); return merged_node; } @@ -202,6 +244,7 @@ struct AnchorFusionOperation { return merged_node; } }; + struct SplitRecomputeOperation { void operator()(PatternGraph* graph, PatternNodePtr upstream) { auto origin_name = upstream->id(); diff --git a/paddle/cinn/operator_fusion/pattern_fuser.h b/paddle/cinn/operator_fusion/pattern_fuser.h index b32d96a76b58ac..1ad048274ea547 100644 --- a/paddle/cinn/operator_fusion/pattern_fuser.h +++ b/paddle/cinn/operator_fusion/pattern_fuser.h @@ -142,68 +142,11 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first, second.loop_axis_mapping())); } -// RR & RT - -static int InsertUpstreamIntoTree(const ReduceTreePattern& upstream, - ReduceTreePattern& downstream) { // NOLINT - auto is_direct_upstream = [&](const ReducePattern& upstream, - const ReducePattern& downstream) -> bool { - auto upstream_result = upstream.GetReduceOp()->result(0); - auto user_ops = FindUserOp(downstream.ops(), upstream_result); - return !user_ops.empty(); - }; - - if (is_direct_upstream(upstream.GetRootPattern(), - downstream.GetRootPattern())) { - downstream.InsertChild(upstream); - return 1; - } - int insert_num = 0; - for (auto& child : downstream.children()) { - insert_num += InsertUpstreamIntoTree(upstream, child); - } - return insert_num; -} - -static StmtPattern MergePatternImpl(const ReduceTreePattern& upstream, - const ReduceTreePattern& downstream) { - ReduceTreePattern result = ReduceTreePattern( - downstream.children(), - downstream.GetRootPattern(), - std::make_shared(upstream.tracker_, - downstream.tracker_)); // copy first. - int insert_num = InsertUpstreamIntoTree(upstream, result); - result.set_loop_axis_mapping(LoopAxisMappingMerge( - upstream.loop_axis_mapping(), downstream.loop_axis_mapping(), false)); - PADDLE_ENFORCE_EQ(insert_num, - 1, - ::common::errors::PreconditionNotMet( - "Must insert only once, but insert %d", insert_num)); - return result; -} - -static StmtPattern MergePatternImpl(const ReduceTreePattern& first, - const TrivialPattern& second) { - auto result = ReduceTreePlusTrivialPattern( - first, - second, - std::make_shared(first.tracker_, second.tracker_)); - result.set_loop_axis_mapping(ReducePlusTrivialLoopAxisMappingMerge( - first.loop_axis_mapping(), second.loop_axis_mapping())); - return result; -} - static StmtPattern MergePattern(const StmtPattern& first, const StmtPattern& second) { VLOG(4) << "MergePattern: " << GetPatternId(first) << " x " << GetPatternId(second); const auto PatternMatch = adt::match{ - [&](const ReduceTreePattern& lhs, const ReduceTreePattern& rhs) { - return MergePatternImpl(lhs, rhs); - }, - [&](const ReduceTreePattern& lhs, const TrivialPattern& rhs) { - return MergePatternImpl(lhs, rhs); - }, [&](const TrivialPattern& lhs, const ReducePattern& rhs) { return MergePatternImpl(lhs, rhs); }, diff --git a/paddle/cinn/operator_fusion/pattern_graph.cc b/paddle/cinn/operator_fusion/pattern_graph.cc index daa2e9770d26eb..c617e4de940301 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.cc +++ b/paddle/cinn/operator_fusion/pattern_graph.cc @@ -31,12 +31,6 @@ std::vector PatternGraph::ClusterOps() { VLOG(4) << "[Group Cluster] After SinkTrivialPattern: "; PrintGraphInfo(); - // ReducePattern -> ReduceTreePattern - VLOG(4) << "[Group Cluster] Start ReduceLiftReduceTree"; - ReduceLiftReduceTree(); - VLOG(4) << "[Group Cluster] After ReduceLiftReduceTree: "; - PrintGraphInfo(); - // ReduceTreePattern + ReduceTreePattern fusion VLOG(4) << "[Group Cluster] Start ReduceTreeGrown"; ReduceTreeGrown(); @@ -153,24 +147,20 @@ void PatternGraph::SinkTrivialPattern() { MergeTrivialPatternOperation>(this); } -void PatternGraph::ReduceLiftReduceTree() { - GraphTransformer< - NodePattern, - And, StmtPatternGraphMatcher>, - LiftReduceToReduceTreeOperation>(this); -} - void PatternGraph::ReduceTreeGrown() { GraphTransformer>, - MergeReduceTreeOperation>(this); + StmtPatternGraphMatcher, + LiftToAnchorPatternOperation>(this); + + GraphTransformer(this); } void PatternGraph::ReduceTree_Trivial_Fusion() { - GraphTransformer< - NodePattern, - And>, - MergeReduceTreeAndTrivialOperation>(this); + GraphTransformer(this); } void PatternGraph::AnchorFusion() { diff --git a/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_transform_analysis.cc b/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_transform_analysis.cc index 7b368d9d6df19e..c397177d530527 100644 --- a/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_transform_analysis.cc +++ b/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_transform_analysis.cc @@ -200,6 +200,20 @@ LoopAxisMapping TrivialSinkLoopAxisMappingMerge( return result; } +std::vector GetTargetRelatedAxisIdx( + const std::vector& source, + const AxisTransformRoute& route) { + AxisTransformSimulator simulator(route, source); + auto source_related_ids = simulator.GetRelatedAxisIds(simulator.source_ids_); + std::vector result; + for (int i = 0; i < simulator.target_ids_.size(); ++i) { + if (source_related_ids.count(simulator.target_ids_[i])) { + result.push_back(i); + } + } + return result; +} + std::vector GetFakeReduceAxisIdx(const std::vector& loop, const AxisTransformRoute& route, int reduce_axis_num) { @@ -232,110 +246,55 @@ std::vector GetFakeReduceAxisIdx(const std::vector& loop, return fake_reduce_idx; } +bool CanFuseReducePlusReduce(const LoopAxisMapping& upstream, + const LoopAxisMapping& downstream) { + auto upstream_reduce_loop = + SliceVector(upstream.loop, + upstream.loop.size() - upstream.reduce_axis_num, + upstream.loop.size()); + auto downstream_reduce_loop = + SliceVector(downstream.loop, + downstream.loop.size() - downstream.reduce_axis_num, + downstream.loop.size()); + if (upstream.reduce_axis_num != downstream.reduce_axis_num || + upstream_reduce_loop != downstream_reduce_loop) { + return false; + } + auto loop_sink_route = GetLoopSinkRoute(upstream, downstream); + if (HasUnsupportedTransform(loop_sink_route)) return false; + auto downstream_related_axis = + GetTargetRelatedAxisIdx(upstream.loop, loop_sink_route); + for (auto idx : downstream_related_axis) { + if (idx >= downstream.loop.size() - downstream.reduce_axis_num) { + return false; + } + } + return true; +} + LoopAxisMapping ReducePlusTrivialLoopAxisMappingMerge( - const LoopAxisMapping& upstream, const LoopAxisMapping& downstream) { - // Signal downstream reduce plus trivial fusion loop is downstream trivial - // loop plus upstream reduce loop. + const LoopAxisMapping& upstream, + const LoopAxisMapping& downstream, + const AxisTransformRoute& downstream_loop_transform) { PADDLE_ENFORCE( upstream.reduce_axis_num > 0 && downstream.reduce_axis_num == 0, ::common::errors::InvalidArgument( "Upstream should be reduce pattern and " "downstream should be trivial pattern.")); - auto loop_sink_route = GetLoopSinkRoute(upstream, downstream); - if (HasUnsupportedTransform(loop_sink_route)) { - // TODO(huangjiyi): fix unsupported transform in RT fusion - auto result = LoopAxisMappingMergeImpl(upstream, downstream, false); - result.DisableLoopAxisMapping(); - return result; + auto result = LoopAxisMappingMergeImpl(upstream, downstream, false); + AxisTransformSimulator simulator(downstream_loop_transform, downstream.loop); + result.loop = simulator.out_shape_; + for (auto& route : result.input2loop) { + route.insert(route.end(), + downstream_loop_transform.begin(), + downstream_loop_transform.end()); } - auto reduce_axis_num = upstream.reduce_axis_num; - auto reduce_axis = ArangeVector( - upstream.loop.size() - reduce_axis_num, upstream.loop.size()); - auto reduce_loop = SliceVector(upstream.loop, - upstream.loop.size() - reduce_axis_num, - upstream.loop.size()); - // Check whether downstream trivial can reuse upstream reduce axis. - auto fake_reduce_idx = - GetFakeReduceAxisIdx(upstream.loop, loop_sink_route, reduce_axis_num); - VLOG(4) << "fake_reduce_idx: " << cinn::utils::Join(fake_reduce_idx, ","); - LoopAxisMapping result; - if (fake_reduce_idx.empty()) { - AxisTransform append_reduce_axis = - std::make_shared(reduce_axis, reduce_loop); - auto upstream_copy = upstream; - for (auto& route : upstream_copy.input2loop) { - route.push_back(append_reduce_axis); - } - upstream_copy.loop.insert( - upstream_copy.loop.end(), reduce_loop.begin(), reduce_loop.end()); - result = LoopAxisMappingMergeImpl(upstream_copy, downstream, false); - result.loop = ConcatVector(downstream.loop, reduce_loop); - AxisTransform delete_reduce_axis = std::make_shared( - ArangeVector(downstream.loop.size(), result.loop.size()), - reduce_loop); - for (auto& route : result.loop2output) { - route.insert(route.begin(), delete_reduce_axis); - } - auto fake_reduce_idx = ArangeVector( - downstream.loop.size(), downstream.loop.size() + reduce_axis_num); - AxisTransform append_fake_reduce_idx = - std::make_shared(fake_reduce_idx, reduce_loop); - for (int i = upstream.input2loop.size(); i < result.input2loop.size(); - ++i) { - result.input2loop[i].push_back(append_fake_reduce_idx); - } - } else { - // Transpose fake reduce axis to the end - auto perm = ArangeVector(0, downstream.loop.size()); - for (auto index : fake_reduce_idx) { - perm.push_back(index); - } - std::sort(fake_reduce_idx.begin(), fake_reduce_idx.end()); - std::reverse(fake_reduce_idx.begin(), fake_reduce_idx.end()); - for (auto index : fake_reduce_idx) { - perm.erase(perm.begin() + index); - } - result = LoopAxisMappingMergeImpl(upstream, downstream, false); - AxisTransformRoute fake_reduce_axis_transforms; - if (perm != ArangeVector(0, downstream.loop.size())) { - result.loop = TransposeVector(result.loop, perm); - auto transpose_trans = std::make_shared(perm); - fake_reduce_axis_transforms.push_back(transpose_trans); - } - // Check whether fake reduce axis reuse all reduce axis - if (fake_reduce_idx.size() < reduce_axis_num) { - std::vector one_reduce_axis; - for (int i = 0; i < reduce_loop.size(); ++i) { - bool has_reuse = false; - for (const auto& downstream_idx : fake_reduce_idx) { - if (reduce_loop[i] == downstream.loop[downstream_idx]) { - has_reuse = true; - break; - } - } - if (!has_reuse) { - PADDLE_ENFORCE_EQ(reduce_loop[i], - symbol::DimExpr(1), - ::common::errors::PreconditionNotMet( - "Reduce axis not been reused must be 1.")); - one_reduce_axis.push_back(downstream.loop.size() - - fake_reduce_idx.size() + i); - } - } - auto append_one_reduce_axis = - std::make_shared(one_reduce_axis); - fake_reduce_axis_transforms.push_back(append_one_reduce_axis); - } - for (auto& route : result.input2loop) { - route.insert(route.end(), - fake_reduce_axis_transforms.begin(), - fake_reduce_axis_transforms.end()); - } - for (auto& route : result.loop2output) { - route.insert(route.begin(), - fake_reduce_axis_transforms.begin(), - fake_reduce_axis_transforms.end()); - } + auto reverse_loop_transform = + ReverseTransformRoute(downstream_loop_transform); + for (auto& route : result.loop2output) { + route.insert(route.begin(), + reverse_loop_transform.begin(), + reverse_loop_transform.end()); } result.SimplifyForwardMapping(); result.SetReverseMapping(); @@ -385,6 +344,114 @@ LoopAxisMapping HorizontalLoopAxisMappingMerge(const LoopAxisMapping& source, return result; } +AxisTransformRoute InsertSubstituteReduceAxis(const AxisTransformRoute& route, + const LoopAxisMapping& source) { + auto result = route; + // Because reduce axis can not be transformed, we need to add + // same fake axis to substitute reduce axis for transformation. + std::vector append_reduce_axis = ArangeVector( + source.loop.size() - source.reduce_axis_num, source.loop.size()); + std::vector append_reduce_shape = + GatherVector(source.loop, append_reduce_axis); + result.insert(result.begin(), + std::make_shared(append_reduce_axis, + append_reduce_shape)); + // Remove substitute reduce axis. + AxisTransformSimulator simulator(result, source.loop); + std::vector delete_reduce_axis = ArangeVector( + simulator.target_ids_.size() - source.reduce_axis_num * 2, + simulator.target_ids_.size() - source.reduce_axis_num); + PADDLE_ENFORCE_GE( + simulator.target_ids_.size(), + source.reduce_axis_num * 2, + ::common::errors::InvalidArgument("Reduce axis num is not enough.")); + auto delete_reduce_shape = + GatherVector(simulator.out_shape_, delete_reduce_axis); + PADDLE_ENFORCE(append_reduce_shape == delete_reduce_shape, + ::common::errors::InvalidArgument( + "Reduce axis shape is not equal after transform.")); + result.push_back(std::make_shared(delete_reduce_axis, + delete_reduce_shape)); + return result; +} + +std::optional> +GetReducePlusTrivialLoopTransform(const LoopAxisMapping& upstream, + const LoopAxisMapping& downstream) { + PADDLE_ENFORCE( + upstream.reduce_axis_num > 0 && downstream.reduce_axis_num == 0, + ::common::errors::InvalidArgument( + "Upstream should be reduce pattern and " + "downstream should be trivial pattern.")); + auto loop_sink_route = GetLoopSinkRoute(upstream, downstream); + if (HasUnsupportedTransform(loop_sink_route)) { + return std::nullopt; + } + auto reduce_axis_num = upstream.reduce_axis_num; + auto non_reduce_idx = + ArangeVector(0, upstream.loop.size() - reduce_axis_num); + auto non_reduce_loop = GatherVector(upstream.loop, non_reduce_idx); + auto reduce_loop = GatherVectorExcept(upstream.loop, non_reduce_idx); + auto fake_reduce_idx = + GetFakeReduceAxisIdx(upstream.loop, loop_sink_route, reduce_axis_num); + VLOG(4) << "fake_reduce_idx: " << cinn::utils::Join(fake_reduce_idx, ","); + + bool can_merge = [&]() -> bool { + auto downstream_non_fake_reduce_loop = + GatherVectorExcept(downstream.loop, fake_reduce_idx); + return ShapeProductSmallerOrEqual(downstream_non_fake_reduce_loop, + non_reduce_loop); + }(); + if (!can_merge) return std::nullopt; + + AxisTransformRoute downstream_loop_transform; + if (fake_reduce_idx.empty()) { + downstream_loop_transform.push_back(std::make_shared( + ArangeVector(downstream.loop.size(), + downstream.loop.size() + reduce_axis_num), + reduce_loop)); + } else { + // Transpose fake reduce axis to the end + auto perm = ArangeVector(0, downstream.loop.size()); + for (auto index : fake_reduce_idx) { + perm.push_back(index); + } + std::sort(fake_reduce_idx.begin(), fake_reduce_idx.end()); + std::reverse(fake_reduce_idx.begin(), fake_reduce_idx.end()); + for (auto index : fake_reduce_idx) { + perm.erase(perm.begin() + index); + } + downstream_loop_transform.push_back( + std::make_shared(perm)); + // Append non reused reduce axis + std::vector append_axis; + std::vector append_loop; + if (fake_reduce_idx.size() < reduce_axis_num) { + for (int i = 0; i < reduce_loop.size(); ++i) { + bool has_reuse = false; + for (const auto& downstream_idx : fake_reduce_idx) { + if (reduce_loop[i] == downstream.loop[downstream_idx]) { + has_reuse = true; + break; + } + } + if (!has_reuse) { + append_axis.push_back(downstream.loop.size() + i); + append_loop.push_back(reduce_loop[i]); + } + } + } + downstream_loop_transform.push_back( + std::make_shared(append_axis, append_loop)); + } + auto upstream_loop_transform = + ConcatVector(loop_sink_route, downstream_loop_transform); + + return std::make_pair( + SimplifyTransformRoute(upstream_loop_transform, upstream.loop), + SimplifyTransformRoute(downstream_loop_transform, downstream.loop)); +} + std::optional GetValidLoopTransformRoute( const LoopAxisMapping& source, const LoopAxisMapping& target, diff --git a/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_transform_analysis.h b/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_transform_analysis.h index b352a42d847131..202914a0cb4351 100644 --- a/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_transform_analysis.h +++ b/paddle/cinn/operator_fusion/pir_graph_analyzing/loop_transform_analysis.h @@ -24,9 +24,18 @@ LoopAxisMapping LoopAxisMappingMerge(const LoopAxisMapping& upstream, LoopAxisMapping TrivialSinkLoopAxisMappingMerge( const LoopAxisMapping& upstream, const LoopAxisMapping& downstream); LoopAxisMapping ReducePlusTrivialLoopAxisMappingMerge( - const LoopAxisMapping& upstream, const LoopAxisMapping& downstream); + const LoopAxisMapping& upstream, + const LoopAxisMapping& downstream, + const AxisTransformRoute& downstream_loop_transform); LoopAxisMapping HorizontalLoopAxisMappingMerge(const LoopAxisMapping& source, const LoopAxisMapping& target); +bool CanFuseReducePlusReduce(const LoopAxisMapping& upstream, + const LoopAxisMapping& downstream); +AxisTransformRoute InsertSubstituteReduceAxis(const AxisTransformRoute& route, + const LoopAxisMapping& source); +std::optional> +GetReducePlusTrivialLoopTransform(const LoopAxisMapping& upstream, + const LoopAxisMapping& downstream); // Try to find a valid axis transform route with specific direction between // upstream and downstream LoopAxisMapping. The following cases are considered diff --git a/paddle/cinn/operator_fusion/utils.cc b/paddle/cinn/operator_fusion/utils.cc index 643dac2f51e95f..9229bfcf06e5fb 100644 --- a/paddle/cinn/operator_fusion/utils.cc +++ b/paddle/cinn/operator_fusion/utils.cc @@ -288,6 +288,18 @@ bool ShapeProductEqual(const std::vector& in_shape, in_shape, out_shape, 0, in_shape.size(), 0, out_shape.size()); } +bool ShapeProductSmallerOrEqual(const std::vector& first, + const std::vector& second) { + if (first.empty()) return true; + const auto& first_product = GetShapeProduct(first); + const auto& second_product = GetShapeProduct(second); + if (first_product.isa() && second_product.isa()) { + return first_product.dyn_cast() <= + second_product.dyn_cast(); + } + return first_product == second_product; +} + std::vector> PartitionReshapeAxes( const std::vector& in_shape, const std::vector& out_shape) { diff --git a/paddle/cinn/operator_fusion/utils.h b/paddle/cinn/operator_fusion/utils.h index 78fdd0912e42e4..06f734c973d029 100644 --- a/paddle/cinn/operator_fusion/utils.h +++ b/paddle/cinn/operator_fusion/utils.h @@ -677,6 +677,9 @@ bool ShapeProductEqual(const std::vector& in_shape, bool ShapeProductEqual(const std::vector& in_shape, const std::vector& out_shape); +bool ShapeProductSmallerOrEqual(const std::vector& first, + const std::vector& second); + std::vector> PartitionReshapeAxes( const std::vector& in_shape, const std::vector& out_shape);