Skip to content

Commit 856f904

Browse files
committed
Improve names & comments, SmallVector for parallel op stack
Signed-off-by: Carl Pearson <cwpears@sandia.gov>
1 parent e5f777a commit 856f904

File tree

1 file changed

+48
-43
lines changed

1 file changed

+48
-43
lines changed

mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ struct KokkosMdrangeIterationPass
147147

148148
// a context for expression evaluation
149149
struct Ctx {
150+
151+
// FIXME: llvm data structures
150152
std::unordered_map<std::string, int> values;
151153
};
152154

@@ -414,8 +416,7 @@ struct KokkosMdrangeIterationPass
414416
}
415417
};
416418

417-
// partial derivative df/dx
418-
419+
// partial derivative df/dx
419420
static std::shared_ptr<Expr> df_dx(Value &f, Value &x) {
420421
if (f == x) {
421422
llvm::outs() << "Info: df_dx of equal values\n";
@@ -585,21 +586,10 @@ static std::shared_ptr<Expr> df_dx(Value &f, Value &x) {
585586
}
586587

587588

588-
// return groups of induction variables for
589-
static std::vector<Value> all_induction_variables(std::vector<scf::ParallelOp> &ops) {
590-
std::vector<Value> vars;
591-
for (auto &op : ops) {
592-
for (auto &var : op.getInductionVars()) {
593-
vars.push_back(var);
594-
}
595-
}
596-
return vars;
597-
}
598589

599-
// FIXME: this returns things like this:
600-
// <block argument> of type 'index' at index: 2
601-
static std::string get_value_name(mlir::Value &value) {
602590

591+
// Get a unique name for the provided value
592+
static std::string get_value_name(mlir::Value &value) {
603593
if (mlir::isa<BlockArgument>(value)) {
604594
auto ba = mlir::cast<BlockArgument>(value);
605595
return std::string("block") +std::to_string(uintptr_t(ba.getOwner())) + "_arg" + std::to_string(ba.getArgNumber());
@@ -608,8 +598,9 @@ static std::string get_value_name(mlir::Value &value) {
608598
}
609599
}
610600

611-
612-
static std::shared_ptr<Expr> iteration_space_size(scf::ParallelOp &op, int dim) {
601+
// Get an expression representing the size of the iteration space of `op` in the
602+
// `dim` dimension.
603+
static std::shared_ptr<Expr> iteration_space_expr(scf::ParallelOp &op, int dim) {
613604

614605
auto lb = op.getLowerBound()[dim];
615606
auto ub = op.getUpperBound()[dim];
@@ -643,62 +634,74 @@ static std::shared_ptr<Expr> iteration_space_size(scf::ParallelOp &op, int dim)
643634
return Div::make(num, stExpr);
644635
}
645636

646-
// return an Expr representing the product of the iteration space of all dimensions
647-
static std::shared_ptr<Expr> trip_count_expr(scf::ParallelOp &op) {
637+
// Get an expression representing the size of the iteration space of `op`
638+
static std::shared_ptr<Expr> iteration_space_expr(scf::ParallelOp &op) {
648639
auto lowerBounds = op.getLowerBound();
649-
std::shared_ptr<Expr> total = iteration_space_size(op, 0);
640+
std::shared_ptr<Expr> total = iteration_space_expr(op, 0);
650641
for (unsigned i = 1; i < lowerBounds.size(); ++i) {
651-
total = Mul::make(total, iteration_space_size(op, i));
642+
total = Mul::make(total, iteration_space_expr(op, i));
652643
}
653644
return total;
654645
}
655646

656-
using ParallelTripCounts = llvm::DenseMap<scf::ParallelOp, std::shared_ptr<Expr>>;
647+
using IterationSpaceExprs = llvm::DenseMap<scf::ParallelOp, std::shared_ptr<Expr>>;
657648

658-
static ParallelTripCounts build_parallel_trip_counts(ModuleOp &mod) {
649+
// Get expressions represeting the iteration space for all parallel loops in the module
650+
static IterationSpaceExprs build_parallel_trip_counts(ModuleOp &mod) {
659651

660-
ParallelTripCounts PTC;
652+
IterationSpaceExprs ISE;
661653

662654
mod.walk([&](Operation *op) {
663655
if (auto parallelOp = dyn_cast<scf::ParallelOp>(op)) {
664-
665656
// create an expression representing the trip count for this loop
666-
std::shared_ptr<Expr> count = trip_count_expr(parallelOp);
667-
PTC[parallelOp] = count;
657+
std::shared_ptr<Expr> expr = iteration_space_expr(parallelOp);
658+
ISE[parallelOp] = expr;
668659

669660
// descend into the body of the loop
670-
ParallelTripCounts counts = build_parallel_trip_counts(parallelOp, count);
661+
IterationSpaceExprs exprs = build_parallel_trip_counts(parallelOp, expr);
662+
ISE.insert(exprs.begin(), exprs.end());
671663
}
672664
}); // walk
673665

674666

675-
return PTC;
667+
return ISE;
676668
}
677669

678-
static ParallelTripCounts build_parallel_trip_counts(scf::ParallelOp &parentOp, std::shared_ptr<Expr> cost) {
679-
ParallelTripCounts PTC;
670+
static IterationSpaceExprs build_parallel_trip_counts(scf::ParallelOp &parentOp, std::shared_ptr<Expr> cost) {
671+
IterationSpaceExprs ISE;
680672

681673
parentOp.getBody()->walk([&](Operation *op) {
682674
if (auto parallelOp = dyn_cast<scf::ParallelOp>(op)) {
683-
684675
// create an expression representing the trip count for this loop
685-
std::shared_ptr<Expr> count = trip_count_expr(parallelOp);
686-
count = Mul::make(count, cost->clone());
687-
PTC[parallelOp] = count;
676+
std::shared_ptr<Expr> expr = iteration_space_expr(parallelOp);
677+
ISE[parallelOp] = expr;
688678

689679
// descend into the body of the loop
690-
ParallelTripCounts counts = build_parallel_trip_counts(parallelOp, count);
680+
IterationSpaceExprs exprs = build_parallel_trip_counts(parallelOp, expr);
681+
ISE.insert(exprs.begin(), exprs.end());
691682
}
692683
});
693684

694-
return PTC;
685+
return ISE;
695686
}
696687

697688
// map of (Operation*, Value) -> Cost
698689
// map of the cost model for a given memref / induction variable pair
699690
using MemrefInductionCosts = llvm::DenseMap<std::pair<Operation*, mlir::Value>, Cost>;
691+
using ParallelOpStack = llvm::SmallVector<scf::ParallelOp, 4>;
700692

701-
static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts &tripCounts, std::vector<scf::ParallelOp> &stack) {
693+
// return all induction variables for all parallel ops
694+
static std::vector<Value> all_induction_variables(ParallelOpStack &ops) {
695+
std::vector<Value> vars;
696+
for (auto &op : ops) {
697+
for (auto &var : op.getInductionVars()) {
698+
vars.push_back(var);
699+
}
700+
}
701+
return vars;
702+
}
703+
704+
static MemrefInductionCosts build_cost_table(ModuleOp &mod, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) {
702705

703706
MemrefInductionCosts MIC;
704707

@@ -715,13 +718,15 @@ static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts &
715718
return MIC;
716719
}
717720

718-
static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts &tripCounts) {
719-
std::vector<scf::ParallelOp> stack;
721+
722+
723+
static MemrefInductionCosts build_cost_table(ModuleOp &mod, IterationSpaceExprs &tripCounts) {
724+
ParallelOpStack stack;
720725
return build_cost_table(mod, tripCounts, stack);
721726
}
722727

723728
template <typename Memref>
724-
static MemrefInductionCosts get_costs(Memref &memrefOp, ParallelTripCounts &tripCounts, std::vector<scf::ParallelOp> &stack) {
729+
static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) {
725730
static_assert(std::is_same_v<Memref, memref::LoadOp> || std::is_same_v<Memref, memref::StoreOp>);
726731

727732
if constexpr (std::is_same_v<Memref, memref::LoadOp>) {
@@ -790,7 +795,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, ParallelTripCounts &trip
790795
}
791796

792797
// FIXME: parentOp is also the back of the stack?
793-
static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, ParallelTripCounts &tripCounts, std::vector<scf::ParallelOp> &stack) {
798+
static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) {
794799
MemrefInductionCosts MIC;
795800

796801
parentOp.getBody()->walk([&](Operation *op) {
@@ -965,7 +970,7 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Parallel
965970
dump_ops(module);
966971

967972
llvm::outs() << "====\nbuild_parallel_trip_counts\n====\n";
968-
ParallelTripCounts tripCounts = build_parallel_trip_counts(module);
973+
IterationSpaceExprs tripCounts = build_parallel_trip_counts(module);
969974

970975
for (auto &kv : tripCounts) {
971976
const std::shared_ptr<Expr> &trip = kv.second;

0 commit comments

Comments
 (0)