@@ -147,6 +147,8 @@ struct KokkosMdrangeIterationPass
147
147
148
148
// a context for expression evaluation
149
149
struct Ctx {
150
+
151
+ // FIXME: llvm data structures
150
152
std::unordered_map<std::string, int > values;
151
153
};
152
154
@@ -414,8 +416,7 @@ struct KokkosMdrangeIterationPass
414
416
}
415
417
};
416
418
417
- // partial derivative df/dx
418
-
419
+ // partial derivative df/dx
419
420
static std::shared_ptr<Expr> df_dx (Value &f, Value &x) {
420
421
if (f == x) {
421
422
llvm::outs () << " Info: df_dx of equal values\n " ;
@@ -585,21 +586,10 @@ static std::shared_ptr<Expr> df_dx(Value &f, Value &x) {
585
586
}
586
587
587
588
588
- // return groups of induction variables for
589
- static std::vector<Value> all_induction_variables (std::vector<scf::ParallelOp> &ops) {
590
- std::vector<Value> vars;
591
- for (auto &op : ops) {
592
- for (auto &var : op.getInductionVars ()) {
593
- vars.push_back (var);
594
- }
595
- }
596
- return vars;
597
- }
598
589
599
- // FIXME: this returns things like this:
600
- // <block argument> of type 'index' at index: 2
601
- static std::string get_value_name (mlir::Value &value) {
602
590
591
+ // Get a unique name for the provided value
592
+ static std::string get_value_name (mlir::Value &value) {
603
593
if (mlir::isa<BlockArgument>(value)) {
604
594
auto ba = mlir::cast<BlockArgument>(value);
605
595
return std::string (" block" ) +std::to_string (uintptr_t (ba.getOwner ())) + " _arg" + std::to_string (ba.getArgNumber ());
@@ -608,8 +598,9 @@ static std::string get_value_name(mlir::Value &value) {
608
598
}
609
599
}
610
600
611
-
612
- static std::shared_ptr<Expr> iteration_space_size (scf::ParallelOp &op, int dim) {
601
+ // Get an expression representing the size of the iteration space of `op` in the
602
+ // `dim` dimension.
603
+ static std::shared_ptr<Expr> iteration_space_expr (scf::ParallelOp &op, int dim) {
613
604
614
605
auto lb = op.getLowerBound ()[dim];
615
606
auto ub = op.getUpperBound ()[dim];
@@ -643,62 +634,74 @@ static std::shared_ptr<Expr> iteration_space_size(scf::ParallelOp &op, int dim)
643
634
return Div::make (num, stExpr);
644
635
}
645
636
646
- // return an Expr representing the product of the iteration space of all dimensions
647
- static std::shared_ptr<Expr> trip_count_expr (scf::ParallelOp &op) {
637
+ // Get an expression representing the size of the iteration space of `op`
638
+ static std::shared_ptr<Expr> iteration_space_expr (scf::ParallelOp &op) {
648
639
auto lowerBounds = op.getLowerBound ();
649
- std::shared_ptr<Expr> total = iteration_space_size (op, 0 );
640
+ std::shared_ptr<Expr> total = iteration_space_expr (op, 0 );
650
641
for (unsigned i = 1 ; i < lowerBounds.size (); ++i) {
651
- total = Mul::make (total, iteration_space_size (op, i));
642
+ total = Mul::make (total, iteration_space_expr (op, i));
652
643
}
653
644
return total;
654
645
}
655
646
656
- using ParallelTripCounts = llvm::DenseMap<scf::ParallelOp, std::shared_ptr<Expr>>;
647
+ using IterationSpaceExprs = llvm::DenseMap<scf::ParallelOp, std::shared_ptr<Expr>>;
657
648
658
- static ParallelTripCounts build_parallel_trip_counts (ModuleOp &mod) {
649
+ // Get expressions represeting the iteration space for all parallel loops in the module
650
+ static IterationSpaceExprs build_parallel_trip_counts (ModuleOp &mod) {
659
651
660
- ParallelTripCounts PTC ;
652
+ IterationSpaceExprs ISE ;
661
653
662
654
mod.walk ([&](Operation *op) {
663
655
if (auto parallelOp = dyn_cast<scf::ParallelOp>(op)) {
664
-
665
656
// create an expression representing the trip count for this loop
666
- std::shared_ptr<Expr> count = trip_count_expr (parallelOp);
667
- PTC [parallelOp] = count ;
657
+ std::shared_ptr<Expr> expr = iteration_space_expr (parallelOp);
658
+ ISE [parallelOp] = expr ;
668
659
669
660
// descend into the body of the loop
670
- ParallelTripCounts counts = build_parallel_trip_counts (parallelOp, count);
661
+ IterationSpaceExprs exprs = build_parallel_trip_counts (parallelOp, expr);
662
+ ISE.insert (exprs.begin (), exprs.end ());
671
663
}
672
664
}); // walk
673
665
674
666
675
- return PTC ;
667
+ return ISE ;
676
668
}
677
669
678
- static ParallelTripCounts build_parallel_trip_counts (scf::ParallelOp &parentOp, std::shared_ptr<Expr> cost) {
679
- ParallelTripCounts PTC ;
670
+ static IterationSpaceExprs build_parallel_trip_counts (scf::ParallelOp &parentOp, std::shared_ptr<Expr> cost) {
671
+ IterationSpaceExprs ISE ;
680
672
681
673
parentOp.getBody ()->walk ([&](Operation *op) {
682
674
if (auto parallelOp = dyn_cast<scf::ParallelOp>(op)) {
683
-
684
675
// create an expression representing the trip count for this loop
685
- std::shared_ptr<Expr> count = trip_count_expr (parallelOp);
686
- count = Mul::make (count, cost->clone ());
687
- PTC[parallelOp] = count;
676
+ std::shared_ptr<Expr> expr = iteration_space_expr (parallelOp);
677
+ ISE[parallelOp] = expr;
688
678
689
679
// descend into the body of the loop
690
- ParallelTripCounts counts = build_parallel_trip_counts (parallelOp, count);
680
+ IterationSpaceExprs exprs = build_parallel_trip_counts (parallelOp, expr);
681
+ ISE.insert (exprs.begin (), exprs.end ());
691
682
}
692
683
});
693
684
694
- return PTC ;
685
+ return ISE ;
695
686
}
696
687
697
688
// map of (Operation*, Value) -> Cost
698
689
// map of the cost model for a given memref / induction variable pair
699
690
using MemrefInductionCosts = llvm::DenseMap<std::pair<Operation*, mlir::Value>, Cost>;
691
+ using ParallelOpStack = llvm::SmallVector<scf::ParallelOp, 4 >;
700
692
701
- static MemrefInductionCosts build_cost_table (ModuleOp &mod, ParallelTripCounts &tripCounts, std::vector<scf::ParallelOp> &stack) {
693
+ // return all induction variables for all parallel ops
694
+ static std::vector<Value> all_induction_variables (ParallelOpStack &ops) {
695
+ std::vector<Value> vars;
696
+ for (auto &op : ops) {
697
+ for (auto &var : op.getInductionVars ()) {
698
+ vars.push_back (var);
699
+ }
700
+ }
701
+ return vars;
702
+ }
703
+
704
+ static MemrefInductionCosts build_cost_table (ModuleOp &mod, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) {
702
705
703
706
MemrefInductionCosts MIC;
704
707
@@ -715,13 +718,15 @@ static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts &
715
718
return MIC;
716
719
}
717
720
718
- static MemrefInductionCosts build_cost_table (ModuleOp &mod, ParallelTripCounts &tripCounts) {
719
- std::vector<scf::ParallelOp> stack;
721
+
722
+
723
+ static MemrefInductionCosts build_cost_table (ModuleOp &mod, IterationSpaceExprs &tripCounts) {
724
+ ParallelOpStack stack;
720
725
return build_cost_table (mod, tripCounts, stack);
721
726
}
722
727
723
728
template <typename Memref>
724
- static MemrefInductionCosts get_costs (Memref &memrefOp, ParallelTripCounts &tripCounts, std::vector<scf::ParallelOp> &stack) {
729
+ static MemrefInductionCosts get_costs (Memref &memrefOp, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) {
725
730
static_assert (std::is_same_v<Memref, memref::LoadOp> || std::is_same_v<Memref, memref::StoreOp>);
726
731
727
732
if constexpr (std::is_same_v<Memref, memref::LoadOp>) {
@@ -790,7 +795,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, ParallelTripCounts &trip
790
795
}
791
796
792
797
// FIXME: parentOp is also the back of the stack?
793
- static MemrefInductionCosts build_cost_table (scf::ParallelOp &parentOp, ParallelTripCounts &tripCounts, std::vector<scf::ParallelOp> &stack) {
798
+ static MemrefInductionCosts build_cost_table (scf::ParallelOp &parentOp, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) {
794
799
MemrefInductionCosts MIC;
795
800
796
801
parentOp.getBody ()->walk ([&](Operation *op) {
@@ -965,7 +970,7 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Parallel
965
970
dump_ops (module);
966
971
967
972
llvm::outs () << " ====\n build_parallel_trip_counts\n ====\n " ;
968
- ParallelTripCounts tripCounts = build_parallel_trip_counts (module);
973
+ IterationSpaceExprs tripCounts = build_parallel_trip_counts (module);
969
974
970
975
for (auto &kv : tripCounts) {
971
976
const std::shared_ptr<Expr> &trip = kv.second ;
0 commit comments