From a911162a39baa0b3534e56855627f694cbb98951 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 19 Dec 2024 13:13:33 -0700 Subject: [PATCH 01/37] mdrange: Initial skeleton Signed-off-by: Carl Pearson --- .../lapis/Dialect/Kokkos/Transforms/Passes.h | 2 + .../lapis/Dialect/Kokkos/Transforms/Passes.td | 14 +++++ .../Dialect/Kokkos/Transforms/CMakeLists.txt | 1 + .../Transforms/KokkosMdrangeIterationPass.cpp | 55 +++++++++++++++++++ .../Kokkos/Transforms/KokkosPasses.cpp | 1 - 5 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp diff --git a/mlir/include/lapis/Dialect/Kokkos/Transforms/Passes.h b/mlir/include/lapis/Dialect/Kokkos/Transforms/Passes.h index 39c4865d..90d1da6d 100644 --- a/mlir/include/lapis/Dialect/Kokkos/Transforms/Passes.h +++ b/mlir/include/lapis/Dialect/Kokkos/Transforms/Passes.h @@ -31,6 +31,8 @@ std::unique_ptr createKokkosMemorySpaceAssignmentPass(); std::unique_ptr createKokkosDualViewManagementPass(); +std::unique_ptr createKokkosMdrangeIterationPass(); + //===----------------------------------------------------------------------===// // Registration. //===----------------------------------------------------------------------===// diff --git a/mlir/include/lapis/Dialect/Kokkos/Transforms/Passes.td b/mlir/include/lapis/Dialect/Kokkos/Transforms/Passes.td index f5eec1b8..f8252d67 100644 --- a/mlir/include/lapis/Dialect/Kokkos/Transforms/Passes.td +++ b/mlir/include/lapis/Dialect/Kokkos/Transforms/Passes.td @@ -56,4 +56,18 @@ def KokkosDualViewManagement : Pass<"kokkos-dualview-management", "ModuleOp"> { ]; } +def KokkosMdrangeIteration : Pass<"kokkos-mdrange-iteration", "ModuleOp"> { + let summary = "Rearange MDRange to improve memory access patterns on GPU"; + let description = [{ + }]; + let constructor = "mlir::createKokkosMdrangeIterationPass()"; + let dependentDialects = [ + "arith::ArithDialect", + "func::FuncDialect", + "kokkos::KokkosDialect", + "memref::MemRefDialect", + "scf::SCFDialect" + ]; +} + #endif // MLIR_DIALECT_KOKKOS_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Kokkos/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Kokkos/Transforms/CMakeLists.txt index 795e336a..96efb935 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Kokkos/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRKokkosTransforms KokkosLoopMapping.cpp KokkosMemorySpaceAssignment.cpp KokkosDualViewManagement.cpp + KokkosMdrangeIterationPass.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Kokkos diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp new file mode 100644 index 00000000..a3eb1697 --- /dev/null +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -0,0 +1,55 @@ +//===- KokkosPasses.cpp - Passes for lowering to Kokkos dialect -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "lapis/Dialect/Kokkos/IR/KokkosDialect.h" +#include "lapis/Dialect/Kokkos/Transforms/Passes.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" //for SparseParallelizationStrategy + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_PARALLELUNITSTEP +#define GEN_PASS_DEF_KOKKOSLOOPMAPPING +#define GEN_PASS_DEF_KOKKOSMEMORYSPACEASSIGNMENT + +#define GEN_PASS_DEF_KOKKOSMDRANGEITERATION + +#include "lapis/Dialect/Kokkos/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::kokkos; + +namespace { + +struct KokkosMdrangeIterationPass + : public impl::KokkosMdrangeIterationBase { + + KokkosMdrangeIterationPass() = default; + KokkosMdrangeIterationPass(const KokkosMdrangeIterationPass& pass) = default; + + void runOnOperation() override { + // do nothing + std::cerr << __FILE__ << ":" << __LINE__ << "\n"; + } +}; + +} // anonymous namespace + +std::unique_ptr mlir::createKokkosMdrangeIterationPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosPasses.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosPasses.cpp index 289cfd0c..f0a7b084 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosPasses.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosPasses.cpp @@ -93,4 +93,3 @@ std::unique_ptr mlir::createKokkosMemorySpaceAssignmentPass() { return std::make_unique(); } - From aca4a69e79438b2f4cff38e1b9dd8e7116965246 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 19 Dec 2024 13:36:10 -0700 Subject: [PATCH 02/37] mdrange: missing header --- .../Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index a3eb1697..a9c44791 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -21,6 +21,8 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include // is there an LLVM way to do this? + namespace mlir { #define GEN_PASS_DEF_PARALLELUNITSTEP #define GEN_PASS_DEF_KOKKOSLOOPMAPPING From 6322fbe4aacace53008febfa7198f1527b3ff76e Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 19 Dec 2024 13:38:26 -0700 Subject: [PATCH 03/37] mdrange: remove unneeded defs --- .../Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index a9c44791..f10d6227 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -24,10 +24,6 @@ #include // is there an LLVM way to do this? namespace mlir { -#define GEN_PASS_DEF_PARALLELUNITSTEP -#define GEN_PASS_DEF_KOKKOSLOOPMAPPING -#define GEN_PASS_DEF_KOKKOSMEMORYSPACEASSIGNMENT - #define GEN_PASS_DEF_KOKKOSMDRANGEITERATION #include "lapis/Dialect/Kokkos/Transforms/Passes.h.inc" From bf2e025eb94bbedb0d9c7da3b8ce8c0741c5e372 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 19 Dec 2024 13:52:41 -0700 Subject: [PATCH 04/37] mdrange: no-op test of no-op pass --- .../Kokkos/Transforms/KokkosMdrangeIterationPass.cpp | 3 --- mlir/test/Dialect/Kokkos/mdrange_0.mlir | 6 ++++++ mlir/test/Dialect/Kokkos/mdrange_0.mlir.gold | 6 ++++++ 3 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 mlir/test/Dialect/Kokkos/mdrange_0.mlir create mode 100644 mlir/test/Dialect/Kokkos/mdrange_0.mlir.gold diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index f10d6227..d40d2761 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -21,8 +21,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include // is there an LLVM way to do this? - namespace mlir { #define GEN_PASS_DEF_KOKKOSMDRANGEITERATION @@ -42,7 +40,6 @@ struct KokkosMdrangeIterationPass void runOnOperation() override { // do nothing - std::cerr << __FILE__ << ":" << __LINE__ << "\n"; } }; diff --git a/mlir/test/Dialect/Kokkos/mdrange_0.mlir b/mlir/test/Dialect/Kokkos/mdrange_0.mlir new file mode 100644 index 00000000..1cf7a523 --- /dev/null +++ b/mlir/test/Dialect/Kokkos/mdrange_0.mlir @@ -0,0 +1,6 @@ +// RUN: %lapis-opt %s --kokkos-mdrange-iteration | diff %s.gold - +module { + func.func @myfunc(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: index) { + return + } +} diff --git a/mlir/test/Dialect/Kokkos/mdrange_0.mlir.gold b/mlir/test/Dialect/Kokkos/mdrange_0.mlir.gold new file mode 100644 index 00000000..f228edac --- /dev/null +++ b/mlir/test/Dialect/Kokkos/mdrange_0.mlir.gold @@ -0,0 +1,6 @@ +module { + func.func @myfunc(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: index) { + return + } +} + From c1f6e48e93e8602c41420d39181d720a4481dd7a Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Tue, 14 Jan 2025 10:59:56 -0700 Subject: [PATCH 05/37] find some relevant stuff in the module --- .../Transforms/KokkosMdrangeIterationPass.cpp | 57 ++++++++++++++++++- mlir/test/Dialect/Kokkos/mdrange_1.mlir | 14 +++++ 2 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/Kokkos/mdrange_1.mlir diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index d40d2761..b9186439 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -38,8 +38,63 @@ struct KokkosMdrangeIterationPass KokkosMdrangeIterationPass() = default; KokkosMdrangeIterationPass(const KokkosMdrangeIterationPass& pass) = default; + + static void dump_ops(ModuleOp &mod) { + mod.walk([&](Operation *op) { + if (auto parallelOp = dyn_cast(op)) { + llvm::outs() << "Found scf.parallel operation:\n"; + llvm::outs() << "Induction variables and strides:\n"; + for (auto iv : llvm::zip(parallelOp.getInductionVars(), parallelOp.getStep())) { + std::get<0>(iv).print(llvm::outs()); + llvm::outs() << " with stride "; + std::get<1>(iv).print(llvm::outs()); + llvm::outs() << "\n"; + } + llvm::outs() << "\n\n"; + } + + if (auto memrefOp = dyn_cast(op)) { + llvm::outs() << "Found memref.load operation:\n"; + llvm::outs() << "MemRef: "; + memrefOp.getMemRef().print(llvm::outs()); + llvm::outs() << "\nIndex variables:\n"; + for (Value index : memrefOp.getIndices()) { + index.print(llvm::outs()); + llvm::outs() << "\n"; + } + if (auto memrefType = memrefOp.getMemRef().getType().dyn_cast()) { + llvm::outs() << "MemRef extents:\n"; + for (int64_t dim : memrefType.getShape()) { + llvm::outs() << dim << "\n"; + } + } + llvm::outs() << "\n\n"; + } + + if (auto memrefOp = dyn_cast(op)) { + llvm::outs() << "Found memref.store operation:\n"; + llvm::outs() << "MemRef: "; + memrefOp.getMemRef().print(llvm::outs()); + llvm::outs() << "\nIndex variables:\n"; + for (Value index : memrefOp.getIndices()) { + index.print(llvm::outs()); + llvm::outs() << "\n"; + } + if (auto memrefType = memrefOp.getMemRef().getType().dyn_cast()) { + llvm::outs() << "MemRef extents:\n"; + for (int64_t dim : memrefType.getShape()) { + llvm::outs() << dim << "\n"; + } + } + llvm::outs() << "\n\n"; + } + }); + } + + void runOnOperation() override { - // do nothing + ModuleOp module = getOperation(); + dump_ops(module); } }; diff --git a/mlir/test/Dialect/Kokkos/mdrange_1.mlir b/mlir/test/Dialect/Kokkos/mdrange_1.mlir new file mode 100644 index 00000000..5abf7e6d --- /dev/null +++ b/mlir/test/Dialect/Kokkos/mdrange_1.mlir @@ -0,0 +1,14 @@ +module { + func.func @example_function(%arg0: memref<10x20xf32>, %arg1: memref<10x20xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + + scf.parallel (%i, %j) = (%c0, %c0) to (%c10, %c20) step (%c1, %c1) { + %val = memref.load %arg0[%i, %j] : memref<10x20xf32> + memref.store %val, %arg1[%i, %j] : memref<10x20xf32> + } + return + } +} \ No newline at end of file From c0f7fac8a981effc22693f276eece7badedd0aa4 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 16 Jan 2025 16:30:34 -0700 Subject: [PATCH 06/37] Symbolic reuse distance and additional utilities --- .../Transforms/KokkosMdrangeIterationPass.cpp | 685 +++++++++++++++++- mlir/test/Dialect/Kokkos/mdrange_1.mlir | 4 +- 2 files changed, 685 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index b9186439..81f7d24e 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -6,6 +6,11 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include // pair +#include + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" @@ -32,12 +37,393 @@ using namespace mlir::kokkos; namespace { + +/* The basic idea: + +we have something like this + +scf.parallel (%i, %j) { (1) + memref.store[%j, %i] : memref<10, 20, f32> (2) + scf.parallel (%k, %l) { (3) + _ = memref.load[%i, %k] : memref<10, ?, f32> (4) + } +} + +Presuming that the scf.parallel will actually be implemented in a "layout-right" iteration order, and given that memrefs are layout right, how to do we want to order the scf.parallel induction variables? e.g. for (1), do we want (%i, %j) or (%j, %i)? +To answer, we build a cost model of each memref load/store, and choose the induction variable ordering for all scf.parallels that minimizes that cost. +The foundation of the cost model is the reuse distance of the memref, under the theory that accesses with better locality will be faster due to coalescing/caching. +The stride of the memref depends on whichever induction variable is the "right-most" one in the scf.parallel region, due to our "layout-right" iteration order assumption. + +Some examples: + +For (2), the reuse distance w.r.t. %i is 4 (sizeof f32), and the reuse distance with respect to %j is 20 * 4 (size of 1st dimension * sizeof f32) + +For (4), the reuse distance with respect to (%i) is 4 * whatever the 1st memref dimension is. +The reuse distance w.r.t %j is undefined (address does not change when %j changes). +The reuse distance w.r.t %k is 4. +The reuse distance w.r.t %l is undefined. + +The way to understand this is that if the index variable of the memref is some kind of simple function of the induction variable, we can compute the reuse distance. If it is not a function of the induction variable, or is a function of the induction variable but we don't know the function, we can't compute the reuse distance. + +------ + +So, what kind of simple functions can we compute? This takes the following approach: it tries to compute + +d(memref) / d(induction variable), the partial derivative of the accessed offset w.r.t the induction variable. We can ignore the base address, because it's derivative w.r.t all induction variables is 0 + +Via the chain rule; +d(memref) / d(induction variable) = d(memref) / d(index variable) * + d(index variable) / d(induction variable) + +Let's take d(index variable) / d(induction variable) first. + +------ + +d(index variable) / d(induction variable) is computed by recursively following the inputs to the operation and applying differentiation rules. + +To make this problem tractable, we make two simplifying assumptions: + +We only care about about results of the form df / dx = a * x +We only try to differentiate simple arithmetic functions, e.g. if f(x) = g(x) + h(x), df/dx = dg/dx + dh/dx. Similarly for multiply, divide, etc. Any other f we just give up and say who knows. + +------ + +d(memref) / d(index variable) is in principle simple, it's just the product of all strides of lower dimension than the index variable. In practice, however, most strides are unknown at compile time, so we won't be able to get an actual number, we'll just get expressions like + +d(memref) / d(index variable) = stride_0 * stride_2 * sizeof(datatype) + +if we're lucky, and all dimensions are known or there are no lower dimensions, then we can get an actual number. + +------ + +So in the end, d(memref) / d(induction variable) ends up being something of this form: + +stride_0 * stride_2 * sizeof(datatype) * a * x +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +memref / index var component + ^^^^^ + index var / induction var component + + +the second component might just be ???, and/or the first component might be a known integer number. + +------ + +In principle, each memref has a different cost for each induction variable ordering. +In practice, we just consider the induction variable that is incrementing the fastest for each memref - that is, the right-most induction variable for the closest enclosing loop. + +The reuse distance is just looked up in the previously computed table of d(memref) / d(induction variable) + +------ + +We generate all possible combinations of + * choose an induction variable from each parallel region to be the right-most one. +We compute the cost under each combination. + * Since the cost expression will contain many unknowns, we do monte-carlo simulation of the cost model for each induction variable ordering +We chosoe the induction variable ordering with the lowest cost + +*/ struct KokkosMdrangeIterationPass : public impl::KokkosMdrangeIterationBase { KokkosMdrangeIterationPass() = default; KokkosMdrangeIterationPass(const KokkosMdrangeIterationPass& pass) = default; + // generate a log-random integer within a specified range + static int getLogRandomInt(int min, int max) { + // Create a random device to seed the random number generator + std::random_device rd; + // Use the Mersenne Twister engine for random number generation + std::mt19937 gen(rd()); + + // Create a uniform real distribution between log(min) and log(max) + std::uniform_real_distribution<> dis(std::log(min), std::log(max)); + + // Generate a random number in the log space and exponentiate it + double logRandom = dis(gen); + int logRandomInt = static_cast(std::exp(logRandom)); + + // Ensure the result is within the desired range + if (logRandomInt < min) logRandomInt = min; + if (logRandomInt > max) logRandomInt = max; + + return logRandomInt; + } + + // a context for expression evaluation + struct Ctx { + std::unordered_map values; + }; + + struct Expr { + + enum class Kind { + Add, Mul, Constant, Unknown + }; + + Expr(Kind kind) : kind_(kind) {} + Kind kind_; + + virtual int eval(const Ctx &ctx) = 0; + virtual void dump(llvm::raw_fd_ostream &os) = 0; + virtual ~Expr() {} + }; + + struct Add : public Expr { + Add(std::shared_ptr lhs, std::shared_ptr rhs) : Expr(Kind::Add), lhs_(lhs), rhs_(rhs) {} + std::shared_ptr lhs_; + std::shared_ptr rhs_; + + virtual int eval(const Ctx &ctx) override { + return lhs_->eval(ctx) + rhs_->eval(ctx); + } + + virtual void dump(llvm::raw_fd_ostream &os) override { + os << "("; + lhs_->dump(os); + os << "+"; + rhs_->dump(os); + os << ")"; + } + + static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { + auto lhs_const = llvm::dyn_cast(lhs.get()); + auto rhs_const = llvm::dyn_cast(rhs.get()); + + if (lhs_const && lhs_const->value_ == 0) { + return rhs; + } else if (rhs_const && rhs_const->value_ == 0) { + return lhs; + } else if (rhs_const && lhs_const) { + return Constant::make(lhs_const->value_ + rhs_const->value_); + } + + return std::make_shared(lhs, rhs); + } + + static bool classof(const Expr *e) { + return e->kind_ == Expr::Kind::Add; + } + + }; + + struct Mul : public Expr { + Mul(std::shared_ptr lhs, std::shared_ptr rhs) : Expr(Kind::Mul), lhs_(lhs), rhs_(rhs) {} + std::shared_ptr lhs_; + std::shared_ptr rhs_; + + virtual int eval(const Ctx &ctx) override { + return lhs_->eval(ctx) * rhs_->eval(ctx); + } + + virtual void dump(llvm::raw_fd_ostream &os) override { + os << "("; + lhs_->dump(os); + os << "*"; + rhs_->dump(os); + os << ")"; + } + + static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { + auto lhs_const = llvm::dyn_cast(lhs.get()); + auto rhs_const = llvm::dyn_cast(rhs.get()); + + if (rhs_const && lhs_const) { + return Constant::make(lhs_const->value_ * rhs_const->value_); + } else if (lhs_const && lhs_const->value_ == 1) { + return rhs; + } else if (lhs_const && lhs_const->value_ == 0) { + return Constant::make(0); + } else if (rhs_const && rhs_const->value_ == 1) { + return lhs; + } else if (rhs_const && rhs_const->value_ == 0) { + return Constant::make(0); + } + + return std::make_shared(lhs, rhs); + } + + static bool classof(const Expr *e) { + return e->kind_ == Expr::Kind::Mul; + } + + }; + + struct Constant : public Expr { + Constant(int value) : Expr(Kind::Constant), value_(value) {} + int value_; + + virtual int eval(const Ctx &ctx) override { + return value_; + } + + virtual void dump(llvm::raw_fd_ostream &os) override { + os << value_; + } + + static std::shared_ptr make(int c) { + return std::make_shared(c); + } + + static bool classof(const Expr *e) { + return e->kind_ == Expr::Kind::Constant; + } + + }; + + struct Unknown : public Expr { + Unknown(const std::string &name) : Expr(Kind::Unknown), name_(name) {} + std::string name_; + + virtual int eval(const Ctx &ctx) override { + return ctx.values.at(name_); + } + + virtual void dump(llvm::raw_fd_ostream &os) override { + os << "(" << name_ << ")"; + } + + static std::shared_ptr make(const std::string &name) { + return std::make_shared(name); + } + + static bool classof(const Expr *e) { + return e->kind_ == Expr::Kind::Unknown; + } + + }; + + + + // cost model for memref / induction variable pair + struct Cost { + + Cost(std::shared_ptr stride, std::shared_ptr count, int sf) : stride_(stride), count_(count), sf_(sf) {} + Cost() = default; + + std::shared_ptr stride_; // stride of the memref w.r.t an induction variable + std::shared_ptr count_; // number of times the memref is executed + int sf_; // scaling factor, 1 for load, 3 for store + }; + + // partial derivative df/dx + +static std::shared_ptr df_dx(Value &f, Value &x) { + if (f == x) { + llvm::outs() << "Info: df_dx of equal values\n"; + return Constant::make(1); + } else if (mlir::isa(f) && mlir::isa(x)) { + llvm::outs() << "Info: df_dx of different block arguments\n"; + return Constant::make(0); + } else { + // FIXME: what other scenarios if there is no defining op. + if (auto fOp = f.getDefiningOp()) { + if (auto xOp = x.getDefiningOp()) { + return df_dx(fOp, xOp); + } + } + llvm::outs() << "ERROR: One of the values has no defining operation\n"; + return nullptr; + } +} + + // FIXME: better written as df_dx(f, x) I guess + static std::shared_ptr df_dx(Operation *df, Operation *dx) { + if (!df) { + llvm::outs() << "Warn: df_dx requested on null df\n"; + return nullptr; + } else if (!dx) { + llvm::outs() << "Warn: df_dx requested on null dx\n"; + return nullptr; + } else if (df == dx) { + // df/dx (dx) = 1 + return Constant::make(1); + } else if (auto constOp = dyn_cast(df)) { // f is + + return Constant::make(0); + } else if (auto addOp = dyn_cast(df)) { // f is + + // d(lhs + rhs)/dx = dlhs/dx + drhs/dx + Value lhs = addOp.getOperand(0); + Value rhs = addOp.getOperand(1); + std::shared_ptr dLhs = df_dx(lhs.getDefiningOp(), dx); + std::shared_ptr dRhs = df_dx(rhs.getDefiningOp(), dx); + if (dLhs && dRhs) { + return Add::make(dLhs, dRhs); + } + } else if (auto mulOp = dyn_cast(df)) { // f is * + // d(lhs * rhs)/dx = lhs * drhs/dx + rhs * dlhs/dx + // we'll only bother to compute this one if lhs or rhs is a constant + Value lhs = mulOp.getOperand(0); + Value rhs = mulOp.getOperand(1); + + if (auto lhsConst = lhs.getDefiningOp()) { // FIXME: is this all integral values? + // lhs is a constant, so the derivative is lhs * drhs/dx + std::shared_ptr dRhs = df_dx(rhs.getDefiningOp(), dx); + if (dRhs) { + return Mul::make(Constant::make(cast(lhsConst.getValue()).getInt()), dRhs); // FIXME: can this cast fail? + } + } + + if (auto rhsConst = rhs.getDefiningOp()) { // FIXME: is this all integral values? + // rhs is a constant, so the derivative is rhs * dlhs/dx + std::shared_ptr dLhs = df_dx(lhs.getDefiningOp(), dx); + if (dLhs) { + return Mul::make(Constant::make(cast(rhsConst.getValue()).getInt()), dLhs); // FIXME: can this cast fail? + } + } + } // TODO: sub, div + + llvm::outs() << "WARN: unhandled case in df_dx of "; + df->print(llvm::outs()); + llvm::outs() << " w.r.t."; + dx->print(llvm::outs()); + return nullptr; + } + + + + // computes d(offset) / d(index variable) + // FIXME: is there something that is both a LoadOp and a StoreOp? + template + static std::shared_ptr do_di(Memref &memrefOp, Value indexVar) { + + static_assert(std::is_same_v || std::is_same_v, "Memref must be either LoadOp or StoreOp"); + + // find the index var + int indexVarDim = 0; + for (mlir::Value var : memrefOp.getIndices()) { + if (var == indexVar) { + + auto memrefType = dyn_cast(memrefOp.getMemRef().getType()); // FIXME: can this fail? + + // Get the size in bits of the element type + mlir::Type elementType = memrefType.getElementType(); + unsigned sizeInBytes = elementType.getIntOrFloatBitWidth() / CHAR_BIT; + + std::shared_ptr res = std::make_shared(sizeInBytes); + + auto memrefShape = memrefType.getShape(); + for (int dim = 0; dim < indexVarDim; ++dim) { + if (memrefShape[dim] == ShapedType::kDynamic) { + std::string name = memrefOp.getOperation()->getName().getStringRef().str() + "_extent" + std::to_string(dim); // FIXME: unique name for each memref dimension + res = std::make_shared(res, Unknown::make(name)); + } else { + res = std::make_shared(res, std::make_shared(memrefShape[dim])); + } + } + + return res; + } + ++indexVarDim; + } + + // memref address is not a function of this variable + llvm::outs() << "Info: "; + memrefOp.print(llvm::outs()); + llvm::outs() << " is not a function of "; + indexVar.print(llvm::outs()); + llvm::outs() << "\n"; + return std::make_shared(0); + } static void dump_ops(ModuleOp &mod) { mod.walk([&](Operation *op) { @@ -62,7 +448,7 @@ struct KokkosMdrangeIterationPass index.print(llvm::outs()); llvm::outs() << "\n"; } - if (auto memrefType = memrefOp.getMemRef().getType().dyn_cast()) { + if (auto memrefType = dyn_cast(memrefOp.getMemRef().getType())) { llvm::outs() << "MemRef extents:\n"; for (int64_t dim : memrefType.getShape()) { llvm::outs() << dim << "\n"; @@ -80,7 +466,7 @@ struct KokkosMdrangeIterationPass index.print(llvm::outs()); llvm::outs() << "\n"; } - if (auto memrefType = memrefOp.getMemRef().getType().dyn_cast()) { + if (auto memrefType = dyn_cast(memrefOp.getMemRef().getType())) { llvm::outs() << "MemRef extents:\n"; for (int64_t dim : memrefType.getShape()) { llvm::outs() << dim << "\n"; @@ -92,9 +478,304 @@ struct KokkosMdrangeIterationPass } +// return groups of induction variables for +static std::vector all_induction_variables(std::vector &ops) { + std::vector vars; + for (auto &op : ops) { + for (auto &var : op.getInductionVars()) { + vars.push_back(var); + } + } + return vars; +} + +// map of (Operation*, Value) -> Cost +// map of the cost model for a given memref / induction variable pair +// using MemrefInductionCosts = std::map, Cost>; +class MemrefInductionCosts { + +public: + using key_type = std::pair; + using value_type = Cost; + using iterator = typename std::vector>::iterator; + using const_iterator = typename std::vector>::const_iterator; + +private: + std::vector> data_; + + // Find an iterator to a key-value pair by key + auto find(const key_type& key) const { + return std::find_if(data_.begin(), data_.end(), + [&key](const auto& pair) { return pair.first == key; }); + } + + auto find(const key_type& key) { + return std::find_if(data_.begin(), data_.end(), + [&key](const auto& pair) { return pair.first == key; }); + } + +public: + // Access value by key without bounds checking + value_type& operator[](const key_type& key) { + auto it = find(key); + if (it != data_.end()) { + return it->second; + } else { + data_.emplace_back(key, value_type{}); + return data_.back().second; + } + } + + // Get an iterator to the beginning + iterator begin() { + return data_.begin(); + } + + // Get a const iterator to the beginning + const_iterator begin() const { + return data_.begin(); + } + + // Get an iterator to the end + iterator end() { + return data_.end(); + } + + // Get a const iterator to the end + const_iterator end() const { + return data_.end(); + } + + + +}; + +// return induction variables for each parallel op in the module +static std::vector> all_induction_vars(ModuleOp &mod) { + std::vector> ret; + mod.walk([&](Operation *op) { + if (auto parallelOp = dyn_cast(op)) { + std::vector indVars; + for (Value &var : parallelOp.getInductionVars()) { + indVars.push_back(var); + } + ret.push_back(indVars); + } + }); // walk + return ret; + } + +static MemrefInductionCosts analyze_cost(ModuleOp &mod, std::vector &stack) { + + MemrefInductionCosts MIC; + + mod.walk([&](Operation *op) { + // skip memrefs outside a parallel region + if (auto parallelOp = dyn_cast(op)) { + stack.push_back(parallelOp); + MemrefInductionCosts costs = analyze_cost(parallelOp, stack); + stack.pop_back(); + for (const auto &kv : costs) { + MIC[kv.first] = kv.second; + } + } + }); // walk + + return MIC; + } + +static MemrefInductionCosts analyze_cost(scf::ParallelOp &parentOp, std::vector &stack) { + + MemrefInductionCosts MIC; + + parentOp.getBody()->walk([&](Operation *op) { + if (auto parallelOp = dyn_cast(op)) { + stack.push_back(parallelOp); + MemrefInductionCosts costs = analyze_cost(parallelOp, stack); + stack.pop_back(); + for (const auto &kv : costs) { + MIC[kv.first] = kv.second; + } + } else if (auto memrefOp = dyn_cast(op)) { + llvm::outs() << "nested memref load!\n"; + + std::vector indVars = all_induction_variables(stack); + + // compute the partial derivative of each memref with respect to all induction variables via the chain rule: + // d(offset)/d(indvar) = sum( + // d(offset)/d(index) * d(index)/d(indvar), + // for each index in indices) + + for (Value indVar : indVars) { + std::shared_ptr dodi = Constant::make(0); + for (Value indexVar : memrefOp.getIndices()) { + auto e1 = do_di(memrefOp, indexVar); + + llvm::outs() << "pd of " << memrefOp << " w.r.t " << indexVar << "\n"; + if (e1) { + e1->dump(llvm::outs()); + } else { + llvm::outs() << " undefined "; + } + llvm::outs() << "\n"; + + auto e2 = df_dx(indexVar, indVar); + + llvm::outs() << "pd of " << indexVar << " w.r.t " << indVar << "\n"; + if (e2) { + e2->dump(llvm::outs()); + } else { + llvm::outs() << " undefined "; + } + llvm::outs() << "\n"; + + if (e1 && e2) { + dodi = Add::make(dodi, Mul::make(e1, e2)); + } else { + dodi = nullptr; + break; + } + } + + llvm::outs() << "pd of " << memrefOp << " w.r.t " << indVar << "\n"; + if (dodi) { + dodi->dump(llvm::outs()); + } else { + llvm::outs() << " undefined "; + } + llvm::outs() << "\n"; + + // FIXME: compute trip count + MIC[std::make_pair(memrefOp, indVar)] = Cost(dodi, Constant::make(1), 1 /*load cost*/); + } + + + } else if (auto memrefOp = dyn_cast(op)) { + llvm::outs() << "nested memref store!\n"; + + std::vector indVars = all_induction_variables(stack); + + // compute the partial derivative of each memref with respect to all induction variables via the chain rule: + // d(offset)/d(indvar) = sum( + // d(offset)/d(index) * d(index)/d(indvar), + // for each index in indices) + + for (Value indVar : indVars) { + std::shared_ptr dodi = Constant::make(0); + for (Value indexVar : memrefOp.getIndices()) { + auto e1 = do_di(memrefOp, indexVar); + + llvm::outs() << "pd of " << memrefOp << " w.r.t " << indexVar << "\n"; + if (e1) { + e1->dump(llvm::outs()); + } else { + llvm::outs() << " undefined "; + } + llvm::outs() << "\n"; + + auto e2 = df_dx(indexVar, indVar); + + llvm::outs() << "pd of " << indexVar << " w.r.t " << indVar << "\n"; + if (e2) { + e2->dump(llvm::outs()); + } else { + llvm::outs() << " undefined "; + } + llvm::outs() << "\n"; + + if (e1 && e2) { + dodi = Add::make(dodi, Mul::make(e1, e2)); + } else { + dodi = nullptr; + break; + } + } + + llvm::outs() << "pd of " << memrefOp << " w.r.t " << indVar << "\n"; + if (dodi) { + dodi->dump(llvm::outs()); + } else { + llvm::outs() << " undefined "; + } + llvm::outs() << "\n"; + + // FIXME: compute trip count + MIC[std::make_pair(memrefOp, indVar)] = Cost(dodi, Constant::make(1), 3 /*store cost*/); + } + } + }); + + return MIC; + } + +/* +call f() on a vector containing all permutations of valid indices of the entries of vec +e.g vec = { + {1, 2}, + {3}, + {4, 5, 6} + }; +yields + +f( {0, 0, 0} ) +f( {0, 0, 1} ) +f( {0, 0, 2} ) +f( {1, 0, 0} ) +f( {1, 0, 1} ) +f( {1, 0, 2} ) +*/ +template +void walk_selections(const std::vector>& vec, Lambda &&f) { + if (vec.empty()) return; + + std::vector indices(vec.size(), 0); // Initialize indices to track positions in each vector + + while (true) { + // Print the current combination + for (size_t i = 0; i < vec.size(); ++i) { + // std::cout << indices[i] << " "; + f(indices); + } + std::cout << std::endl; + + // Find the rightmost vector that has more elements to iterate + size_t k = vec.size(); + while (k > 0) { + --k; + if (indices[k] < vec[k].size() - 1) { + ++indices[k]; + break; + } + indices[k] = 0; // Reset this index and move to the previous vector + } + + // If we've reset all indices, we're done + if (k == 0 && indices[0] == 0) { + break; + } + } +} + void runOnOperation() override { ModuleOp module = getOperation(); + llvm::outs() << "====\ndump_ops\n====\n"; dump_ops(module); + + llvm::outs() << "====\nanalyze_cost\n====\n"; + // FIXME: some helper function to tighten up `stack` scope + // model the cost + std::vector stack; + MemrefInductionCosts costs = analyze_cost(module, stack); + + llvm::outs() << "====\nall_induction_vars\n====\n"; + std::vector> parallelRegions = all_induction_vars(module); + for (const auto &r : parallelRegions) { + for (const auto &v : r) { + llvm::outs() << v << "\n"; + } + } + + llvm::outs() << "====\ndone\n====\n"; } }; diff --git a/mlir/test/Dialect/Kokkos/mdrange_1.mlir b/mlir/test/Dialect/Kokkos/mdrange_1.mlir index 5abf7e6d..53259e0e 100644 --- a/mlir/test/Dialect/Kokkos/mdrange_1.mlir +++ b/mlir/test/Dialect/Kokkos/mdrange_1.mlir @@ -1,5 +1,5 @@ module { - func.func @example_function(%arg0: memref<10x20xf32>, %arg1: memref<10x20xf32>) { + func.func @example_function(%arg0: memref<10x20xf32>, %arg1: memref) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index @@ -7,7 +7,7 @@ module { scf.parallel (%i, %j) = (%c0, %c0) to (%c10, %c20) step (%c1, %c1) { %val = memref.load %arg0[%i, %j] : memref<10x20xf32> - memref.store %val, %arg1[%i, %j] : memref<10x20xf32> + memref.store %val, %arg1[%i, %j] : memref } return } From ba6217ae2dd2904b6d7611a7479a70fd21828bcf Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 17 Jan 2025 11:21:24 -0700 Subject: [PATCH 07/37] Walk all possible parallel loop configurations --- .../Transforms/KokkosMdrangeIterationPass.cpp | 255 ++++++++++++++---- 1 file changed, 209 insertions(+), 46 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 81f7d24e..fa23bdc5 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -166,6 +166,7 @@ struct KokkosMdrangeIterationPass virtual int eval(const Ctx &ctx) = 0; virtual void dump(llvm::raw_fd_ostream &os) = 0; + virtual std::vector unknowns() const = 0; virtual ~Expr() {} }; @@ -186,6 +187,16 @@ struct KokkosMdrangeIterationPass os << ")"; } + virtual std::vector unknowns() const override { + std::vector ret; + for (auto &op : {lhs_, rhs_}) { + for (auto &name : op->unknowns()) { + ret.push_back(name); + } + } + return ret; + } + static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { auto lhs_const = llvm::dyn_cast(lhs.get()); auto rhs_const = llvm::dyn_cast(rhs.get()); @@ -224,6 +235,16 @@ struct KokkosMdrangeIterationPass os << ")"; } + virtual std::vector unknowns() const override { + std::vector ret; + for (auto &op : {lhs_, rhs_}) { + for (auto &name : op->unknowns()) { + ret.push_back(name); + } + } + return ret; + } + static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { auto lhs_const = llvm::dyn_cast(lhs.get()); auto rhs_const = llvm::dyn_cast(rhs.get()); @@ -247,6 +268,8 @@ struct KokkosMdrangeIterationPass return e->kind_ == Expr::Kind::Mul; } + + }; struct Constant : public Expr { @@ -261,6 +284,10 @@ struct KokkosMdrangeIterationPass os << value_; } + virtual std::vector unknowns() const override { + return {}; + } + static std::shared_ptr make(int c) { return std::make_shared(c); } @@ -268,7 +295,6 @@ struct KokkosMdrangeIterationPass static bool classof(const Expr *e) { return e->kind_ == Expr::Kind::Constant; } - }; struct Unknown : public Expr { @@ -283,6 +309,10 @@ struct KokkosMdrangeIterationPass os << "(" << name_ << ")"; } + virtual std::vector unknowns() const override { + return {name_}; + } + static std::shared_ptr make(const std::string &name) { return std::make_shared(name); } @@ -304,6 +334,13 @@ struct KokkosMdrangeIterationPass std::shared_ptr stride_; // stride of the memref w.r.t an induction variable std::shared_ptr count_; // number of times the memref is executed int sf_; // scaling factor, 1 for load, 3 for store + + std::vector unknowns() const { + std::vector ret; + for (auto &u : stride_->unknowns()) ret.push_back(u); + for (auto &u : count_->unknowns()) ret.push_back(u); + return ret; + } }; // partial derivative df/dx @@ -382,7 +419,6 @@ static std::shared_ptr df_dx(Value &f, Value &x) { // computes d(offset) / d(index variable) - // FIXME: is there something that is both a LoadOp and a StoreOp? template static std::shared_ptr do_di(Memref &memrefOp, Value indexVar) { @@ -489,20 +525,18 @@ static std::vector all_induction_variables(std::vector & return vars; } -// map of (Operation*, Value) -> Cost -// map of the cost model for a given memref / induction variable pair -// using MemrefInductionCosts = std::map, Cost>; -class MemrefInductionCosts { +template +class VecMap { public: - using key_type = std::pair; - using value_type = Cost; using iterator = typename std::vector>::iterator; using const_iterator = typename std::vector>::const_iterator; private: std::vector> data_; + +public: // Find an iterator to a key-value pair by key auto find(const key_type& key) const { return std::find_if(data_.begin(), data_.end(), @@ -514,7 +548,6 @@ class MemrefInductionCosts { [&key](const auto& pair) { return pair.first == key; }); } -public: // Access value by key without bounds checking value_type& operator[](const key_type& key) { auto it = find(key); @@ -526,46 +559,28 @@ class MemrefInductionCosts { } } - // Get an iterator to the beginning iterator begin() { return data_.begin(); } - // Get a const iterator to the beginning const_iterator begin() const { return data_.begin(); } - // Get an iterator to the end iterator end() { return data_.end(); } - // Get a const iterator to the end const_iterator end() const { return data_.end(); } - - - }; -// return induction variables for each parallel op in the module -static std::vector> all_induction_vars(ModuleOp &mod) { - std::vector> ret; - mod.walk([&](Operation *op) { - if (auto parallelOp = dyn_cast(op)) { - std::vector indVars; - for (Value &var : parallelOp.getInductionVars()) { - indVars.push_back(var); - } - ret.push_back(indVars); - } - }); // walk - return ret; - } +// map of (Operation*, Value) -> Cost +// map of the cost model for a given memref / induction variable pair +using MemrefInductionCosts = VecMap, Cost>; -static MemrefInductionCosts analyze_cost(ModuleOp &mod, std::vector &stack) { +static MemrefInductionCosts build_cost_table(ModuleOp &mod, std::vector &stack) { MemrefInductionCosts MIC; @@ -573,7 +588,7 @@ static MemrefInductionCosts analyze_cost(ModuleOp &mod, std::vector(op)) { stack.push_back(parallelOp); - MemrefInductionCosts costs = analyze_cost(parallelOp, stack); + MemrefInductionCosts costs = build_cost_table(parallelOp, stack); stack.pop_back(); for (const auto &kv : costs) { MIC[kv.first] = kv.second; @@ -584,14 +599,14 @@ static MemrefInductionCosts analyze_cost(ModuleOp &mod, std::vector &stack) { +static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, std::vector &stack) { MemrefInductionCosts MIC; parentOp.getBody()->walk([&](Operation *op) { if (auto parallelOp = dyn_cast(op)) { stack.push_back(parallelOp); - MemrefInductionCosts costs = analyze_cost(parallelOp, stack); + MemrefInductionCosts costs = build_cost_table(parallelOp, stack); stack.pop_back(); for (const auto &kv : costs) { MIC[kv.first] = kv.second; @@ -708,6 +723,56 @@ static MemrefInductionCosts analyze_cost(scf::ParallelOp &parentOp, std::vector< return MIC; } + struct ParallelConfig { + // permutation of induction variables for each parallel op + VecMap> perms_; + }; + + + static size_t get_num_induction_vars(scf::ParallelOp ¶llelOp) { + return parallelOp.getInductionVars().size(); + } + + template + void walk_configurations(scf::ParallelOp &parentOp, ParallelConfig cfg, Lambda &&f) { + bool found = false; + parentOp.getBody()->walk([&](Operation *op) { + if (auto parallelOp = dyn_cast(op)) { + found = true; + + // walk all configurations of this parallel op too + std::vector perm(get_num_induction_vars(parallelOp)); + std::iota(perm.begin(), perm.end(), 0); + do { + cfg.perms_[parallelOp] = perm; + walk_configurations(parallelOp, cfg, std::forward(f)); + } while (std::next_permutation(perm.begin(), perm.end())); + } + }); // walk + + // no nested parallel regions, no more configurations to go through, call f + if (!found) { + f(cfg); + } + } + + template + void walk_configurations(ModuleOp &mod, Lambda &&f) { + mod.walk([&](Operation *op) { + if (auto parallelOp = dyn_cast(op)) { + + std::vector perm(get_num_induction_vars(parallelOp)); + std::iota(perm.begin(), perm.end(), 0); + + do { + ParallelConfig cfg; + cfg.perms_[parallelOp] = perm; + walk_configurations(parallelOp, cfg, f); + } while (std::next_permutation(perm.begin(), perm.end())); + } + }); // walk + } + /* call f() on a vector containing all permutations of valid indices of the entries of vec e.g vec = { @@ -724,8 +789,8 @@ f( {1, 0, 0} ) f( {1, 0, 1} ) f( {1, 0, 2} ) */ -template -void walk_selections(const std::vector>& vec, Lambda &&f) { +template +void walk_selections(const std::vector>& vec, Lambda &&f) { if (vec.empty()) return; std::vector indices(vec.size(), 0); // Initialize indices to track positions in each vector @@ -736,7 +801,7 @@ void walk_selections(const std::vector>& vec, Lambda &&f) { // std::cout << indices[i] << " "; f(indices); } - std::cout << std::endl; + // std::cout << std::endl; // Find the rightmost vector that has more elements to iterate size_t k = vec.size(); @@ -756,24 +821,122 @@ void walk_selections(const std::vector>& vec, Lambda &&f) { } } + + + + + // model the cost of a module with a given parallel configuration + static size_t model_cost(ModuleOp &mod, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { + size_t cost = 0; + mod.walk([&](Operation *op) { + if (auto parallelOp = dyn_cast(op)) { + cost += model_cost(parallelOp, cfg, costTable); + } + }); // walk + return cost; + } + + // model the cost of a parallel operation with a given config + static size_t model_cost(scf::ParallelOp &parentOp, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { + + size_t cost = 0; + + parentOp.getBody()->walk([&](Operation *op) { + if (auto parallelOp = dyn_cast(op)) { + cost += model_cost(parallelOp, cfg, costTable); + } else if (auto memrefOp = dyn_cast(op)) { + cost += model_cost(parentOp, memrefOp, cfg, costTable); + } else if (auto memrefOp = dyn_cast(op)) { + cost += model_cost(parentOp, memrefOp, cfg, costTable); + } + }); + + return cost; + } + + template + static size_t model_cost(scf::ParallelOp ¶llelOp, MemrefOp &memrefOp, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { + static_assert(std::is_same_v || std::is_same_v); + + + llvm::outs() << "model cost of "; + memrefOp.print(llvm::outs()); + llvm::outs() << "\n"; + + if (auto it = cfg.perms_.find(parallelOp); it != cfg.perms_.end()) { + llvm::outs() << "found perm for memref's parent parallelOp in config\n"; + + const std::vector &perm = it->second; + Value rightMostVar = parallelOp.getInductionVars()[perm[perm.size() - 1]]; + + llvm::outs() << "right-most induction var is "; + rightMostVar.print(llvm::outs()); + llvm::outs() << "\n"; + + + // FIXME: why does this work? the table should expect key to be pair not pair + auto costKey = std::make_pair(memrefOp, rightMostVar); + if (auto jt = costTable.find(costKey); jt != costTable.end()) { + llvm::outs() << "found cost model in table\n"; + + Cost model = jt->second; + std::vector unknowns = model.unknowns(); + + // TODO: this context just says every stride is 10 + Ctx ctx; + for (auto &name : unknowns) { + ctx.values[name] = 10; + } + + size_t cost = model.stride_->eval(ctx); + return cost; + + } + + } + + size_t cost = 0; + return cost; + } + void runOnOperation() override { ModuleOp module = getOperation(); llvm::outs() << "====\ndump_ops\n====\n"; dump_ops(module); - llvm::outs() << "====\nanalyze_cost\n====\n"; + llvm::outs() << "====\nbuild_cost_table\n====\n"; // FIXME: some helper function to tighten up `stack` scope - // model the cost std::vector stack; - MemrefInductionCosts costs = analyze_cost(module, stack); + MemrefInductionCosts costTable = build_cost_table(module, stack); + + + llvm::outs() << "====\nmodel reordered induction vars\n====\n"; + size_t minCost = std::numeric_limits::max(); + ParallelConfig minCfg; + walk_configurations(module, [&](const ParallelConfig &cfg){ - llvm::outs() << "====\nall_induction_vars\n====\n"; - std::vector> parallelRegions = all_induction_vars(module); - for (const auto &r : parallelRegions) { - for (const auto &v : r) { - llvm::outs() << v << "\n"; + llvm::outs() << "modeling ParallelConfig:\n"; + for (const auto &kv : cfg.perms_) { + kv.first->print(llvm::outs()); + llvm::outs() << " -> {"; + for(const auto &e : kv.second) { + llvm::outs() << e << ", "; + } + llvm::outs() << "}\n"; } - } + + size_t cost = model_cost(module, cfg, costTable); + llvm::outs() << "cost was " << cost << "\n"; + if (cost < minCost) { + llvm::outs() << "Info: new optimal! " << cost << "\n"; + minCost = cost; + minCfg = cfg; + } + + }); // walk_configurations + llvm::outs() << "min cost: " << minCost << "\n"; + + llvm::outs() << "====\nbuild new module\n====\n"; llvm::outs() << "====\ndone\n====\n"; } From b9d6325a21984d884d10bd785b8ad3397107f37c Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 17 Jan 2025 11:42:10 -0700 Subject: [PATCH 08/37] primitive monte-carlo --- .../Transforms/KokkosMdrangeIterationPass.cpp | 61 ++++++++++++------- 1 file changed, 40 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index fa23bdc5..4aa0615f 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -130,24 +130,19 @@ struct KokkosMdrangeIterationPass KokkosMdrangeIterationPass(const KokkosMdrangeIterationPass& pass) = default; // generate a log-random integer within a specified range - static int getLogRandomInt(int min, int max) { - // Create a random device to seed the random number generator - std::random_device rd; - // Use the Mersenne Twister engine for random number generation - std::mt19937 gen(rd()); - + static size_t log_random_int(std::mt19937 &gen, size_t min, size_t max) { // Create a uniform real distribution between log(min) and log(max) std::uniform_real_distribution<> dis(std::log(min), std::log(max)); // Generate a random number in the log space and exponentiate it double logRandom = dis(gen); - int logRandomInt = static_cast(std::exp(logRandom)); + size_t res = std::exp(logRandom); // Ensure the result is within the desired range - if (logRandomInt < min) logRandomInt = min; - if (logRandomInt > max) logRandomInt = max; + if (res < min) res = min; + if (res > max) res = max; - return logRandomInt; + return res; } // a context for expression evaluation @@ -854,6 +849,40 @@ void walk_selections(const std::vector>& vec, Lambda &&f) { return cost; } + + static size_t monte_carlo(const Cost &model, int n = 100, int seed = 31337) { + std::mt19937 gen(seed); + + std::vector costs; + + std::vector unknowns = model.unknowns(); + + for (int i = 0; i < n; i++) { + + // generate random values for all unknowns in cost model + Ctx ctx; + for (auto &name : unknowns) { + auto val = log_random_int(gen, 1, 1000000); + llvm::outs() << name << ": " << val << "\n"; + ctx.values[name] = val; + } + + costs.push_back(model.stride_->eval(ctx)); + } + + // FIXME: here we do median, is there a principled aggregation strategy? + // kth pctile cost? + // average of worst k? + // worst / average ("competitive ratio")? + // geometric mean? + // trimmed mean? + // + // robustness metrics? + // coefficient of variation + std::sort(costs.begin(), costs.end()); + return costs[costs.size() / 2]; + } + template static size_t model_cost(scf::ParallelOp ¶llelOp, MemrefOp &memrefOp, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { static_assert(std::is_same_v || std::is_same_v); @@ -880,17 +909,7 @@ void walk_selections(const std::vector>& vec, Lambda &&f) { llvm::outs() << "found cost model in table\n"; Cost model = jt->second; - std::vector unknowns = model.unknowns(); - - // TODO: this context just says every stride is 10 - Ctx ctx; - for (auto &name : unknowns) { - ctx.values[name] = 10; - } - - size_t cost = model.stride_->eval(ctx); - return cost; - + return monte_carlo(model); } } From c98872fe8bd4e617b0fd4179704b210c5b7843c5 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 17 Jan 2025 11:46:33 -0700 Subject: [PATCH 09/37] remove unused walk_selections Signed-off-by: Carl Pearson --- .../Transforms/KokkosMdrangeIterationPass.cpp | 52 ------------------- 1 file changed, 52 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 4aa0615f..d9727429 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -768,58 +768,6 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, std::vec }); // walk } -/* -call f() on a vector containing all permutations of valid indices of the entries of vec -e.g vec = { - {1, 2}, - {3}, - {4, 5, 6} - }; -yields - -f( {0, 0, 0} ) -f( {0, 0, 1} ) -f( {0, 0, 2} ) -f( {1, 0, 0} ) -f( {1, 0, 1} ) -f( {1, 0, 2} ) -*/ -template -void walk_selections(const std::vector>& vec, Lambda &&f) { - if (vec.empty()) return; - - std::vector indices(vec.size(), 0); // Initialize indices to track positions in each vector - - while (true) { - // Print the current combination - for (size_t i = 0; i < vec.size(); ++i) { - // std::cout << indices[i] << " "; - f(indices); - } - // std::cout << std::endl; - - // Find the rightmost vector that has more elements to iterate - size_t k = vec.size(); - while (k > 0) { - --k; - if (indices[k] < vec[k].size() - 1) { - ++indices[k]; - break; - } - indices[k] = 0; // Reset this index and move to the previous vector - } - - // If we've reset all indices, we're done - if (k == 0 && indices[0] == 0) { - break; - } - } -} - - - - - // model the cost of a module with a given parallel configuration static size_t model_cost(ModuleOp &mod, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { size_t cost = 0; From 9ceabce5102a5320aa029b00a830af82ba026ba5 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 17 Jan 2025 12:19:45 -0700 Subject: [PATCH 10/37] Add Sub and Div (useful for trip counts) --- .../Transforms/KokkosMdrangeIterationPass.cpp | 193 +++++++++++++++--- 1 file changed, 159 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index d9727429..f3555a66 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -153,7 +153,7 @@ struct KokkosMdrangeIterationPass struct Expr { enum class Kind { - Add, Mul, Constant, Unknown + Add, Sub, Mul, Div, Constant, Unknown }; Expr(Kind kind) : kind_(kind) {} @@ -165,19 +165,13 @@ struct KokkosMdrangeIterationPass virtual ~Expr() {} }; - struct Add : public Expr { - Add(std::shared_ptr lhs, std::shared_ptr rhs) : Expr(Kind::Add), lhs_(lhs), rhs_(rhs) {} - std::shared_ptr lhs_; - std::shared_ptr rhs_; - - virtual int eval(const Ctx &ctx) override { - return lhs_->eval(ctx) + rhs_->eval(ctx); - } + struct Binary : public Expr { + Binary(Kind kind, const std::string sym, std::shared_ptr lhs, std::shared_ptr rhs) : Expr(kind), sym_(sym), lhs_(lhs), rhs_(rhs) {} virtual void dump(llvm::raw_fd_ostream &os) override { os << "("; lhs_->dump(os); - os << "+"; + os << sym_; rhs_->dump(os); os << ")"; } @@ -192,6 +186,19 @@ struct KokkosMdrangeIterationPass return ret; } + protected: + std::string sym_; + std::shared_ptr lhs_; + std::shared_ptr rhs_; + }; + + struct Add : public Binary { + Add(std::shared_ptr lhs, std::shared_ptr rhs) : Binary(Kind::Add, "+", lhs, rhs) {} + + virtual int eval(const Ctx &ctx) override { + return lhs_->eval(ctx) + rhs_->eval(ctx); + } + static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { auto lhs_const = llvm::dyn_cast(lhs.get()); auto rhs_const = llvm::dyn_cast(rhs.get()); @@ -210,34 +217,40 @@ struct KokkosMdrangeIterationPass static bool classof(const Expr *e) { return e->kind_ == Expr::Kind::Add; } - }; - struct Mul : public Expr { - Mul(std::shared_ptr lhs, std::shared_ptr rhs) : Expr(Kind::Mul), lhs_(lhs), rhs_(rhs) {} - std::shared_ptr lhs_; - std::shared_ptr rhs_; + struct Sub : public Binary { + Sub(std::shared_ptr lhs, std::shared_ptr rhs) : Binary(Kind::Add, "-", lhs, rhs) {} virtual int eval(const Ctx &ctx) override { - return lhs_->eval(ctx) * rhs_->eval(ctx); + return lhs_->eval(ctx) + rhs_->eval(ctx); } - virtual void dump(llvm::raw_fd_ostream &os) override { - os << "("; - lhs_->dump(os); - os << "*"; - rhs_->dump(os); - os << ")"; - } + static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { + auto lhs_const = llvm::dyn_cast(lhs.get()); + auto rhs_const = llvm::dyn_cast(rhs.get()); - virtual std::vector unknowns() const override { - std::vector ret; - for (auto &op : {lhs_, rhs_}) { - for (auto &name : op->unknowns()) { - ret.push_back(name); - } + if (lhs_const && lhs_const->value_ == 0) { + return Mul::make(rhs, Constant::make(-1)); + } else if (rhs_const && rhs_const->value_ == 0) { + return lhs; + } else if (rhs_const && lhs_const) { + return Constant::make(lhs_const->value_ - rhs_const->value_); } - return ret; + + return std::make_shared(lhs, rhs); + } + + static bool classof(const Expr *e) { + return e->kind_ == Expr::Kind::Sub; + } + }; + + struct Mul : public Binary { + Mul(std::shared_ptr lhs, std::shared_ptr rhs) : Binary(Kind::Mul, "*", lhs, rhs) {} + + virtual int eval(const Ctx &ctx) override { + return lhs_->eval(ctx) * rhs_->eval(ctx); } static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { @@ -246,14 +259,18 @@ struct KokkosMdrangeIterationPass if (rhs_const && lhs_const) { return Constant::make(lhs_const->value_ * rhs_const->value_); - } else if (lhs_const && lhs_const->value_ == 1) { + } else if (lhs_const && lhs_const->value_ == 1) { // 1 * x return rhs; - } else if (lhs_const && lhs_const->value_ == 0) { + } else if (lhs_const && lhs_const->value_ == 0) { // 0 * x return Constant::make(0); - } else if (rhs_const && rhs_const->value_ == 1) { + } else if (rhs_const && rhs_const->value_ == 1) { // x * 1 return lhs; - } else if (rhs_const && rhs_const->value_ == 0) { + } else if (rhs_const && rhs_const->value_ == 0) { // x * 0 return Constant::make(0); + } else if (rhs_const && rhs_const->value_ == -1) { // x * -1 + return Constant::make(-lhs_const->value_); + } else if (lhs_const && lhs_const->value_ == -1) { // -1 * x + return Constant::make(-rhs_const->value_); } return std::make_shared(lhs, rhs); @@ -262,11 +279,38 @@ struct KokkosMdrangeIterationPass static bool classof(const Expr *e) { return e->kind_ == Expr::Kind::Mul; } + }; + struct Div : public Binary { + Div(std::shared_ptr lhs, std::shared_ptr rhs) : Binary(Kind::Mul, "/", lhs, rhs) {} + virtual int eval(const Ctx &ctx) override { + return lhs_->eval(ctx) / rhs_->eval(ctx); + } + + static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { + auto lhs_const = llvm::dyn_cast(lhs.get()); + auto rhs_const = llvm::dyn_cast(rhs.get()); + if (rhs_const && lhs_const) { + return Constant::make(lhs_const->value_ * rhs_const->value_); + } else if (lhs_const && lhs_const->value_ == 0) { // 0 / x + return Constant::make(0); + } else if (rhs_const && rhs_const->value_ == 1) { // x / 1 + return lhs; + } else if (rhs_const && rhs_const->value_ == -1) { // x / -1 + return Constant::make(-lhs_const->value_); + } + + return std::make_shared
(lhs, rhs); + } + + static bool classof(const Expr *e) { + return e->kind_ == Expr::Kind::Div; + } }; + struct Constant : public Expr { Constant(int value) : Expr(Kind::Constant), value_(value) {} int value_; @@ -571,6 +615,87 @@ class VecMap { } }; + +std::shared_ptr iteration_space_size(scf::ParallelOp &op, int dim) { +#if 0 + auto lb = op.getLowerBound()[dim]; + auto ub = op.getUpperBound()[dim]; + auto st = op.getStep()[dim]; + + std::shared_ptr lbExpr; + std::shared_ptr ubExpr; + std::shared_ptr stExpr; + + // Assuming the bounds and steps are constant integers for simplicity + if (auto lbConst = lb.getDefiningOp()) { + if (auto ubConst = ub.getDefiningOp()) { + if (auto stepConst = step.getDefiningOp()) { + int64_t lbValue = lbConst.value(); + int64_t ubValue = ubConst.value(); + int64_t stepValue = stepConst.value(); + + int64_t iterationSpaceSize = (ubValue - lbValue + stepValue - 1) / stepValue; + llvm::outs() << "Iteration space size for dimension " << i << ": " << iterationSpaceSize << "\n"; + } + } + } + + // (ub - lb + step - 1) / step + auto num = Add::make(Sub::make(ubExpr, lbExpr), Sub::make(stExpr, Constant::make(1))); + return Div::make(num, lbExpr) +#endif +} + +std::shared_ptr trip_count_expr(scf::ParallelOp &op) { +#if 0 + auto lowerBounds = parallelOp.getLowerBound(); + + std::shared_ptr total = iteration_space_size(0); + + for (unsigned i = 1; i < lowerBounds.size(); ++i) { + + total = Mul::make(total, iteration_space_size(i)); + + + + } +#endif +} + +using ParallelTripCounts = VecMap>; + +static ParallelTripCounts build_parallel_trip_counts(ModuleOp &mod) { +#if 0 + ParallelTripCounts PTC; + + mod.walk([&](Operation *op) { + // skip memrefs outside a parallel region + if (auto parallelOp = dyn_cast(op)) { + + + + + + + MemrefInductionCosts costs = build_parallel_trip_counts(parallelOp); + for (const auto &kv : costs) { + PTC[kv.first] = kv.second; + } + } + }); // walk + + + return PTC; +#endif +} + +static ParallelTripCounts build_parallel_trip_counts(scf::ParallelOp &parentOp) { + + ParallelTripCounts PTC; + return PTC; + +} + // map of (Operation*, Value) -> Cost // map of the cost model for a given memref / induction variable pair using MemrefInductionCosts = VecMap, Cost>; From 0a3368895d0749ebafc5e7323da1109e79af3237 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 17 Jan 2025 12:20:07 -0700 Subject: [PATCH 11/37] Fix div symbol --- .../Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index f3555a66..7f3e56c7 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -282,7 +282,7 @@ struct KokkosMdrangeIterationPass }; struct Div : public Binary { - Div(std::shared_ptr lhs, std::shared_ptr rhs) : Binary(Kind::Mul, "/", lhs, rhs) {} + Div(std::shared_ptr lhs, std::shared_ptr rhs) : Binary(Kind::Div, "/", lhs, rhs) {} virtual int eval(const Ctx &ctx) override { return lhs_->eval(ctx) / rhs_->eval(ctx); From 9780455ca330908ec2ef50665656f61745131ac0 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 23 Jan 2025 14:25:36 -0700 Subject: [PATCH 12/37] expressions for parallel trip counts Signed-off-by: Carl Pearson --- .../Transforms/KokkosMdrangeIterationPass.cpp | 140 ++++++++++++------ mlir/test/Dialect/Kokkos/mdrange_2.mlir | 13 ++ 2 files changed, 110 insertions(+), 43 deletions(-) create mode 100644 mlir/test/Dialect/Kokkos/mdrange_2.mlir diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 7f3e56c7..c60a9291 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -162,6 +162,7 @@ struct KokkosMdrangeIterationPass virtual int eval(const Ctx &ctx) = 0; virtual void dump(llvm::raw_fd_ostream &os) = 0; virtual std::vector unknowns() const = 0; + virtual std::shared_ptr clone() const = 0; virtual ~Expr() {} }; @@ -199,6 +200,10 @@ struct KokkosMdrangeIterationPass return lhs_->eval(ctx) + rhs_->eval(ctx); } + virtual std::shared_ptr clone() const override { + return make(lhs_->clone(), rhs_->clone()); + } + static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { auto lhs_const = llvm::dyn_cast(lhs.get()); auto rhs_const = llvm::dyn_cast(rhs.get()); @@ -226,6 +231,10 @@ struct KokkosMdrangeIterationPass return lhs_->eval(ctx) + rhs_->eval(ctx); } + virtual std::shared_ptr clone() const override { + return make(lhs_->clone(), rhs_->clone()); + } + static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { auto lhs_const = llvm::dyn_cast(lhs.get()); auto rhs_const = llvm::dyn_cast(rhs.get()); @@ -253,6 +262,10 @@ struct KokkosMdrangeIterationPass return lhs_->eval(ctx) * rhs_->eval(ctx); } + virtual std::shared_ptr clone() const override { + return make(lhs_->clone(), rhs_->clone()); + } + static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { auto lhs_const = llvm::dyn_cast(lhs.get()); auto rhs_const = llvm::dyn_cast(rhs.get()); @@ -288,6 +301,10 @@ struct KokkosMdrangeIterationPass return lhs_->eval(ctx) / rhs_->eval(ctx); } + virtual std::shared_ptr clone() const override { + return make(lhs_->clone(), rhs_->clone()); + } + static std::shared_ptr make(std::shared_ptr lhs, std::shared_ptr rhs) { auto lhs_const = llvm::dyn_cast(lhs.get()); auto rhs_const = llvm::dyn_cast(rhs.get()); @@ -319,6 +336,10 @@ struct KokkosMdrangeIterationPass return value_; } + virtual std::shared_ptr clone() const override { + return make(value_); + } + virtual void dump(llvm::raw_fd_ostream &os) override { os << value_; } @@ -344,6 +365,10 @@ struct KokkosMdrangeIterationPass return ctx.values.at(name_); } + virtual std::shared_ptr clone() const override { + return make(name_); + } + virtual void dump(llvm::raw_fd_ostream &os) override { os << "(" << name_ << ")"; } @@ -616,8 +641,21 @@ class VecMap { }; -std::shared_ptr iteration_space_size(scf::ParallelOp &op, int dim) { -#if 0 +// FIXME: this returns things like this: +// of type 'index' at index: 2 +static std::string get_value_name(mlir::Value &value) { + + if (mlir::isa(value)) { + auto ba = mlir::cast(value); + return std::string("block") +std::to_string(uintptr_t(ba.getOwner())) + "_arg" + std::to_string(ba.getArgNumber()); + } else { + return value.getDefiningOp()->getName().getStringRef().str(); + } +} + + +static std::shared_ptr iteration_space_size(scf::ParallelOp &op, int dim) { + auto lb = op.getLowerBound()[dim]; auto ub = op.getUpperBound()[dim]; auto st = op.getStep()[dim]; @@ -626,74 +664,79 @@ std::shared_ptr iteration_space_size(scf::ParallelOp &op, int dim) { std::shared_ptr ubExpr; std::shared_ptr stExpr; - // Assuming the bounds and steps are constant integers for simplicity if (auto lbConst = lb.getDefiningOp()) { - if (auto ubConst = ub.getDefiningOp()) { - if (auto stepConst = step.getDefiningOp()) { - int64_t lbValue = lbConst.value(); - int64_t ubValue = ubConst.value(); - int64_t stepValue = stepConst.value(); - - int64_t iterationSpaceSize = (ubValue - lbValue + stepValue - 1) / stepValue; - llvm::outs() << "Iteration space size for dimension " << i << ": " << iterationSpaceSize << "\n"; - } - } + lbExpr = Constant::make(lbConst.value()); + } else { + lbExpr = Unknown::make(get_value_name(lb)); + } + + if (auto ubConst = ub.getDefiningOp()) { + ubExpr = Constant::make(ubConst.value()); + } else { + ubExpr = Unknown::make(get_value_name(ub)); + } + + if (auto stepConst = st.getDefiningOp()) { + stExpr = Constant::make(stepConst.value()); + } else { + stExpr = Unknown::make(get_value_name(st)); } // (ub - lb + step - 1) / step + // TODO: this could be a special DivCeil operation or something auto num = Add::make(Sub::make(ubExpr, lbExpr), Sub::make(stExpr, Constant::make(1))); - return Div::make(num, lbExpr) -#endif + return Div::make(num, stExpr); } -std::shared_ptr trip_count_expr(scf::ParallelOp &op) { -#if 0 - auto lowerBounds = parallelOp.getLowerBound(); - - std::shared_ptr total = iteration_space_size(0); - +// return an Expr representing the product of the iteration space of all dimensions +static std::shared_ptr trip_count_expr(scf::ParallelOp &op) { + auto lowerBounds = op.getLowerBound(); + std::shared_ptr total = iteration_space_size(op, 0); for (unsigned i = 1; i < lowerBounds.size(); ++i) { - - total = Mul::make(total, iteration_space_size(i)); - - - - } -#endif + total = Mul::make(total, iteration_space_size(op, i)); + } + return total; } using ParallelTripCounts = VecMap>; static ParallelTripCounts build_parallel_trip_counts(ModuleOp &mod) { -#if 0 + ParallelTripCounts PTC; mod.walk([&](Operation *op) { - // skip memrefs outside a parallel region if (auto parallelOp = dyn_cast(op)) { + // create an expression representing the trip count for this loop + std::shared_ptr count = trip_count_expr(parallelOp); + PTC[parallelOp] = count; - - - - - MemrefInductionCosts costs = build_parallel_trip_counts(parallelOp); - for (const auto &kv : costs) { - PTC[kv.first] = kv.second; - } + // descend into the body of the loop + ParallelTripCounts counts = build_parallel_trip_counts(parallelOp, count); } }); // walk return PTC; -#endif } -static ParallelTripCounts build_parallel_trip_counts(scf::ParallelOp &parentOp) { - +static ParallelTripCounts build_parallel_trip_counts(scf::ParallelOp &parentOp, std::shared_ptr cost) { ParallelTripCounts PTC; - return PTC; + parentOp.getBody()->walk([&](Operation *op) { + if (auto parallelOp = dyn_cast(op)) { + + // create an expression representing the trip count for this loop + std::shared_ptr count = trip_count_expr(parallelOp); + count = Mul::make(count, cost->clone()); + PTC[parallelOp] = count; + + // descend into the body of the loop + ParallelTripCounts counts = build_parallel_trip_counts(parallelOp, count); + } + }); + + return PTC; } // map of (Operation*, Value) -> Cost @@ -996,12 +1039,23 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, std::vec llvm::outs() << "====\ndump_ops\n====\n"; dump_ops(module); + llvm::outs() << "====\nbuild_parallel_trip_counts\n====\n"; + ParallelTripCounts tripCounts = build_parallel_trip_counts(module); + + for (auto &kv : tripCounts) { + const std::shared_ptr &trip = kv.second; + llvm::outs() << "parallel op: "; + kv.first.print(llvm::outs()); + llvm::outs() << " trip: "; + trip->dump(llvm::outs()); + llvm::outs() << "\n"; + } + llvm::outs() << "====\nbuild_cost_table\n====\n"; // FIXME: some helper function to tighten up `stack` scope std::vector stack; MemrefInductionCosts costTable = build_cost_table(module, stack); - llvm::outs() << "====\nmodel reordered induction vars\n====\n"; size_t minCost = std::numeric_limits::max(); ParallelConfig minCfg; diff --git a/mlir/test/Dialect/Kokkos/mdrange_2.mlir b/mlir/test/Dialect/Kokkos/mdrange_2.mlir new file mode 100644 index 00000000..af0d89ea --- /dev/null +++ b/mlir/test/Dialect/Kokkos/mdrange_2.mlir @@ -0,0 +1,13 @@ +module { + func.func @example_function(%arg0: memref<10x20xf32>, %arg1: memref, %loop_bound_i: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + scf.parallel (%i, %j) = (%c0, %c0) to (%loop_bound_i, %c2) step (%c1, %c1) { + %val = memref.load %arg0[%i, %j] : memref<10x20xf32> + memref.store %val, %arg1[%i, %j] : memref + } + return + } +} From 02273e69ac921bbfaf90ab69bc6a03ac9093a179 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 23 Jan 2025 14:38:11 -0700 Subject: [PATCH 13/37] Incorporate trip count into cost model --- .../Transforms/KokkosMdrangeIterationPass.cpp | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index c60a9291..07373398 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -743,7 +743,7 @@ static ParallelTripCounts build_parallel_trip_counts(scf::ParallelOp &parentOp, // map of the cost model for a given memref / induction variable pair using MemrefInductionCosts = VecMap, Cost>; -static MemrefInductionCosts build_cost_table(ModuleOp &mod, std::vector &stack) { +static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts &tripCounts, std::vector &stack) { MemrefInductionCosts MIC; @@ -751,7 +751,7 @@ static MemrefInductionCosts build_cost_table(ModuleOp &mod, std::vector(op)) { stack.push_back(parallelOp); - MemrefInductionCosts costs = build_cost_table(parallelOp, stack); + MemrefInductionCosts costs = build_cost_table(parallelOp, tripCounts, stack); stack.pop_back(); for (const auto &kv : costs) { MIC[kv.first] = kv.second; @@ -762,14 +762,14 @@ static MemrefInductionCosts build_cost_table(ModuleOp &mod, std::vector &stack) { +static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, ParallelTripCounts &tripCounts, std::vector &stack) { MemrefInductionCosts MIC; parentOp.getBody()->walk([&](Operation *op) { if (auto parallelOp = dyn_cast(op)) { stack.push_back(parallelOp); - MemrefInductionCosts costs = build_cost_table(parallelOp, stack); + MemrefInductionCosts costs = build_cost_table(parallelOp, tripCounts, stack); stack.pop_back(); for (const auto &kv : costs) { MIC[kv.first] = kv.second; @@ -823,8 +823,8 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, std::vec } llvm::outs() << "\n"; - // FIXME: compute trip count - MIC[std::make_pair(memrefOp, indVar)] = Cost(dodi, Constant::make(1), 1 /*load cost*/); + std::shared_ptr tripCount = tripCounts[parentOp]; + MIC[std::make_pair(memrefOp, indVar)] = Cost(dodi, tripCount, 1 /*load cost*/); } @@ -877,8 +877,8 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, std::vec } llvm::outs() << "\n"; - // FIXME: compute trip count - MIC[std::make_pair(memrefOp, indVar)] = Cost(dodi, Constant::make(1), 3 /*store cost*/); + std::shared_ptr tripCount = tripCounts[parentOp]; + MIC[std::make_pair(memrefOp, indVar)] = Cost(dodi, tripCount, 3 /*store cost*/); } } }); @@ -979,11 +979,11 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, std::vec Ctx ctx; for (auto &name : unknowns) { auto val = log_random_int(gen, 1, 1000000); - llvm::outs() << name << ": " << val << "\n"; + // llvm::outs() << name << ": " << val << "\n"; ctx.values[name] = val; } - costs.push_back(model.stride_->eval(ctx)); + costs.push_back(model.stride_->eval(ctx) * model.count_->eval(ctx)); } // FIXME: here we do median, is there a principled aggregation strategy? @@ -1054,7 +1054,7 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, std::vec llvm::outs() << "====\nbuild_cost_table\n====\n"; // FIXME: some helper function to tighten up `stack` scope std::vector stack; - MemrefInductionCosts costTable = build_cost_table(module, stack); + MemrefInductionCosts costTable = build_cost_table(module, tripCounts, stack); llvm::outs() << "====\nmodel reordered induction vars\n====\n"; size_t minCost = std::numeric_limits::max(); From 641a3c962cc7d601bb0439e479519340c3dd0853 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 09:40:59 -0700 Subject: [PATCH 14/37] use llvm::DenseMap for ParallelTripCounts Signed-off-by: Carl Pearson --- .../Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 07373398..8c4a12a3 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -698,7 +698,7 @@ static std::shared_ptr trip_count_expr(scf::ParallelOp &op) { return total; } -using ParallelTripCounts = VecMap>; +using ParallelTripCounts = llvm::DenseMap>; static ParallelTripCounts build_parallel_trip_counts(ModuleOp &mod) { From da5942a2de88b4e926e071973820e7d53fa150be Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 09:42:46 -0700 Subject: [PATCH 15/37] use llvm::DenseMap for MemrefInductionCosts Signed-off-by: Carl Pearson --- .../Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 8c4a12a3..e518ad96 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -741,7 +741,7 @@ static ParallelTripCounts build_parallel_trip_counts(scf::ParallelOp &parentOp, // map of (Operation*, Value) -> Cost // map of the cost model for a given memref / induction variable pair -using MemrefInductionCosts = VecMap, Cost>; +using MemrefInductionCosts = llvm::DenseMap, Cost>; static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts &tripCounts, std::vector &stack) { From 5090616044107c06155f319e9f5b6beff6940f8b Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 09:44:01 -0700 Subject: [PATCH 16/37] use llvm::DenseMap in ParallelConfig Signed-off-by: Carl Pearson --- .../Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index e518ad96..d4adbdfd 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -888,7 +888,7 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Parallel struct ParallelConfig { // permutation of induction variables for each parallel op - VecMap> perms_; + llvm::DenseMap> perms_; }; From cf6c5e7906a320d15c9aa42ad146dc28d56f6aa1 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 09:45:18 -0700 Subject: [PATCH 17/37] remove unused VecMap Signed-off-by: Carl Pearson --- .../Transforms/KokkosMdrangeIterationPass.cpp | 52 ------------------- 1 file changed, 52 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index d4adbdfd..69cf13aa 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -589,58 +589,6 @@ static std::vector all_induction_variables(std::vector & return vars; } -template -class VecMap { - -public: - using iterator = typename std::vector>::iterator; - using const_iterator = typename std::vector>::const_iterator; - -private: - std::vector> data_; - - -public: - // Find an iterator to a key-value pair by key - auto find(const key_type& key) const { - return std::find_if(data_.begin(), data_.end(), - [&key](const auto& pair) { return pair.first == key; }); - } - - auto find(const key_type& key) { - return std::find_if(data_.begin(), data_.end(), - [&key](const auto& pair) { return pair.first == key; }); - } - - // Access value by key without bounds checking - value_type& operator[](const key_type& key) { - auto it = find(key); - if (it != data_.end()) { - return it->second; - } else { - data_.emplace_back(key, value_type{}); - return data_.back().second; - } - } - - iterator begin() { - return data_.begin(); - } - - const_iterator begin() const { - return data_.begin(); - } - - iterator end() { - return data_.end(); - } - - const_iterator end() const { - return data_.end(); - } -}; - - // FIXME: this returns things like this: // of type 'index' at index: 2 static std::string get_value_name(mlir::Value &value) { From 4801b959b97eaea7081a78626cd58d41fb88c728 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 09:52:08 -0700 Subject: [PATCH 18/37] better stack scoping Signed-off-by: Carl Pearson --- .../Kokkos/Transforms/KokkosMdrangeIterationPass.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 69cf13aa..c1dbfaca 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -701,7 +701,7 @@ static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts & stack.push_back(parallelOp); MemrefInductionCosts costs = build_cost_table(parallelOp, tripCounts, stack); stack.pop_back(); - for (const auto &kv : costs) { + for (const auto &kv : costs) { // FIXME: insert? MIC[kv.first] = kv.second; } } @@ -710,6 +710,11 @@ static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts & return MIC; } +static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts &tripCounts) { + std::vector stack; + return build_cost_table(mod, tripCounts, stack); +} + static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, ParallelTripCounts &tripCounts, std::vector &stack) { MemrefInductionCosts MIC; @@ -1000,9 +1005,7 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Parallel } llvm::outs() << "====\nbuild_cost_table\n====\n"; - // FIXME: some helper function to tighten up `stack` scope - std::vector stack; - MemrefInductionCosts costTable = build_cost_table(module, tripCounts, stack); + MemrefInductionCosts costTable = build_cost_table(module, tripCounts); llvm::outs() << "====\nmodel reordered induction vars\n====\n"; size_t minCost = std::numeric_limits::max(); From 9f938d0bb91b4a68870427f1ddf93263d99153f5 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 10:09:29 -0700 Subject: [PATCH 19/37] factor out memref cost generation Signed-off-by: Carl Pearson --- .../Transforms/KokkosMdrangeIterationPass.cpp | 188 ++++++++---------- 1 file changed, 84 insertions(+), 104 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index c1dbfaca..556b16cf 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -405,6 +405,13 @@ struct KokkosMdrangeIterationPass for (auto &u : count_->unknowns()) ret.push_back(u); return ret; } + + template + static constexpr int scale_factor() { + static_assert(std::is_same_v || std::is_same_v); + if constexpr (std::is_same_v) return 1; + else if constexpr (std::is_same_v) return 3; + } }; // partial derivative df/dx @@ -715,8 +722,77 @@ static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts & return build_cost_table(mod, tripCounts, stack); } -static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, ParallelTripCounts &tripCounts, std::vector &stack) { +template +static MemrefInductionCosts get_costs(Memref &memrefOp, ParallelTripCounts &tripCounts, std::vector &stack) { + static_assert(std::is_same_v || std::is_same_v); + + if constexpr (std::is_same_v) { + llvm::outs() << "get_cost: memref::LoadOp\n"; + } else if constexpr (std::is_same_v) { + llvm::outs() << "get_cost: memref::StoreOp\n"; + } + + MemrefInductionCosts MIC; + std::vector indVars = all_induction_variables(stack); + if (stack.empty()) { + llvm::report_fatal_error("get_costs: memref is not enclosed in an scf::ParallelOp"); + } + scf::ParallelOp &parentOp = stack.back(); + + // compute the partial derivative of each memref with respect to all induction variables via the chain rule: + // d(offset)/d(indvar) = sum( + // d(offset)/d(index) * d(index)/d(indvar), + // for each index in indices) + + for (Value indVar : indVars) { + std::shared_ptr dodi = Constant::make(0); + for (Value indexVar : memrefOp.getIndices()) { + auto e1 = do_di(memrefOp, indexVar); + + llvm::outs() << "pd of " << memrefOp << " w.r.t " << indexVar << "\n"; + if (e1) { + e1->dump(llvm::outs()); + } else { + llvm::outs() << " undefined "; + } + llvm::outs() << "\n"; + + auto e2 = df_dx(indexVar, indVar); + + llvm::outs() << "pd of " << indexVar << " w.r.t " << indVar << "\n"; + if (e2) { + e2->dump(llvm::outs()); + } else { + llvm::outs() << " undefined "; + } + llvm::outs() << "\n"; + + if (e1 && e2) { + dodi = Add::make(dodi, Mul::make(e1, e2)); + } else { + dodi = nullptr; + break; + } + } + + llvm::outs() << "pd of " << memrefOp << " w.r.t " << indVar << "\n"; + if (dodi) { + dodi->dump(llvm::outs()); + } else { + llvm::outs() << " undefined "; + } + llvm::outs() << "\n"; + + std::shared_ptr tripCount = tripCounts[parentOp]; + + MIC[std::make_pair(memrefOp, indVar)] = Cost(dodi, tripCount, Cost::scale_factor()); + } + return MIC; +} + +// FIXME: parentOp is also the back of the stack? +static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, ParallelTripCounts &tripCounts, std::vector &stack) { MemrefInductionCosts MIC; parentOp.getBody()->walk([&](Operation *op) { @@ -724,114 +800,18 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Parallel stack.push_back(parallelOp); MemrefInductionCosts costs = build_cost_table(parallelOp, tripCounts, stack); stack.pop_back(); - for (const auto &kv : costs) { + for (const auto &kv : costs) { // FIXME: insert? MIC[kv.first] = kv.second; } } else if (auto memrefOp = dyn_cast(op)) { - llvm::outs() << "nested memref load!\n"; - - std::vector indVars = all_induction_variables(stack); - - // compute the partial derivative of each memref with respect to all induction variables via the chain rule: - // d(offset)/d(indvar) = sum( - // d(offset)/d(index) * d(index)/d(indvar), - // for each index in indices) - - for (Value indVar : indVars) { - std::shared_ptr dodi = Constant::make(0); - for (Value indexVar : memrefOp.getIndices()) { - auto e1 = do_di(memrefOp, indexVar); - - llvm::outs() << "pd of " << memrefOp << " w.r.t " << indexVar << "\n"; - if (e1) { - e1->dump(llvm::outs()); - } else { - llvm::outs() << " undefined "; - } - llvm::outs() << "\n"; - - auto e2 = df_dx(indexVar, indVar); - - llvm::outs() << "pd of " << indexVar << " w.r.t " << indVar << "\n"; - if (e2) { - e2->dump(llvm::outs()); - } else { - llvm::outs() << " undefined "; - } - llvm::outs() << "\n"; - - if (e1 && e2) { - dodi = Add::make(dodi, Mul::make(e1, e2)); - } else { - dodi = nullptr; - break; - } - } - - llvm::outs() << "pd of " << memrefOp << " w.r.t " << indVar << "\n"; - if (dodi) { - dodi->dump(llvm::outs()); - } else { - llvm::outs() << " undefined "; - } - llvm::outs() << "\n"; - - std::shared_ptr tripCount = tripCounts[parentOp]; - MIC[std::make_pair(memrefOp, indVar)] = Cost(dodi, tripCount, 1 /*load cost*/); + MemrefInductionCosts costs = get_costs(memrefOp, tripCounts, stack); + for (const auto &kv : costs) { // FIXME: insert? + MIC[kv.first] = kv.second; } - - } else if (auto memrefOp = dyn_cast(op)) { - llvm::outs() << "nested memref store!\n"; - - std::vector indVars = all_induction_variables(stack); - - // compute the partial derivative of each memref with respect to all induction variables via the chain rule: - // d(offset)/d(indvar) = sum( - // d(offset)/d(index) * d(index)/d(indvar), - // for each index in indices) - - for (Value indVar : indVars) { - std::shared_ptr dodi = Constant::make(0); - for (Value indexVar : memrefOp.getIndices()) { - auto e1 = do_di(memrefOp, indexVar); - - llvm::outs() << "pd of " << memrefOp << " w.r.t " << indexVar << "\n"; - if (e1) { - e1->dump(llvm::outs()); - } else { - llvm::outs() << " undefined "; - } - llvm::outs() << "\n"; - - auto e2 = df_dx(indexVar, indVar); - - llvm::outs() << "pd of " << indexVar << " w.r.t " << indVar << "\n"; - if (e2) { - e2->dump(llvm::outs()); - } else { - llvm::outs() << " undefined "; - } - llvm::outs() << "\n"; - - if (e1 && e2) { - dodi = Add::make(dodi, Mul::make(e1, e2)); - } else { - dodi = nullptr; - break; - } - } - - llvm::outs() << "pd of " << memrefOp << " w.r.t " << indVar << "\n"; - if (dodi) { - dodi->dump(llvm::outs()); - } else { - llvm::outs() << " undefined "; - } - llvm::outs() << "\n"; - - std::shared_ptr tripCount = tripCounts[parentOp]; - MIC[std::make_pair(memrefOp, indVar)] = Cost(dodi, tripCount, 3 /*store cost*/); + MemrefInductionCosts costs = get_costs(memrefOp, tripCounts, stack); + for (const auto &kv : costs) { // FIXME: insert? + MIC[kv.first] = kv.second; } } }); From e5f777a197381d6f68d2ec1e0c30a56d6d15b513 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 10:12:40 -0700 Subject: [PATCH 20/37] replace loop with DenseMap::insert Signed-off-by: Carl Pearson --- .../Transforms/KokkosMdrangeIterationPass.cpp | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 556b16cf..9e4e749f 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -708,9 +708,7 @@ static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts & stack.push_back(parallelOp); MemrefInductionCosts costs = build_cost_table(parallelOp, tripCounts, stack); stack.pop_back(); - for (const auto &kv : costs) { // FIXME: insert? - MIC[kv.first] = kv.second; - } + MIC.insert(costs.begin(), costs.end()); } }); // walk @@ -800,19 +798,13 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Parallel stack.push_back(parallelOp); MemrefInductionCosts costs = build_cost_table(parallelOp, tripCounts, stack); stack.pop_back(); - for (const auto &kv : costs) { // FIXME: insert? - MIC[kv.first] = kv.second; - } + MIC.insert(costs.begin(), costs.end()); } else if (auto memrefOp = dyn_cast(op)) { MemrefInductionCosts costs = get_costs(memrefOp, tripCounts, stack); - for (const auto &kv : costs) { // FIXME: insert? - MIC[kv.first] = kv.second; - } + MIC.insert(costs.begin(), costs.end()); } else if (auto memrefOp = dyn_cast(op)) { MemrefInductionCosts costs = get_costs(memrefOp, tripCounts, stack); - for (const auto &kv : costs) { // FIXME: insert? - MIC[kv.first] = kv.second; - } + MIC.insert(costs.begin(), costs.end()); } }); From 856f9049e443cb85611c12645afaadaafb40a929 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 10:41:10 -0700 Subject: [PATCH 21/37] Improve names & comments, SmallVector for parallel op stack Signed-off-by: Carl Pearson --- .../Transforms/KokkosMdrangeIterationPass.cpp | 91 ++++++++++--------- 1 file changed, 48 insertions(+), 43 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 9e4e749f..84091582 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -147,6 +147,8 @@ struct KokkosMdrangeIterationPass // a context for expression evaluation struct Ctx { + + // FIXME: llvm data structures std::unordered_map values; }; @@ -414,8 +416,7 @@ struct KokkosMdrangeIterationPass } }; - // partial derivative df/dx - +// partial derivative df/dx static std::shared_ptr df_dx(Value &f, Value &x) { if (f == x) { llvm::outs() << "Info: df_dx of equal values\n"; @@ -585,21 +586,10 @@ static std::shared_ptr df_dx(Value &f, Value &x) { } -// return groups of induction variables for -static std::vector all_induction_variables(std::vector &ops) { - std::vector vars; - for (auto &op : ops) { - for (auto &var : op.getInductionVars()) { - vars.push_back(var); - } - } - return vars; -} -// FIXME: this returns things like this: -// of type 'index' at index: 2 -static std::string get_value_name(mlir::Value &value) { +// Get a unique name for the provided value +static std::string get_value_name(mlir::Value &value) { if (mlir::isa(value)) { auto ba = mlir::cast(value); return std::string("block") +std::to_string(uintptr_t(ba.getOwner())) + "_arg" + std::to_string(ba.getArgNumber()); @@ -608,8 +598,9 @@ static std::string get_value_name(mlir::Value &value) { } } - -static std::shared_ptr iteration_space_size(scf::ParallelOp &op, int dim) { +// Get an expression representing the size of the iteration space of `op` in the +// `dim` dimension. +static std::shared_ptr iteration_space_expr(scf::ParallelOp &op, int dim) { auto lb = op.getLowerBound()[dim]; auto ub = op.getUpperBound()[dim]; @@ -643,62 +634,74 @@ static std::shared_ptr iteration_space_size(scf::ParallelOp &op, int dim) return Div::make(num, stExpr); } -// return an Expr representing the product of the iteration space of all dimensions -static std::shared_ptr trip_count_expr(scf::ParallelOp &op) { +// Get an expression representing the size of the iteration space of `op` +static std::shared_ptr iteration_space_expr(scf::ParallelOp &op) { auto lowerBounds = op.getLowerBound(); - std::shared_ptr total = iteration_space_size(op, 0); + std::shared_ptr total = iteration_space_expr(op, 0); for (unsigned i = 1; i < lowerBounds.size(); ++i) { - total = Mul::make(total, iteration_space_size(op, i)); + total = Mul::make(total, iteration_space_expr(op, i)); } return total; } -using ParallelTripCounts = llvm::DenseMap>; +using IterationSpaceExprs = llvm::DenseMap>; -static ParallelTripCounts build_parallel_trip_counts(ModuleOp &mod) { +// Get expressions represeting the iteration space for all parallel loops in the module +static IterationSpaceExprs build_parallel_trip_counts(ModuleOp &mod) { - ParallelTripCounts PTC; + IterationSpaceExprs ISE; mod.walk([&](Operation *op) { if (auto parallelOp = dyn_cast(op)) { - // create an expression representing the trip count for this loop - std::shared_ptr count = trip_count_expr(parallelOp); - PTC[parallelOp] = count; + std::shared_ptr expr = iteration_space_expr(parallelOp); + ISE[parallelOp] = expr; // descend into the body of the loop - ParallelTripCounts counts = build_parallel_trip_counts(parallelOp, count); + IterationSpaceExprs exprs = build_parallel_trip_counts(parallelOp, expr); + ISE.insert(exprs.begin(), exprs.end()); } }); // walk - return PTC; + return ISE; } -static ParallelTripCounts build_parallel_trip_counts(scf::ParallelOp &parentOp, std::shared_ptr cost) { - ParallelTripCounts PTC; +static IterationSpaceExprs build_parallel_trip_counts(scf::ParallelOp &parentOp, std::shared_ptr cost) { + IterationSpaceExprs ISE; parentOp.getBody()->walk([&](Operation *op) { if (auto parallelOp = dyn_cast(op)) { - // create an expression representing the trip count for this loop - std::shared_ptr count = trip_count_expr(parallelOp); - count = Mul::make(count, cost->clone()); - PTC[parallelOp] = count; + std::shared_ptr expr = iteration_space_expr(parallelOp); + ISE[parallelOp] = expr; // descend into the body of the loop - ParallelTripCounts counts = build_parallel_trip_counts(parallelOp, count); + IterationSpaceExprs exprs = build_parallel_trip_counts(parallelOp, expr); + ISE.insert(exprs.begin(), exprs.end()); } }); - return PTC; + return ISE; } // map of (Operation*, Value) -> Cost // map of the cost model for a given memref / induction variable pair using MemrefInductionCosts = llvm::DenseMap, Cost>; +using ParallelOpStack = llvm::SmallVector; -static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts &tripCounts, std::vector &stack) { +// return all induction variables for all parallel ops +static std::vector all_induction_variables(ParallelOpStack &ops) { + std::vector vars; + for (auto &op : ops) { + for (auto &var : op.getInductionVars()) { + vars.push_back(var); + } + } + return vars; +} + +static MemrefInductionCosts build_cost_table(ModuleOp &mod, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) { MemrefInductionCosts MIC; @@ -715,13 +718,15 @@ static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts & return MIC; } -static MemrefInductionCosts build_cost_table(ModuleOp &mod, ParallelTripCounts &tripCounts) { - std::vector stack; + + +static MemrefInductionCosts build_cost_table(ModuleOp &mod, IterationSpaceExprs &tripCounts) { + ParallelOpStack stack; return build_cost_table(mod, tripCounts, stack); } template -static MemrefInductionCosts get_costs(Memref &memrefOp, ParallelTripCounts &tripCounts, std::vector &stack) { +static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) { static_assert(std::is_same_v || std::is_same_v); if constexpr (std::is_same_v) { @@ -790,7 +795,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, ParallelTripCounts &trip } // FIXME: parentOp is also the back of the stack? -static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, ParallelTripCounts &tripCounts, std::vector &stack) { +static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) { MemrefInductionCosts MIC; parentOp.getBody()->walk([&](Operation *op) { @@ -965,7 +970,7 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Parallel dump_ops(module); llvm::outs() << "====\nbuild_parallel_trip_counts\n====\n"; - ParallelTripCounts tripCounts = build_parallel_trip_counts(module); + IterationSpaceExprs tripCounts = build_parallel_trip_counts(module); for (auto &kv : tripCounts) { const std::shared_ptr &trip = kv.second; From 187a0c3b7bb42e2a8478ecf66b2e0793272dd8fe Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 10:47:35 -0700 Subject: [PATCH 22/37] clone and replace module ops --- .../Kokkos/Transforms/KokkosMdrangeIterationPass.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 84091582..2c216d67 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -1012,6 +1012,16 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio llvm::outs() << "====\nbuild new module\n====\n"; + // clone the existing module + ModuleOp newModule = module.clone(); + + // TODO: modify the parallel ops in the new module + + // overwrite the module with the new module + // Replace the original module with the new module. + module.getBody()->getOperations().clear(); + module.getBody()->getOperations().splice(module.getBody()->begin(), + newModule.getBody()->getOperations()); llvm::outs() << "====\ndone\n====\n"; } }; From fa9a22b3c4b357ef44e3993f19307337a9e0baac Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 12:21:52 -0700 Subject: [PATCH 23/37] fix redundant scf.reduce, permute scf parallel --- .../Transforms/KokkosMdrangeIterationPass.cpp | 77 +++++++++++++++++++ mlir/test/Dialect/Kokkos/mdrange_1.mlir | 1 + 2 files changed, 78 insertions(+) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 2c216d67..27b7000d 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -964,8 +964,40 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio return cost; } + using Permutation = llvm::SmallVector; + + // modify `parallelOp` so that its induction variables are permuted according to `permutation` + static void permute_parallel_op(scf::ParallelOp parallelOp, const Permutation &permutation) { + OpBuilder builder(parallelOp); + SmallVector newLowerBounds, newUpperBounds, newSteps; + + for (int index : permutation) { + newLowerBounds.push_back(parallelOp.getLowerBound()[index]); + newUpperBounds.push_back(parallelOp.getUpperBound()[index]); + newSteps.push_back(parallelOp.getStep()[index]); + } + + auto newParallelOp = builder.create( + parallelOp.getLoc(), newLowerBounds, newUpperBounds, newSteps); + + // Move the body of the original parallelOp to the new parallelOp. + newParallelOp.getBody()->getTerminator()->erase(); // splicing in the new body has a terminator already + newParallelOp.getBody()->getOperations().splice( + newParallelOp.getBody()->begin(), parallelOp.getBody()->getOperations()); + + // replace uses of original induction variable perm[i] with new induction variable [i] + for (size_t i = 0; i < permutation.size(); ++i) { + parallelOp.getInductionVars()[permutation[i]].replaceAllUsesWith(newParallelOp.getInductionVars()[i]); + } + + parallelOp.erase(); + } + void runOnOperation() override { ModuleOp module = getOperation(); + + llvm::outs() << module << "\n"; + llvm::outs() << "====\ndump_ops\n====\n"; dump_ops(module); @@ -1012,16 +1044,61 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio llvm::outs() << "====\nbuild new module\n====\n"; +#if 0 // clone the existing module ModuleOp newModule = module.clone(); // TODO: modify the parallel ops in the new module + newModule.walk([&](scf::ParallelOp parallelOp) { + + llvm::outs() << "modifying " << parallelOp << "\n"; + + // TODO: replace this placeholder permutation with the computed one + // fake permutation that just reverses stuff + Permutation permutation(parallelOp.getInductionVars().size()); + std::iota(permutation.begin(), permutation.end(), 0); + std::reverse(permutation.begin(), permutation.end()); + + llvm::outs() << "applying permutation "; + for (auto i : permutation) { + llvm::outs() << i << " "; + } + llvm::outs() << "\n"; + + permute_parallel_op(parallelOp, permutation); + }); + + + // FIXME: this seems like it might introduce an extra scf.reduce at the end + // of the parallel region, probably because it clones one and then one gets inserted + // --mlir-print-ir-after-failure // overwrite the module with the new module // Replace the original module with the new module. module.getBody()->getOperations().clear(); module.getBody()->getOperations().splice(module.getBody()->begin(), newModule.getBody()->getOperations()); +#else + // modify the parallel ops in the module + module.walk([&](scf::ParallelOp parallelOp) { + + llvm::outs() << "modifying " << parallelOp << "\n"; + + // TODO: replace this placeholder permutation with the computed one + // fake permutation that just reverses stuff + Permutation permutation(parallelOp.getInductionVars().size()); + std::iota(permutation.begin(), permutation.end(), 0); + std::reverse(permutation.begin(), permutation.end()); + + llvm::outs() << "applying permutation "; + for (auto i : permutation) { + llvm::outs() << i << " "; + } + llvm::outs() << "\n"; + + permute_parallel_op(parallelOp, permutation); + }); +#endif llvm::outs() << "====\ndone\n====\n"; } }; diff --git a/mlir/test/Dialect/Kokkos/mdrange_1.mlir b/mlir/test/Dialect/Kokkos/mdrange_1.mlir index 53259e0e..e9443098 100644 --- a/mlir/test/Dialect/Kokkos/mdrange_1.mlir +++ b/mlir/test/Dialect/Kokkos/mdrange_1.mlir @@ -8,6 +8,7 @@ module { scf.parallel (%i, %j) = (%c0, %c0) to (%c10, %c20) step (%c1, %c1) { %val = memref.load %arg0[%i, %j] : memref<10x20xf32> memref.store %val, %arg1[%i, %j] : memref + scf.reduce } return } From e878b59917327a49a4d438cd74fba4170491feaf Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 12:29:59 -0700 Subject: [PATCH 24/37] add nested parallel test Signed-off-by: Carl Pearson --- mlir/test/Dialect/Kokkos/mdrange_3.mlir | 29 +++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 mlir/test/Dialect/Kokkos/mdrange_3.mlir diff --git a/mlir/test/Dialect/Kokkos/mdrange_3.mlir b/mlir/test/Dialect/Kokkos/mdrange_3.mlir new file mode 100644 index 00000000..a45f4ff1 --- /dev/null +++ b/mlir/test/Dialect/Kokkos/mdrange_3.mlir @@ -0,0 +1,29 @@ +module { + func.func @nested_parallel(%arg0: index, %arg1: index, %arg2: index) { + %0 = memref.alloc(%arg0, %arg1) : memref + %1 = memref.alloc(%arg1, %arg2) : memref + %2 = memref.alloc(%arg0, %arg2) : memref + + %3 = arith.addi %arg0, %arg1 : index + %4 = arith.addi %arg1, %arg2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c10 = arith.constant 10 : index + // Outer parallel loop with dynamic bounds and stride + scf.parallel (%i, %j) = (%arg0, %arg1) to (%3, %4) step (%arg1, %c2) { + // Inner parallel loop with constant bounds and stride + scf.parallel (%k) = (%c0) to (%c10) step (%c1) { + %val0 = memref.load %0[%i, %j] : memref + %val1 = memref.load %1[%j, %k] : memref + %result = arith.addf %val0, %val1 : f32 + memref.store %result, %2[%i, %k] : memref + } + } + + memref.dealloc %0 : memref + memref.dealloc %1 : memref + memref.dealloc %2 : memref + return + } +} \ No newline at end of file From 60fc12e663d8a61ccd6477aaac68a03bcc444df7 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 14:34:27 -0700 Subject: [PATCH 25/37] Improve naming, use SmallVector in ParallelConfig --- .../Transforms/KokkosMdrangeIterationPass.cpp | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 27b7000d..6042d064 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -512,7 +512,7 @@ static std::shared_ptr df_dx(Value &f, Value &x) { auto memrefShape = memrefType.getShape(); for (int dim = 0; dim < indexVarDim; ++dim) { if (memrefShape[dim] == ShapedType::kDynamic) { - std::string name = memrefOp.getOperation()->getName().getStringRef().str() + "_extent" + std::to_string(dim); // FIXME: unique name for each memref dimension + std::string name = std::string("memref") + std::to_string(uintptr_t(memrefOp.getOperation())) + "_extent" + std::to_string(dim); res = std::make_shared(res, Unknown::make(name)); } else { res = std::make_shared(res, std::make_shared(memrefShape[dim])); @@ -594,7 +594,15 @@ static std::string get_value_name(mlir::Value &value) { auto ba = mlir::cast(value); return std::string("block") +std::to_string(uintptr_t(ba.getOwner())) + "_arg" + std::to_string(ba.getArgNumber()); } else { - return value.getDefiningOp()->getName().getStringRef().str(); + // mlir::Operation *op = value.getDefiningOp(); + + std::string name; + llvm::raw_string_ostream os(name); + value.print(os); + name = name.substr(0, name.find(' ')); // "%blah = ..." -> "%blah" + + // std::string name = op->getName().getStringRef().str() + std::to_string(uintptr_t(op)); + return name; } } @@ -816,9 +824,11 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio return MIC; } + using Permutation = llvm::SmallVector; + struct ParallelConfig { // permutation of induction variables for each parallel op - llvm::DenseMap> perms_; + llvm::DenseMap perms_; }; @@ -834,7 +844,7 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio found = true; // walk all configurations of this parallel op too - std::vector perm(get_num_induction_vars(parallelOp)); + Permutation perm(get_num_induction_vars(parallelOp)); std::iota(perm.begin(), perm.end(), 0); do { cfg.perms_[parallelOp] = perm; @@ -854,7 +864,7 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio mod.walk([&](Operation *op) { if (auto parallelOp = dyn_cast(op)) { - std::vector perm(get_num_induction_vars(parallelOp)); + Permutation perm(get_num_induction_vars(parallelOp)); std::iota(perm.begin(), perm.end(), 0); do { @@ -941,7 +951,7 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio if (auto it = cfg.perms_.find(parallelOp); it != cfg.perms_.end()) { llvm::outs() << "found perm for memref's parent parallelOp in config\n"; - const std::vector &perm = it->second; + const Permutation &perm = it->second; Value rightMostVar = parallelOp.getInductionVars()[perm[perm.size() - 1]]; llvm::outs() << "right-most induction var is "; @@ -964,7 +974,7 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio return cost; } - using Permutation = llvm::SmallVector; + // modify `parallelOp` so that its induction variables are permuted according to `permutation` static void permute_parallel_op(scf::ParallelOp parallelOp, const Permutation &permutation) { From 67196bee62b29a2e0ecc84e8517126ea739d40d6 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 15:44:29 -0700 Subject: [PATCH 26/37] Fix walk_configurations --- .../Transforms/KokkosMdrangeIterationPass.cpp | 78 ++++++++++++++++--- 1 file changed, 68 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 6042d064..23be84c3 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -839,6 +839,19 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio template void walk_configurations(scf::ParallelOp &parentOp, ParallelConfig cfg, Lambda &&f) { bool found = false; +#if 1 + for (scf::ParallelOp parallelOp : parentOp.getBody()->getOps()) { + found = true; + + // walk all configurations of this parallel op too + Permutation perm(get_num_induction_vars(parallelOp)); + std::iota(perm.begin(), perm.end(), 0); + do { + cfg.perms_[parallelOp] = perm; + walk_configurations(parallelOp, cfg, std::forward(f)); + } while (std::next_permutation(perm.begin(), perm.end())); + } +#else parentOp.getBody()->walk([&](Operation *op) { if (auto parallelOp = dyn_cast(op)) { found = true; @@ -852,28 +865,60 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio } while (std::next_permutation(perm.begin(), perm.end())); } }); // walk - +#endif // no nested parallel regions, no more configurations to go through, call f if (!found) { + llvm::outs() << "reached fully-nested configuration\n"; f(cfg); } } template - void walk_configurations(ModuleOp &mod, Lambda &&f) { - mod.walk([&](Operation *op) { - if (auto parallelOp = dyn_cast(op)) { + void walk_configurations(mlir::Operation *op, Lambda &&f) { + ParallelConfig cfg; + walk_configurations(op, std::forward(f), cfg); + } + // return true if op, or any of its nested children, were scf parallel + template + bool walk_configurations(mlir::Operation *op, Lambda &&f, const ParallelConfig &cfg) { + if (auto parallelOp = dyn_cast(op)) { + // create a permutation of induction variables Permutation perm(get_num_induction_vars(parallelOp)); std::iota(perm.begin(), perm.end(), 0); + // walk the children of this op with all different induction variable configurations do { - ParallelConfig cfg; - cfg.perms_[parallelOp] = perm; - walk_configurations(parallelOp, cfg, f); + ParallelConfig newCfg = cfg; + newCfg.perms_[parallelOp] = perm; + + bool anyParallel = false; + for (mlir::Region ®ion : op->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (mlir::Operation &nestedOp : block.getOperations()) { + anyParallel |= walk_configurations(&nestedOp, std::forward(f), newCfg); + } + } + } + + if (!anyParallel) { + llvm::outs() << "no parallel regions nested below this...\n" << *op + << "\n...invoking callable on ParallelConfig of " << newCfg.perms_.size() << "regions\n"; + f(newCfg); + } } while (std::next_permutation(perm.begin(), perm.end())); + return true; + } else { + bool anyParallel = false; + for (mlir::Region ®ion : op->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (mlir::Operation &nestedOp : block.getOperations()) { + anyParallel |= walk_configurations(&nestedOp, std::forward(f), cfg); + } + } } - }); // walk + return anyParallel; + } } // model the cost of a module with a given parallel configuration @@ -1044,7 +1089,17 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio size_t cost = model_cost(module, cfg, costTable); llvm::outs() << "cost was " << cost << "\n"; if (cost < minCost) { - llvm::outs() << "Info: new optimal! " << cost << "\n"; + llvm::outs() << "Info: new optimal! cost=" << cost << "\n"; + + for (const auto &kv : cfg.perms_) { + llvm::outs() << kv.first << " with permutation: "; + for (const size_t e : kv.second) { + llvm::outs() << e << " "; + + } + llvm::outs() << "\n"; + } + minCost = cost; minCfg = cfg; } @@ -1094,12 +1149,15 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio llvm::outs() << "modifying " << parallelOp << "\n"; +#if 1 + const Permutation &permutation = minCfg.perms_[parallelOp]; +#else // TODO: replace this placeholder permutation with the computed one // fake permutation that just reverses stuff Permutation permutation(parallelOp.getInductionVars().size()); std::iota(permutation.begin(), permutation.end(), 0); std::reverse(permutation.begin(), permutation.end()); - +#endif llvm::outs() << "applying permutation "; for (auto i : permutation) { llvm::outs() << i << " "; From 757ab3077bcb4cee7abae179fc45fee1f9125d3c Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 15:58:21 -0700 Subject: [PATCH 27/37] helper fn for iterating over nested ops --- .../Transforms/KokkosMdrangeIterationPass.cpp | 43 ++++++++++++------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 23be84c3..826d192a 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -879,6 +879,22 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio walk_configurations(op, std::forward(f), cfg); } + // call f on every operation immediately nested under op + template + void for_each_nested(mlir::Operation *op, Lambda &&f) { + + // assert f can be called on mlir::Operation* + static_assert(std::is_invocable_v); + + for (mlir::Region ®ion : op->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (mlir::Operation &nestedOp : block.getOperations()) { + f(&nestedOp); + } + } + } + } + // return true if op, or any of its nested children, were scf parallel template bool walk_configurations(mlir::Operation *op, Lambda &&f, const ParallelConfig &cfg) { @@ -893,35 +909,30 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio newCfg.perms_[parallelOp] = perm; bool anyParallel = false; - for (mlir::Region ®ion : op->getRegions()) { - for (mlir::Block &block : region.getBlocks()) { - for (mlir::Operation &nestedOp : block.getOperations()) { - anyParallel |= walk_configurations(&nestedOp, std::forward(f), newCfg); - } - } - } + for_each_nested(op, [&](mlir::Operation *child){ + anyParallel |= walk_configurations(child, std::forward(f), newCfg); + }); + // if no parallels are nested underneath here, this is a complete + // configuration if (!anyParallel) { llvm::outs() << "no parallel regions nested below this...\n" << *op - << "\n...invoking callable on ParallelConfig of " << newCfg.perms_.size() << "regions\n"; + << "\n...invoking callable on ParallelConfig of " << newCfg.perms_.size() << " regions\n"; f(newCfg); } } while (std::next_permutation(perm.begin(), perm.end())); return true; } else { bool anyParallel = false; - for (mlir::Region ®ion : op->getRegions()) { - for (mlir::Block &block : region.getBlocks()) { - for (mlir::Operation &nestedOp : block.getOperations()) { - anyParallel |= walk_configurations(&nestedOp, std::forward(f), cfg); - } - } - } + for_each_nested(op, [&](mlir::Operation *child){ + anyParallel |= walk_configurations(child, std::forward(f), cfg); + }); return anyParallel; } } - // model the cost of a module with a given parallel configuration + // FIXME: this should be re-written recursively + // cost of a module with a given parallel configuration static size_t model_cost(ModuleOp &mod, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { size_t cost = 0; mod.walk([&](Operation *op) { From 1ac37d7dee19f6a033c7995d84a46f39b2ceadce Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 24 Jan 2025 16:15:26 -0700 Subject: [PATCH 28/37] remove redundant modeling calls --- .../Transforms/KokkosMdrangeIterationPass.cpp | 77 ++++++++----------- 1 file changed, 31 insertions(+), 46 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 826d192a..fb5a994a 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -931,36 +931,18 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio } } - // FIXME: this should be re-written recursively // cost of a module with a given parallel configuration static size_t model_cost(ModuleOp &mod, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { size_t cost = 0; mod.walk([&](Operation *op) { - if (auto parallelOp = dyn_cast(op)) { - cost += model_cost(parallelOp, cfg, costTable); - } - }); // walk - return cost; - } - - // model the cost of a parallel operation with a given config - static size_t model_cost(scf::ParallelOp &parentOp, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { - - size_t cost = 0; - - parentOp.getBody()->walk([&](Operation *op) { - if (auto parallelOp = dyn_cast(op)) { - cost += model_cost(parallelOp, cfg, costTable); - } else if (auto memrefOp = dyn_cast(op)) { - cost += model_cost(parentOp, memrefOp, cfg, costTable); + if (auto memrefOp = dyn_cast(op)) { + cost += model_cost(memrefOp, cfg, costTable); } else if (auto memrefOp = dyn_cast(op)) { - cost += model_cost(parentOp, memrefOp, cfg, costTable); + cost += model_cost(memrefOp, cfg, costTable); } - }); - + }); // walk return cost; } - static size_t monte_carlo(const Cost &model, int n = 100, int seed = 31337) { std::mt19937 gen(seed); @@ -996,38 +978,41 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio } template - static size_t model_cost(scf::ParallelOp ¶llelOp, MemrefOp &memrefOp, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { + static size_t model_cost(MemrefOp &memrefOp, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { static_assert(std::is_same_v || std::is_same_v); - - llvm::outs() << "model cost of "; - memrefOp.print(llvm::outs()); - llvm::outs() << "\n"; + llvm::outs() << "model cost of " << memrefOp << "...\n"; - if (auto it = cfg.perms_.find(parallelOp); it != cfg.perms_.end()) { - llvm::outs() << "found perm for memref's parent parallelOp in config\n"; + auto parentOp = memrefOp.getOperation()->getParentOp(); + if (auto parallelOp = dyn_cast(parentOp)) { + if (auto it = cfg.perms_.find(parallelOp); it != cfg.perms_.end()) { - const Permutation &perm = it->second; - Value rightMostVar = parallelOp.getInductionVars()[perm[perm.size() - 1]]; + const Permutation &perm = it->second; + Value rightMostVar = parallelOp.getInductionVars()[perm[perm.size() - 1]]; - llvm::outs() << "right-most induction var is "; - rightMostVar.print(llvm::outs()); - llvm::outs() << "\n"; - - - // FIXME: why does this work? the table should expect key to be pair not pair - auto costKey = std::make_pair(memrefOp, rightMostVar); - if (auto jt = costTable.find(costKey); jt != costTable.end()) { - llvm::outs() << "found cost model in table\n"; + llvm::outs() << "under permutation, right-most enclosing induction var is " << rightMostVar << "\n"; + - Cost model = jt->second; - return monte_carlo(model); + // FIXME: why does this work? the table should expect key to be pair not pair + auto costKey = std::make_pair(memrefOp, rightMostVar); + if (auto jt = costTable.find(costKey); jt != costTable.end()) { + Cost model = jt->second; + size_t cost = monte_carlo(model); + llvm::outs() << "..." << memrefOp << " contributes " << cost << "\n"; + return cost; + } else { + llvm::outs() << "couldn't find model for memref / induction variable combo\n"; + return 0; + } + + } else { + llvm::outs() << "couldn't find permutation for parent parallel op\n"; + return 0; } - + } else { + llvm::outs() << "memrefOp " << memrefOp << " has no parallel parent\n"; + return 0; } - - size_t cost = 0; - return cost; } From ed1602efce2c71084fa5368d9cb0db80739c69a6 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 15 May 2025 14:01:22 -0600 Subject: [PATCH 29/37] incorporate enclosing parallel trip count --- .../Transforms/KokkosMdrangeIterationPass.cpp | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index fb5a994a..0a592c38 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -241,11 +241,11 @@ struct KokkosMdrangeIterationPass auto lhs_const = llvm::dyn_cast(lhs.get()); auto rhs_const = llvm::dyn_cast(rhs.get()); - if (lhs_const && lhs_const->value_ == 0) { + if (lhs_const && lhs_const->value_ == 0) { // 0 - x --> -x return Mul::make(rhs, Constant::make(-1)); - } else if (rhs_const && rhs_const->value_ == 0) { + } else if (rhs_const && rhs_const->value_ == 0) { // x - 0 --> x return lhs; - } else if (rhs_const && lhs_const) { + } else if (rhs_const && lhs_const) { // c1 - c2 --> c3 return Constant::make(lhs_const->value_ - rhs_const->value_); } @@ -639,7 +639,12 @@ static std::shared_ptr iteration_space_expr(scf::ParallelOp &op, int dim) // (ub - lb + step - 1) / step // TODO: this could be a special DivCeil operation or something auto num = Add::make(Sub::make(ubExpr, lbExpr), Sub::make(stExpr, Constant::make(1))); - return Div::make(num, stExpr); + auto ret = Div::make(num, stExpr); + + llvm::outs() << "Trip count of " << op << " dim " << dim << " = "; + ret->dump(llvm::outs()); + llvm::outs() << "\n"; + return ret; } // Get an expression representing the size of the iteration space of `op` @@ -675,14 +680,14 @@ static IterationSpaceExprs build_parallel_trip_counts(ModuleOp &mod) { return ISE; } -static IterationSpaceExprs build_parallel_trip_counts(scf::ParallelOp &parentOp, std::shared_ptr cost) { +static IterationSpaceExprs build_parallel_trip_counts(scf::ParallelOp &parentOp, std::shared_ptr parentCount) { IterationSpaceExprs ISE; parentOp.getBody()->walk([&](Operation *op) { if (auto parallelOp = dyn_cast(op)) { // create an expression representing the trip count for this loop std::shared_ptr expr = iteration_space_expr(parallelOp); - ISE[parallelOp] = expr; + ISE[parallelOp] = Mul::make(expr, parentCount); // descend into the body of the loop IterationSpaceExprs exprs = build_parallel_trip_counts(parallelOp, expr); @@ -733,6 +738,11 @@ static MemrefInductionCosts build_cost_table(ModuleOp &mod, IterationSpaceExprs return build_cost_table(mod, tripCounts, stack); } + + // compute the partial derivative of each memref with respect to all enclosing induction variables via the chain rule: + // d(offset)/d(indvar) = sum( + // d(offset)/d(index) * d(index)/d(indvar), + // for each index in indices) template static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) { static_assert(std::is_same_v || std::is_same_v); @@ -751,31 +761,26 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri } scf::ParallelOp &parentOp = stack.back(); - // compute the partial derivative of each memref with respect to all induction variables via the chain rule: - // d(offset)/d(indvar) = sum( - // d(offset)/d(index) * d(index)/d(indvar), - // for each index in indices) - for (Value indVar : indVars) { std::shared_ptr dodi = Constant::make(0); for (Value indexVar : memrefOp.getIndices()) { auto e1 = do_di(memrefOp, indexVar); - llvm::outs() << "pd of " << memrefOp << " w.r.t " << indexVar << "\n"; + llvm::outs() << "∂(" << memrefOp << ")/∂(" << indexVar << ") = "; if (e1) { e1->dump(llvm::outs()); } else { - llvm::outs() << " undefined "; + llvm::outs() << "undefined"; } llvm::outs() << "\n"; auto e2 = df_dx(indexVar, indVar); - llvm::outs() << "pd of " << indexVar << " w.r.t " << indVar << "\n"; + llvm::outs() << "∂(" << indexVar << ")/∂(" << indVar << ") = "; if (e2) { e2->dump(llvm::outs()); } else { - llvm::outs() << " undefined "; + llvm::outs() << "undefined"; } llvm::outs() << "\n"; @@ -787,11 +792,11 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri } } - llvm::outs() << "pd of " << memrefOp << " w.r.t " << indVar << "\n"; + llvm::outs() << "∂(" << memrefOp << ")/∂(" << indVar << ") = "; if (dodi) { dodi->dump(llvm::outs()); } else { - llvm::outs() << " undefined "; + llvm::outs() << "undefined"; } llvm::outs() << "\n"; From 403dd7dbd05dc28c59ee7b38589ce2c9135139d6 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 15 May 2025 15:28:16 -0600 Subject: [PATCH 30/37] simplfy parallel region traversal Signed-off-by: Carl Pearson --- .../Transforms/KokkosMdrangeIterationPass.cpp | 125 ++++++------------ 1 file changed, 43 insertions(+), 82 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 0a592c38..4dca93e2 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -841,49 +841,6 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio return parallelOp.getInductionVars().size(); } - template - void walk_configurations(scf::ParallelOp &parentOp, ParallelConfig cfg, Lambda &&f) { - bool found = false; -#if 1 - for (scf::ParallelOp parallelOp : parentOp.getBody()->getOps()) { - found = true; - - // walk all configurations of this parallel op too - Permutation perm(get_num_induction_vars(parallelOp)); - std::iota(perm.begin(), perm.end(), 0); - do { - cfg.perms_[parallelOp] = perm; - walk_configurations(parallelOp, cfg, std::forward(f)); - } while (std::next_permutation(perm.begin(), perm.end())); - } -#else - parentOp.getBody()->walk([&](Operation *op) { - if (auto parallelOp = dyn_cast(op)) { - found = true; - - // walk all configurations of this parallel op too - Permutation perm(get_num_induction_vars(parallelOp)); - std::iota(perm.begin(), perm.end(), 0); - do { - cfg.perms_[parallelOp] = perm; - walk_configurations(parallelOp, cfg, std::forward(f)); - } while (std::next_permutation(perm.begin(), perm.end())); - } - }); // walk -#endif - // no nested parallel regions, no more configurations to go through, call f - if (!found) { - llvm::outs() << "reached fully-nested configuration\n"; - f(cfg); - } - } - - template - void walk_configurations(mlir::Operation *op, Lambda &&f) { - ParallelConfig cfg; - walk_configurations(op, std::forward(f), cfg); - } - // call f on every operation immediately nested under op template void for_each_nested(mlir::Operation *op, Lambda &&f) { @@ -900,40 +857,15 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio } } - // return true if op, or any of its nested children, were scf parallel - template - bool walk_configurations(mlir::Operation *op, Lambda &&f, const ParallelConfig &cfg) { - if (auto parallelOp = dyn_cast(op)) { - // create a permutation of induction variables - Permutation perm(get_num_induction_vars(parallelOp)); - std::iota(perm.begin(), perm.end(), 0); - - // walk the children of this op with all different induction variable configurations - do { - ParallelConfig newCfg = cfg; - newCfg.perms_[parallelOp] = perm; - - bool anyParallel = false; - for_each_nested(op, [&](mlir::Operation *child){ - anyParallel |= walk_configurations(child, std::forward(f), newCfg); - }); - - // if no parallels are nested underneath here, this is a complete - // configuration - if (!anyParallel) { - llvm::outs() << "no parallel regions nested below this...\n" << *op - << "\n...invoking callable on ParallelConfig of " << newCfg.perms_.size() << " regions\n"; - f(newCfg); - } - } while (std::next_permutation(perm.begin(), perm.end())); - return true; - } else { - bool anyParallel = false; - for_each_nested(op, [&](mlir::Operation *child){ - anyParallel |= walk_configurations(child, std::forward(f), cfg); - }); - return anyParallel; - } + + static std::vector get_parallel_ops(ModuleOp &mod) { + std::vector ret; + mod.walk([&](Operation *op) { + if (auto parallelOp = dyn_cast(op)) { + ret.push_back(parallelOp); + } + }); + return ret; } // cost of a module with a given parallel configuration @@ -1049,6 +981,34 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio parallelOp.erase(); } + template + void walk_configurations(std::vector &ops, Lambda &&f) { + ParallelConfig cfg; + walk_configurations(ops, std::forward(f), cfg); + } + +// return true if op, or any of its nested children, were scf parallel + template + void walk_configurations(std::vector &ops, Lambda &&f, const ParallelConfig &cfg) { + if (ops.empty()) { + f(cfg); + } else { + scf::ParallelOp &first = ops[0]; + std::vector rest; + for (size_t oi = 1; oi < ops.size(); ++oi) { + rest.push_back(ops[oi]); + } + Permutation perm(get_num_induction_vars(first)); + std::iota(perm.begin(), perm.end(), 0); + + do { + ParallelConfig newCfg = cfg; + newCfg.perms_[first] = perm; + walk_configurations(rest, std::forward(f), newCfg); + } while (std::next_permutation(perm.begin(), perm.end())); + } + } + void runOnOperation() override { ModuleOp module = getOperation(); @@ -1072,11 +1032,13 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio llvm::outs() << "====\nbuild_cost_table\n====\n"; MemrefInductionCosts costTable = build_cost_table(module, tripCounts); - llvm::outs() << "====\nmodel reordered induction vars\n====\n"; + llvm::outs() << "====\nExtract parallel ops\n====\n"; + auto parallelOps = get_parallel_ops(module); + + llvm::outs() << "====\nModel Reordered Induction variables\n====\n"; size_t minCost = std::numeric_limits::max(); ParallelConfig minCfg; - walk_configurations(module, [&](const ParallelConfig &cfg){ - + walk_configurations(parallelOps, [&](const ParallelConfig &cfg){ llvm::outs() << "modeling ParallelConfig:\n"; for (const auto &kv : cfg.perms_) { kv.first->print(llvm::outs()); @@ -1105,11 +1067,10 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio minCfg = cfg; } - }); // walk_configurations + }); llvm::outs() << "min cost: " << minCost << "\n"; llvm::outs() << "====\nbuild new module\n====\n"; - #if 0 // clone the existing module ModuleOp newModule = module.clone(); From 218d8e1302cd98a61626c6e54f6ec2ebab10ce6e Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 15 May 2025 16:26:46 -0600 Subject: [PATCH 31/37] handle nesting in cost table, left-most induction variable in cost Signed-off-by: Carl Pearson --- .../Transforms/KokkosMdrangeIterationPass.cpp | 48 +++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 4dca93e2..a3b11f7e 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -641,7 +641,7 @@ static std::shared_ptr iteration_space_expr(scf::ParallelOp &op, int dim) auto num = Add::make(Sub::make(ubExpr, lbExpr), Sub::make(stExpr, Constant::make(1))); auto ret = Div::make(num, stExpr); - llvm::outs() << "Trip count of " << op << " dim " << dim << " = "; + llvm::outs() << "Trip count (dim " << dim << ") of:\n" << op <<"\n: "; ret->dump(llvm::outs()); llvm::outs() << "\n"; return ret; @@ -661,23 +661,24 @@ using IterationSpaceExprs = llvm::DenseMap expr = iteration_space_expr(op); + ISE1[op] = expr; + }); - IterationSpaceExprs ISE; - - mod.walk([&](Operation *op) { - if (auto parallelOp = dyn_cast(op)) { - // create an expression representing the trip count for this loop - std::shared_ptr expr = iteration_space_expr(parallelOp); - ISE[parallelOp] = expr; - - // descend into the body of the loop - IterationSpaceExprs exprs = build_parallel_trip_counts(parallelOp, expr); - ISE.insert(exprs.begin(), exprs.end()); - } - }); // walk - + // fixup, incorporate parent trip counts into expression + IterationSpaceExprs ISE2; + mod.walk([&](scf::ParallelOp op){ + std::shared_ptr expr = ISE1[op]; + for (auto parent : enclosing_parallel_ops(op)) { + expr = Mul::make(expr, ISE1[parent]); + } + ISE2[op] = expr; + }); - return ISE; + return ISE2; } static IterationSpaceExprs build_parallel_trip_counts(scf::ParallelOp &parentOp, std::shared_ptr parentCount) { @@ -857,6 +858,15 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio } } + static std::vector enclosing_parallel_ops(mlir::Operation *op) { + std::vector ops; + scf::ParallelOp parent = op->getParentOfType(); + while (parent) { + ops.push_back(parent); + parent = parent->getParentOfType(); + } + return ops; + } static std::vector get_parallel_ops(ModuleOp &mod) { std::vector ret; @@ -925,13 +935,13 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio if (auto it = cfg.perms_.find(parallelOp); it != cfg.perms_.end()) { const Permutation &perm = it->second; - Value rightMostVar = parallelOp.getInductionVars()[perm[perm.size() - 1]]; + Value leftMostVar = parallelOp.getInductionVars()[perm[0]]; - llvm::outs() << "under permutation, right-most enclosing induction var is " << rightMostVar << "\n"; + llvm::outs() << "under permutation, left-most enclosing induction var is " << leftMostVar << "\n"; // FIXME: why does this work? the table should expect key to be pair not pair - auto costKey = std::make_pair(memrefOp, rightMostVar); + auto costKey = std::make_pair(memrefOp, leftMostVar); if (auto jt = costTable.find(costKey); jt != costTable.end()) { Cost model = jt->second; size_t cost = monte_carlo(model); From cc9eec0d26eb4a473e77a0ad33fb427123bceeaf Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Thu, 15 May 2025 17:01:19 -0600 Subject: [PATCH 32/37] LayoutRight cost, model cost of nested loops --- .../Transforms/KokkosMdrangeIterationPass.cpp | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index a3b11f7e..6dcc7e4a 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -422,7 +422,7 @@ static std::shared_ptr df_dx(Value &f, Value &x) { llvm::outs() << "Info: df_dx of equal values\n"; return Constant::make(1); } else if (mlir::isa(f) && mlir::isa(x)) { - llvm::outs() << "Info: df_dx of different block arguments\n"; + llvm::outs() << "Info: df_dx of different block arguments:\n"; return Constant::make(0); } else { // FIXME: what other scenarios if there is no defining op. @@ -505,12 +505,14 @@ static std::shared_ptr df_dx(Value &f, Value &x) { // Get the size in bits of the element type mlir::Type elementType = memrefType.getElementType(); - unsigned sizeInBytes = elementType.getIntOrFloatBitWidth() / CHAR_BIT; + const unsigned sizeInBytes = elementType.getIntOrFloatBitWidth() / CHAR_BIT; std::shared_ptr res = std::make_shared(sizeInBytes); + // LayoutRight: work in from the right, multiplying in dimensions auto memrefShape = memrefType.getShape(); - for (int dim = 0; dim < indexVarDim; ++dim) { + const int nDim = memrefShape.size(); + for (int dim = nDim - 1; dim > indexVarDim; --dim) { if (memrefShape[dim] == ShapedType::kDynamic) { std::string name = std::string("memref") + std::to_string(uintptr_t(memrefOp.getOperation())) + "_extent" + std::to_string(dim); res = std::make_shared(res, Unknown::make(name)); @@ -927,39 +929,38 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio template static size_t model_cost(MemrefOp &memrefOp, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { static_assert(std::is_same_v || std::is_same_v); - llvm::outs() << "model cost of " << memrefOp << "...\n"; - auto parentOp = memrefOp.getOperation()->getParentOp(); - if (auto parallelOp = dyn_cast(parentOp)) { + auto parents = enclosing_parallel_ops(memrefOp); + + llvm::outs() << "... memref op above has " << parents.size() << " enclosing parallels\n"; + + size_t cost = 0; + for (scf::ParallelOp ¶llelOp : parents) { if (auto it = cfg.perms_.find(parallelOp); it != cfg.perms_.end()) { const Permutation &perm = it->second; Value leftMostVar = parallelOp.getInductionVars()[perm[0]]; - llvm::outs() << "under permutation, left-most enclosing induction var is " << leftMostVar << "\n"; + llvm::outs() << "... under permutation, left-most induction var of enclosing parallel is " << leftMostVar << "\n"; - // FIXME: why does this work? the table should expect key to be pair not pair auto costKey = std::make_pair(memrefOp, leftMostVar); if (auto jt = costTable.find(costKey); jt != costTable.end()) { Cost model = jt->second; - size_t cost = monte_carlo(model); - llvm::outs() << "..." << memrefOp << " contributes " << cost << "\n"; - return cost; + size_t parallelContrib = monte_carlo(model); + cost += parallelContrib; + llvm::outs() << "..." << memrefOp << " contributes " << parallelContrib << " (now " << cost << ")\n"; } else { - llvm::outs() << "couldn't find model for memref / induction variable combo\n"; + llvm::outs() << "WARN: couldn't find model for memref / induction variable combo\n"; return 0; } - } else { - llvm::outs() << "couldn't find permutation for parent parallel op\n"; + llvm::outs() << "WARN: couldn't find permutation for parent parallel op\n"; return 0; } - } else { - llvm::outs() << "memrefOp " << memrefOp << " has no parallel parent\n"; - return 0; } + return cost; } From a46ace9ec5f34a4c86108031b383fc2e3aac812e Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 16 May 2025 14:18:10 -0600 Subject: [PATCH 33/37] Simplify building costs of memrefs --- .../Transforms/KokkosMdrangeIterationPass.cpp | 84 +++++++------------ 1 file changed, 32 insertions(+), 52 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 6dcc7e4a..2a22790b 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -129,6 +129,10 @@ struct KokkosMdrangeIterationPass KokkosMdrangeIterationPass() = default; KokkosMdrangeIterationPass(const KokkosMdrangeIterationPass& pass) = default; + using ParallelOpVec = llvm::SmallVector; + using ValueVec = llvm::SmallVector; + using Permutation = llvm::SmallVector; + // generate a log-random integer within a specified range static size_t log_random_int(std::mt19937 &gen, size_t min, size_t max) { // Create a uniform real distribution between log(min) and log(max) @@ -704,11 +708,11 @@ static IterationSpaceExprs build_parallel_trip_counts(scf::ParallelOp &parentOp, // map of (Operation*, Value) -> Cost // map of the cost model for a given memref / induction variable pair using MemrefInductionCosts = llvm::DenseMap, Cost>; -using ParallelOpStack = llvm::SmallVector; + // return all induction variables for all parallel ops -static std::vector all_induction_variables(ParallelOpStack &ops) { - std::vector vars; +static ValueVec all_induction_variables(ParallelOpVec &ops) { + ValueVec vars; for (auto &op : ops) { for (auto &var : op.getInductionVars()) { vars.push_back(var); @@ -717,37 +721,12 @@ static std::vector all_induction_variables(ParallelOpStack &ops) { return vars; } -static MemrefInductionCosts build_cost_table(ModuleOp &mod, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) { - - MemrefInductionCosts MIC; - - mod.walk([&](Operation *op) { - // skip memrefs outside a parallel region - if (auto parallelOp = dyn_cast(op)) { - stack.push_back(parallelOp); - MemrefInductionCosts costs = build_cost_table(parallelOp, tripCounts, stack); - stack.pop_back(); - MIC.insert(costs.begin(), costs.end()); - } - }); // walk - - return MIC; - } - - - -static MemrefInductionCosts build_cost_table(ModuleOp &mod, IterationSpaceExprs &tripCounts) { - ParallelOpStack stack; - return build_cost_table(mod, tripCounts, stack); -} - - // compute the partial derivative of each memref with respect to all enclosing induction variables via the chain rule: // d(offset)/d(indvar) = sum( // d(offset)/d(index) * d(index)/d(indvar), // for each index in indices) template -static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) { +static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tripCounts) { static_assert(std::is_same_v || std::is_same_v); if constexpr (std::is_same_v) { @@ -758,11 +737,16 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri MemrefInductionCosts MIC; - std::vector indVars = all_induction_variables(stack); - if (stack.empty()) { - llvm::report_fatal_error("get_costs: memref is not enclosed in an scf::ParallelOp"); + // get all the parallel ops that enclose this memref + auto ancestors = enclosing_parallel_ops(memrefOp); + if (ancestors.empty()) { + llvm::outs() << "get_costs: memref is not enclosed in an scf::ParallelOp\n"; + return MIC; } - scf::ParallelOp &parentOp = stack.back(); + scf::ParallelOp &parentOp = *ancestors.begin(); + + ValueVec indVars = all_induction_variables(ancestors); + for (Value indVar : indVars) { std::shared_ptr dodi = Constant::make(0); @@ -810,21 +794,16 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri return MIC; } -// FIXME: parentOp is also the back of the stack? -static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, IterationSpaceExprs &tripCounts, ParallelOpStack &stack) { + static MemrefInductionCosts build_cost_table(ModuleOp &mod, IterationSpaceExprs &tripCounts) { MemrefInductionCosts MIC; - parentOp.getBody()->walk([&](Operation *op) { - if (auto parallelOp = dyn_cast(op)) { - stack.push_back(parallelOp); - MemrefInductionCosts costs = build_cost_table(parallelOp, tripCounts, stack); - stack.pop_back(); - MIC.insert(costs.begin(), costs.end()); - } else if (auto memrefOp = dyn_cast(op)) { - MemrefInductionCosts costs = get_costs(memrefOp, tripCounts, stack); + // compute for loads + mod.walk([&](Operation *op){ + if (auto memrefOp = dyn_cast(op)) { + MemrefInductionCosts costs = get_costs(memrefOp, tripCounts); MIC.insert(costs.begin(), costs.end()); } else if (auto memrefOp = dyn_cast(op)) { - MemrefInductionCosts costs = get_costs(memrefOp, tripCounts, stack); + MemrefInductionCosts costs = get_costs(memrefOp, tripCounts); MIC.insert(costs.begin(), costs.end()); } }); @@ -832,7 +811,8 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio return MIC; } - using Permutation = llvm::SmallVector; + + struct ParallelConfig { // permutation of induction variables for each parallel op @@ -860,8 +840,8 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio } } - static std::vector enclosing_parallel_ops(mlir::Operation *op) { - std::vector ops; + static ParallelOpVec enclosing_parallel_ops(mlir::Operation *op) { + ParallelOpVec ops; scf::ParallelOp parent = op->getParentOfType(); while (parent) { ops.push_back(parent); @@ -870,8 +850,8 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio return ops; } - static std::vector get_parallel_ops(ModuleOp &mod) { - std::vector ret; + static ParallelOpVec get_parallel_ops(ModuleOp &mod) { + ParallelOpVec ret; mod.walk([&](Operation *op) { if (auto parallelOp = dyn_cast(op)) { ret.push_back(parallelOp); @@ -993,19 +973,19 @@ static MemrefInductionCosts build_cost_table(scf::ParallelOp &parentOp, Iteratio } template - void walk_configurations(std::vector &ops, Lambda &&f) { + void walk_configurations(ParallelOpVec &ops, Lambda &&f) { ParallelConfig cfg; walk_configurations(ops, std::forward(f), cfg); } // return true if op, or any of its nested children, were scf parallel template - void walk_configurations(std::vector &ops, Lambda &&f, const ParallelConfig &cfg) { + void walk_configurations(ParallelOpVec &ops, Lambda &&f, const ParallelConfig &cfg) { if (ops.empty()) { f(cfg); } else { scf::ParallelOp &first = ops[0]; - std::vector rest; + ParallelOpVec rest; for (size_t oi = 1; oi < ops.size(); ++oi) { rest.push_back(ops[oi]); } From 7c40ba54c7421ef4f547b974ab353d0c9e242c33 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 16 May 2025 15:01:56 -0600 Subject: [PATCH 34/37] incorporate load/store scale factor, guard prints with macro --- .../Transforms/KokkosMdrangeIterationPass.cpp | 267 ++++++++++-------- 1 file changed, 151 insertions(+), 116 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 2a22790b..f7479cfd 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -123,9 +123,53 @@ We compute the cost under each combination. We chosoe the induction variable ordering with the lowest cost */ + +// Put these out here so we can overload operator<< easily +namespace { + + // a context for expression evaluation + struct Ctx { + std::unordered_map values; // FIXME: llvm data structures + }; + + struct Expr { + enum class Kind { + Add, Sub, Mul, Div, Constant, Unknown + }; + + Expr(Kind kind) : kind_(kind) {} + Kind kind_; + + virtual int eval(const Ctx &ctx) = 0; + virtual void dump(llvm::raw_fd_ostream &os) const = 0; + virtual void dump(llvm::raw_ostream &os) const = 0; + virtual std::vector unknowns() const = 0; + virtual std::shared_ptr clone() const = 0; + virtual ~Expr() {} + }; + +} // namespace + +static llvm::raw_fd_ostream & operator<<(llvm::raw_fd_ostream &os, const std::shared_ptr &e) { + e->dump(os); + return os; +} + +static llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const std::shared_ptr &e) { + e->dump(os); + return os; +} + struct KokkosMdrangeIterationPass : public impl::KokkosMdrangeIterationBase { +#if 1 +#define MDRANGE_DEBUG(x) \ + llvm::outs() << x; +#else +#define MDRANGE_DEBUG(x) +#endif + KokkosMdrangeIterationPass() = default; KokkosMdrangeIterationPass(const KokkosMdrangeIterationPass& pass) = default; @@ -149,33 +193,13 @@ struct KokkosMdrangeIterationPass return res; } - // a context for expression evaluation - struct Ctx { - - // FIXME: llvm data structures - std::unordered_map values; - }; - - struct Expr { - enum class Kind { - Add, Sub, Mul, Div, Constant, Unknown - }; - - Expr(Kind kind) : kind_(kind) {} - Kind kind_; - - virtual int eval(const Ctx &ctx) = 0; - virtual void dump(llvm::raw_fd_ostream &os) = 0; - virtual std::vector unknowns() const = 0; - virtual std::shared_ptr clone() const = 0; - virtual ~Expr() {} - }; struct Binary : public Expr { Binary(Kind kind, const std::string sym, std::shared_ptr lhs, std::shared_ptr rhs) : Expr(kind), sym_(sym), lhs_(lhs), rhs_(rhs) {} - virtual void dump(llvm::raw_fd_ostream &os) override { + template + void dump_impl(OS &os) const { os << "("; lhs_->dump(os); os << sym_; @@ -183,6 +207,14 @@ struct KokkosMdrangeIterationPass os << ")"; } + virtual void dump(llvm::raw_fd_ostream &os) const override { + dump_impl(os); + } + + virtual void dump(llvm::raw_ostream &os) const override { + dump_impl(os); + } + virtual std::vector unknowns() const override { std::vector ret; for (auto &op : {lhs_, rhs_}) { @@ -346,7 +378,16 @@ struct KokkosMdrangeIterationPass return make(value_); } - virtual void dump(llvm::raw_fd_ostream &os) override { + virtual void dump(llvm::raw_fd_ostream &os) const override { + dump_impl(os); + } + + virtual void dump(llvm::raw_ostream &os) const override { + dump_impl(os); + } + + template + void dump_impl(OS &os) const { os << value_; } @@ -375,7 +416,16 @@ struct KokkosMdrangeIterationPass return make(name_); } - virtual void dump(llvm::raw_fd_ostream &os) override { + virtual void dump(llvm::raw_fd_ostream &os) const override { + dump_impl(os); + } + + virtual void dump(llvm::raw_ostream &os) const override { + dump_impl(os); + } + + template + void dump_impl(OS &os) const { os << "(" << name_ << ")"; } @@ -423,10 +473,10 @@ struct KokkosMdrangeIterationPass // partial derivative df/dx static std::shared_ptr df_dx(Value &f, Value &x) { if (f == x) { - llvm::outs() << "Info: df_dx of equal values\n"; + MDRANGE_DEBUG("Info: df_dx of equal values\n"); return Constant::make(1); } else if (mlir::isa(f) && mlir::isa(x)) { - llvm::outs() << "Info: df_dx of different block arguments:\n"; + MDRANGE_DEBUG("Info: df_dx of different block arguments:\n"); return Constant::make(0); } else { // FIXME: what other scenarios if there is no defining op. @@ -435,7 +485,7 @@ static std::shared_ptr df_dx(Value &f, Value &x) { return df_dx(fOp, xOp); } } - llvm::outs() << "ERROR: One of the values has no defining operation\n"; + MDRANGE_DEBUG("ERROR: One of the values has no defining operation\n"); return nullptr; } } @@ -443,10 +493,10 @@ static std::shared_ptr df_dx(Value &f, Value &x) { // FIXME: better written as df_dx(f, x) I guess static std::shared_ptr df_dx(Operation *df, Operation *dx) { if (!df) { - llvm::outs() << "Warn: df_dx requested on null df\n"; + MDRANGE_DEBUG("Warn: df_dx requested on null df\n"); return nullptr; } else if (!dx) { - llvm::outs() << "Warn: df_dx requested on null dx\n"; + MDRANGE_DEBUG("Warn: df_dx requested on null dx\n"); return nullptr; } else if (df == dx) { // df/dx (dx) = 1 @@ -485,10 +535,7 @@ static std::shared_ptr df_dx(Value &f, Value &x) { } } // TODO: sub, div - llvm::outs() << "WARN: unhandled case in df_dx of "; - df->print(llvm::outs()); - llvm::outs() << " w.r.t."; - dx->print(llvm::outs()); + MDRANGE_DEBUG("WARN: unhandled case in df_dx of " << df << " w.r.t " << dx << "\n"); return nullptr; } @@ -531,62 +578,54 @@ static std::shared_ptr df_dx(Value &f, Value &x) { } // memref address is not a function of this variable - llvm::outs() << "Info: "; - memrefOp.print(llvm::outs()); - llvm::outs() << " is not a function of "; - indexVar.print(llvm::outs()); - llvm::outs() << "\n"; + MDRANGE_DEBUG("Info: " << memrefOp << " is not a function of " << indexVar << "\n"); return std::make_shared(0); } static void dump_ops(ModuleOp &mod) { mod.walk([&](Operation *op) { if (auto parallelOp = dyn_cast(op)) { - llvm::outs() << "Found scf.parallel operation:\n"; - llvm::outs() << "Induction variables and strides:\n"; + MDRANGE_DEBUG("Found scf.parallel operation:\n"); + MDRANGE_DEBUG("Induction variables and strides:\n"); for (auto iv : llvm::zip(parallelOp.getInductionVars(), parallelOp.getStep())) { - std::get<0>(iv).print(llvm::outs()); - llvm::outs() << " with stride "; - std::get<1>(iv).print(llvm::outs()); - llvm::outs() << "\n"; + (void) iv; + MDRANGE_DEBUG(std::get<0>(iv) << " with stride " << std::get<1>(iv) << "\n"); } - llvm::outs() << "\n\n"; + MDRANGE_DEBUG("\n\n"); } if (auto memrefOp = dyn_cast(op)) { - llvm::outs() << "Found memref.load operation:\n"; - llvm::outs() << "MemRef: "; - memrefOp.getMemRef().print(llvm::outs()); - llvm::outs() << "\nIndex variables:\n"; + MDRANGE_DEBUG("Found memref.load operation:\n"); + MDRANGE_DEBUG("MemRef: " << memrefOp.getMemRef() << "\nIndex variables:\n"); for (Value index : memrefOp.getIndices()) { - index.print(llvm::outs()); - llvm::outs() << "\n"; + (void) index; + MDRANGE_DEBUG(index << "\n"); } if (auto memrefType = dyn_cast(memrefOp.getMemRef().getType())) { - llvm::outs() << "MemRef extents:\n"; + MDRANGE_DEBUG("MemRef extents:\n"); for (int64_t dim : memrefType.getShape()) { - llvm::outs() << dim << "\n"; + (void) dim; + MDRANGE_DEBUG(dim << "\n"); } } - llvm::outs() << "\n\n"; + MDRANGE_DEBUG("\n\n"); } if (auto memrefOp = dyn_cast(op)) { - llvm::outs() << "Found memref.store operation:\n"; - llvm::outs() << "MemRef: "; - memrefOp.getMemRef().print(llvm::outs()); - llvm::outs() << "\nIndex variables:\n"; + MDRANGE_DEBUG("Found memref.store operation:\n"); + MDRANGE_DEBUG("MemRef: " << memrefOp.getMemRef() << "\nIndex variables:\n"); for (Value index : memrefOp.getIndices()) { - index.print(llvm::outs()); - llvm::outs() << "\n"; + (void) index; + MDRANGE_DEBUG(index << "\n"); } if (auto memrefType = dyn_cast(memrefOp.getMemRef().getType())) { - llvm::outs() << "MemRef extents:\n"; + MDRANGE_DEBUG("MemRef extents:\n"); for (int64_t dim : memrefType.getShape()) { - llvm::outs() << dim << "\n"; + (void) dim; + MDRANGE_DEBUG(dim << "\n"); } } - llvm::outs() << "\n\n"; + MDRANGE_DEBUG("\n\n"); } }); } @@ -647,9 +686,7 @@ static std::shared_ptr iteration_space_expr(scf::ParallelOp &op, int dim) auto num = Add::make(Sub::make(ubExpr, lbExpr), Sub::make(stExpr, Constant::make(1))); auto ret = Div::make(num, stExpr); - llvm::outs() << "Trip count (dim " << dim << ") of:\n" << op <<"\n: "; - ret->dump(llvm::outs()); - llvm::outs() << "\n"; + MDRANGE_DEBUG("Trip count (dim " << dim << ") of:\n" << op << "\n: " << ret << "\n"); return ret; } @@ -730,9 +767,9 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri static_assert(std::is_same_v || std::is_same_v); if constexpr (std::is_same_v) { - llvm::outs() << "get_cost: memref::LoadOp\n"; + MDRANGE_DEBUG("get_cost: memref::LoadOp\n"); } else if constexpr (std::is_same_v) { - llvm::outs() << "get_cost: memref::StoreOp\n"; + MDRANGE_DEBUG("get_cost: memref::StoreOp\n"); } MemrefInductionCosts MIC; @@ -740,7 +777,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri // get all the parallel ops that enclose this memref auto ancestors = enclosing_parallel_ops(memrefOp); if (ancestors.empty()) { - llvm::outs() << "get_costs: memref is not enclosed in an scf::ParallelOp\n"; + MDRANGE_DEBUG("get_costs: memref is not enclosed in an scf::ParallelOp\n"); return MIC; } scf::ParallelOp &parentOp = *ancestors.begin(); @@ -753,23 +790,23 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri for (Value indexVar : memrefOp.getIndices()) { auto e1 = do_di(memrefOp, indexVar); - llvm::outs() << "∂(" << memrefOp << ")/∂(" << indexVar << ") = "; + MDRANGE_DEBUG("∂(" << memrefOp << ")/∂(" << indexVar << ") = "); if (e1) { - e1->dump(llvm::outs()); + MDRANGE_DEBUG(e1); } else { - llvm::outs() << "undefined"; + MDRANGE_DEBUG("undefined"); } - llvm::outs() << "\n"; + MDRANGE_DEBUG("\n"); auto e2 = df_dx(indexVar, indVar); - llvm::outs() << "∂(" << indexVar << ")/∂(" << indVar << ") = "; + MDRANGE_DEBUG("∂(" << indexVar << ")/∂(" << indVar << ") = "); if (e2) { - e2->dump(llvm::outs()); + MDRANGE_DEBUG(e2); } else { - llvm::outs() << "undefined"; + MDRANGE_DEBUG("undefined"); } - llvm::outs() << "\n"; + MDRANGE_DEBUG("\n"); if (e1 && e2) { dodi = Add::make(dodi, Mul::make(e1, e2)); @@ -779,13 +816,13 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri } } - llvm::outs() << "∂(" << memrefOp << ")/∂(" << indVar << ") = "; + MDRANGE_DEBUG("∂(" << memrefOp << ")/∂(" << indVar << ") = "); if (dodi) { - dodi->dump(llvm::outs()); + MDRANGE_DEBUG(dodi); } else { - llvm::outs() << "undefined"; + MDRANGE_DEBUG("undefined"); } - llvm::outs() << "\n"; + MDRANGE_DEBUG("\n"); std::shared_ptr tripCount = tripCounts[parentOp]; @@ -890,7 +927,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri ctx.values[name] = val; } - costs.push_back(model.stride_->eval(ctx) * model.count_->eval(ctx)); + costs.push_back(model.stride_->eval(ctx) * model.count_->eval(ctx) * model.sf_); } // FIXME: here we do median, is there a principled aggregation strategy? @@ -909,11 +946,11 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri template static size_t model_cost(MemrefOp &memrefOp, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { static_assert(std::is_same_v || std::is_same_v); - llvm::outs() << "model cost of " << memrefOp << "...\n"; + MDRANGE_DEBUG("model cost of " << memrefOp << "...\n"); auto parents = enclosing_parallel_ops(memrefOp); - llvm::outs() << "... memref op above has " << parents.size() << " enclosing parallels\n"; + MDRANGE_DEBUG("... memref op above has " << parents.size() << " enclosing parallels\n"); size_t cost = 0; for (scf::ParallelOp ¶llelOp : parents) { @@ -922,7 +959,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri const Permutation &perm = it->second; Value leftMostVar = parallelOp.getInductionVars()[perm[0]]; - llvm::outs() << "... under permutation, left-most induction var of enclosing parallel is " << leftMostVar << "\n"; + MDRANGE_DEBUG("... under permutation, left-most induction var of enclosing parallel is " << leftMostVar << "\n"); // FIXME: why does this work? the table should expect key to be pair not pair auto costKey = std::make_pair(memrefOp, leftMostVar); @@ -930,13 +967,13 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri Cost model = jt->second; size_t parallelContrib = monte_carlo(model); cost += parallelContrib; - llvm::outs() << "..." << memrefOp << " contributes " << parallelContrib << " (now " << cost << ")\n"; + MDRANGE_DEBUG("..." << memrefOp << " contributes " << parallelContrib << " (now " << cost << ")\n"); } else { - llvm::outs() << "WARN: couldn't find model for memref / induction variable combo\n"; + MDRANGE_DEBUG("WARN: couldn't find model for memref / induction variable combo\n"); return 0; } } else { - llvm::outs() << "WARN: couldn't find permutation for parent parallel op\n"; + MDRANGE_DEBUG("WARN: couldn't find permutation for parent parallel op\n"); return 0; } } @@ -1003,55 +1040,51 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri void runOnOperation() override { ModuleOp module = getOperation(); - llvm::outs() << module << "\n"; + MDRANGE_DEBUG(module << "\n"); - llvm::outs() << "====\ndump_ops\n====\n"; + MDRANGE_DEBUG("====\ndump_ops\n====\n"); dump_ops(module); - llvm::outs() << "====\nbuild_parallel_trip_counts\n====\n"; + MDRANGE_DEBUG("====\nbuild_parallel_trip_counts\n====\n"); IterationSpaceExprs tripCounts = build_parallel_trip_counts(module); for (auto &kv : tripCounts) { const std::shared_ptr &trip = kv.second; - llvm::outs() << "parallel op: "; - kv.first.print(llvm::outs()); - llvm::outs() << " trip: "; - trip->dump(llvm::outs()); - llvm::outs() << "\n"; + MDRANGE_DEBUG("parallel op: " << kv.first << " trip: " << trip << "\n"); } - llvm::outs() << "====\nbuild_cost_table\n====\n"; + MDRANGE_DEBUG("====\nbuild_cost_table\n====\n"); MemrefInductionCosts costTable = build_cost_table(module, tripCounts); - llvm::outs() << "====\nExtract parallel ops\n====\n"; + MDRANGE_DEBUG("====\nExtract parallel ops\n====\n"); auto parallelOps = get_parallel_ops(module); - llvm::outs() << "====\nModel Reordered Induction variables\n====\n"; + MDRANGE_DEBUG("====\nModel Reordered Induction variables\n====\n"); size_t minCost = std::numeric_limits::max(); ParallelConfig minCfg; walk_configurations(parallelOps, [&](const ParallelConfig &cfg){ - llvm::outs() << "modeling ParallelConfig:\n"; + MDRANGE_DEBUG("modeling ParallelConfig:\n"); for (const auto &kv : cfg.perms_) { - kv.first->print(llvm::outs()); - llvm::outs() << " -> {"; + MDRANGE_DEBUG(kv.first << " -> {"); for(const auto &e : kv.second) { - llvm::outs() << e << ", "; + (void)e; + MDRANGE_DEBUG(e << ", "); } - llvm::outs() << "}\n"; + MDRANGE_DEBUG("}\n"); } size_t cost = model_cost(module, cfg, costTable); - llvm::outs() << "cost was " << cost << "\n"; + MDRANGE_DEBUG("cost was " << cost << "\n"); if (cost < minCost) { - llvm::outs() << "Info: new optimal! cost=" << cost << "\n"; + MDRANGE_DEBUG("Info: new optimal! cost=" << cost << "\n"); for (const auto &kv : cfg.perms_) { - llvm::outs() << kv.first << " with permutation: "; + MDRANGE_DEBUG(kv.first << " with permutation: "); for (const size_t e : kv.second) { - llvm::outs() << e << " "; + MDRANGE_DEBUG(e << " "); } - llvm::outs() << "\n"; + MDRANGE_DEBUG("\n"); } minCost = cost; @@ -1059,9 +1092,9 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri } }); - llvm::outs() << "min cost: " << minCost << "\n"; + MDRANGE_DEBUG("min cost: " << minCost << "\n"); - llvm::outs() << "====\nbuild new module\n====\n"; + MDRANGE_DEBUG("====\nbuild new module\n====\n"); #if 0 // clone the existing module ModuleOp newModule = module.clone(); @@ -1100,7 +1133,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri // modify the parallel ops in the module module.walk([&](scf::ParallelOp parallelOp) { - llvm::outs() << "modifying " << parallelOp << "\n"; + MDRANGE_DEBUG("modifying " << parallelOp << "\n"); #if 1 const Permutation &permutation = minCfg.perms_[parallelOp]; @@ -1111,16 +1144,17 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri std::iota(permutation.begin(), permutation.end(), 0); std::reverse(permutation.begin(), permutation.end()); #endif - llvm::outs() << "applying permutation "; + MDRANGE_DEBUG("applying permutation "); for (auto i : permutation) { - llvm::outs() << i << " "; + (void) i; + MDRANGE_DEBUG(i << " "); } - llvm::outs() << "\n"; + MDRANGE_DEBUG("\n"); permute_parallel_op(parallelOp, permutation); }); #endif - llvm::outs() << "====\ndone\n====\n"; + MDRANGE_DEBUG("====\ndone\n====\n"); } }; @@ -1129,3 +1163,4 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri std::unique_ptr mlir::createKokkosMdrangeIterationPass() { return std::make_unique(); } + From 88f48fbfdcc6e6eb83c024fddfb7b80176ec7e24 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 16 May 2025 15:37:34 -0600 Subject: [PATCH 35/37] More logging improvements Signed-off-by: Carl Pearson --- .../Transforms/KokkosMdrangeIterationPass.cpp | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index f7479cfd..6d9b0f17 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -148,17 +148,16 @@ namespace { virtual ~Expr() {} }; -} // namespace - -static llvm::raw_fd_ostream & operator<<(llvm::raw_fd_ostream &os, const std::shared_ptr &e) { +template +OS & operator<<(OS &os, const std::shared_ptr &e) { e->dump(os); return os; } -static llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const std::shared_ptr &e) { - e->dump(os); - return os; -} + +} // namespace + + struct KokkosMdrangeIterationPass : public impl::KokkosMdrangeIterationBase { @@ -910,7 +909,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri return cost; } - static size_t monte_carlo(const Cost &model, int n = 100, int seed = 31337) { + static size_t monte_carlo(const Cost &model, int n = 5, int seed = 31337) { std::mt19937 gen(seed); std::vector costs; @@ -922,8 +921,8 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri // generate random values for all unknowns in cost model Ctx ctx; for (auto &name : unknowns) { - auto val = log_random_int(gen, 1, 1000000); - // llvm::outs() << name << ": " << val << "\n"; + auto val = log_random_int(gen, 1, 100000); + MDRANGE_DEBUG(name << ": " << val << "\n"); ctx.values[name] = val; } @@ -1050,6 +1049,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri for (auto &kv : tripCounts) { const std::shared_ptr &trip = kv.second; + (void) trip; MDRANGE_DEBUG("parallel op: " << kv.first << " trip: " << trip << "\n"); } @@ -1081,8 +1081,8 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri for (const auto &kv : cfg.perms_) { MDRANGE_DEBUG(kv.first << " with permutation: "); for (const size_t e : kv.second) { + (void) e; MDRANGE_DEBUG(e << " "); - } MDRANGE_DEBUG("\n"); } From fc9ea8dcf7d2f89dbb95825f41b583286ced8d4c Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 16 May 2025 15:37:48 -0600 Subject: [PATCH 36/37] more MDRange loop ordering tests --- .../{mdrange_1.mlir => mdrange_1a.mlir} | 1 + mlir/test/Dialect/Kokkos/mdrange_1b.mlir | 16 +++++++++ mlir/test/Dialect/Kokkos/mdrange_1c.mlir | 16 +++++++++ mlir/test/Dialect/Kokkos/mdrange_1d.mlir | 16 +++++++++ mlir/test/Dialect/Kokkos/mdrange_2.mlir | 13 ------- mlir/test/Dialect/Kokkos/mdrange_2a.mlir | 23 ++++++++++++ mlir/test/Dialect/Kokkos/mdrange_2b.mlir | 23 ++++++++++++ mlir/test/Dialect/Kokkos/mdrange_2c.mlir | 35 +++++++++++++++++++ 8 files changed, 130 insertions(+), 13 deletions(-) rename mlir/test/Dialect/Kokkos/{mdrange_1.mlir => mdrange_1a.mlir} (93%) create mode 100644 mlir/test/Dialect/Kokkos/mdrange_1b.mlir create mode 100644 mlir/test/Dialect/Kokkos/mdrange_1c.mlir create mode 100644 mlir/test/Dialect/Kokkos/mdrange_1d.mlir delete mode 100644 mlir/test/Dialect/Kokkos/mdrange_2.mlir create mode 100644 mlir/test/Dialect/Kokkos/mdrange_2a.mlir create mode 100644 mlir/test/Dialect/Kokkos/mdrange_2b.mlir create mode 100644 mlir/test/Dialect/Kokkos/mdrange_2c.mlir diff --git a/mlir/test/Dialect/Kokkos/mdrange_1.mlir b/mlir/test/Dialect/Kokkos/mdrange_1a.mlir similarity index 93% rename from mlir/test/Dialect/Kokkos/mdrange_1.mlir rename to mlir/test/Dialect/Kokkos/mdrange_1a.mlir index e9443098..7abe8ee0 100644 --- a/mlir/test/Dialect/Kokkos/mdrange_1.mlir +++ b/mlir/test/Dialect/Kokkos/mdrange_1a.mlir @@ -5,6 +5,7 @@ module { %c10 = arith.constant 10 : index %c20 = arith.constant 20 : index + // this needs to be reversed scf.parallel (%i, %j) = (%c0, %c0) to (%c10, %c20) step (%c1, %c1) { %val = memref.load %arg0[%i, %j] : memref<10x20xf32> memref.store %val, %arg1[%i, %j] : memref diff --git a/mlir/test/Dialect/Kokkos/mdrange_1b.mlir b/mlir/test/Dialect/Kokkos/mdrange_1b.mlir new file mode 100644 index 00000000..71ae14c9 --- /dev/null +++ b/mlir/test/Dialect/Kokkos/mdrange_1b.mlir @@ -0,0 +1,16 @@ +module { + func.func @example_function(%arg0: memref<10x20xf32>, %arg1: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + + // okay as-is + scf.parallel (%i, %j) = (%c0, %c0) to (%c20, %c10) step (%c1, %c1) { + %val = memref.load %arg0[%j, %i] : memref<10x20xf32> + memref.store %val, %arg1[%j, %i] : memref + scf.reduce + } + return + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/Kokkos/mdrange_1c.mlir b/mlir/test/Dialect/Kokkos/mdrange_1c.mlir new file mode 100644 index 00000000..f5208a0f --- /dev/null +++ b/mlir/test/Dialect/Kokkos/mdrange_1c.mlir @@ -0,0 +1,16 @@ +module { + func.func @example_function(%arg0: memref<20x20xf32>, %arg1: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c20 = arith.constant 20 : index + + // store needs to be reversed, load is okay as-is, store is modeled + // as more expensive, so loop should be reversed + scf.parallel (%i, %j) = (%c0, %c0) to (%c20, %c20) step (%c1, %c1) { + %val = memref.load %arg0[%j, %i] : memref<20x20xf32> + memref.store %val, %arg1[%i, %j] : memref + scf.reduce + } + return + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/Kokkos/mdrange_1d.mlir b/mlir/test/Dialect/Kokkos/mdrange_1d.mlir new file mode 100644 index 00000000..8bd0b9e0 --- /dev/null +++ b/mlir/test/Dialect/Kokkos/mdrange_1d.mlir @@ -0,0 +1,16 @@ +module { + func.func @example_function(%arg0: memref, %arg1: memref, %arg2: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // like 1c, but dynamic extents + // store needs to be reversed, load is okay as-is, store is modeled + // as more expensive, so loop should be reversed + scf.parallel (%i, %j) = (%c0, %c0) to (%arg2, %arg2) step (%c1, %c1) { + %val = memref.load %arg0[%j, %i] : memref + memref.store %val, %arg1[%i, %j] : memref + scf.reduce + } + return + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/Kokkos/mdrange_2.mlir b/mlir/test/Dialect/Kokkos/mdrange_2.mlir deleted file mode 100644 index af0d89ea..00000000 --- a/mlir/test/Dialect/Kokkos/mdrange_2.mlir +++ /dev/null @@ -1,13 +0,0 @@ -module { - func.func @example_function(%arg0: memref<10x20xf32>, %arg1: memref, %loop_bound_i: index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - - scf.parallel (%i, %j) = (%c0, %c0) to (%loop_bound_i, %c2) step (%c1, %c1) { - %val = memref.load %arg0[%i, %j] : memref<10x20xf32> - memref.store %val, %arg1[%i, %j] : memref - } - return - } -} diff --git a/mlir/test/Dialect/Kokkos/mdrange_2a.mlir b/mlir/test/Dialect/Kokkos/mdrange_2a.mlir new file mode 100644 index 00000000..2c5d6d99 --- /dev/null +++ b/mlir/test/Dialect/Kokkos/mdrange_2a.mlir @@ -0,0 +1,23 @@ +module { + func.func @example_function(%arg0: memref, + %ub_i: index, + %ub_j: index, + %ub_k: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + scf.parallel (%i, %j) = (%c0, %c0) to (%ub_i, %ub_j) step (%c1, %c1) { + + // this loop wants the outer loop indices to be reversed + scf.parallel (%k) = (%c0) to (%ub_k) step (%c1) { + %val = memref.load %arg0[%i, %j] : memref + memref.store %val, %arg0[%i, %j] : memref + scf.reduce + } + + scf.reduce + } + return + } +} diff --git a/mlir/test/Dialect/Kokkos/mdrange_2b.mlir b/mlir/test/Dialect/Kokkos/mdrange_2b.mlir new file mode 100644 index 00000000..4b6d338c --- /dev/null +++ b/mlir/test/Dialect/Kokkos/mdrange_2b.mlir @@ -0,0 +1,23 @@ +module { + func.func @example_function(%arg0: memref, + %ub_i: index, + %ub_j: index, + %ub_k: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + scf.parallel (%i, %j) = (%c0, %c0) to (%ub_i, %ub_j) step (%c1, %c1) { + + // this loop wants the outer loop indices to stay the same + scf.parallel (%k) = (%c0) to (%ub_k) step (%c1) { + %val = memref.load %arg0[%j, %i] : memref + memref.store %val, %arg0[%j, %i] : memref + scf.reduce + } + + scf.reduce + } + return + } +} diff --git a/mlir/test/Dialect/Kokkos/mdrange_2c.mlir b/mlir/test/Dialect/Kokkos/mdrange_2c.mlir new file mode 100644 index 00000000..e805a6b5 --- /dev/null +++ b/mlir/test/Dialect/Kokkos/mdrange_2c.mlir @@ -0,0 +1,35 @@ + +module { + func.func @example_function(%arg0: memref, + %ub_i: index, + %ub_j: index, + %ub_k: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + scf.parallel (%i, %j) = (%c0, %c0) to (%ub_i, %ub_j) step (%c1, %c1) { + + // this loop wants the outer loop indices to be reversed + scf.parallel (%k) = (%c0) to (%ub_k) step (%c1) { + %val = memref.load %arg0[%i, %j] : memref + memref.store %val, %arg0[%i, %j] : memref + scf.reduce + } + + // this loop wants the outer loop indices to stay the same + // this loop has double the step -> 1/2 the trip count + // the other loop should influence the order more strongly + scf.parallel (%k) = (%c0) to (%ub_k) step (%c2) { + %val = memref.load %arg0[%j, %i] : memref + memref.store %val, %arg0[%j, %i] : memref + scf.reduce + } + + scf.reduce + } + return + } +} + + From 297b31a421f33a5658023533741e5fb6dcf8b244 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 16 May 2025 16:05:39 -0600 Subject: [PATCH 37/37] prevent overflow in Expr eval, more consistent MC simulations Signed-off-by: Carl Pearson --- .../Transforms/KokkosMdrangeIterationPass.cpp | 63 ++++++++++++------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp index 6d9b0f17..9c1849fd 100644 --- a/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp +++ b/mlir/lib/Dialect/Kokkos/Transforms/KokkosMdrangeIterationPass.cpp @@ -10,6 +10,7 @@ #include #include // pair #include +#include #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -140,7 +141,7 @@ namespace { Expr(Kind kind) : kind_(kind) {} Kind kind_; - virtual int eval(const Ctx &ctx) = 0; + virtual size_t eval(const Ctx &ctx) = 0; virtual void dump(llvm::raw_fd_ostream &os) const = 0; virtual void dump(llvm::raw_ostream &os) const = 0; virtual std::vector unknowns() const = 0; @@ -233,7 +234,7 @@ struct KokkosMdrangeIterationPass struct Add : public Binary { Add(std::shared_ptr lhs, std::shared_ptr rhs) : Binary(Kind::Add, "+", lhs, rhs) {} - virtual int eval(const Ctx &ctx) override { + virtual size_t eval(const Ctx &ctx) override { return lhs_->eval(ctx) + rhs_->eval(ctx); } @@ -264,7 +265,7 @@ struct KokkosMdrangeIterationPass struct Sub : public Binary { Sub(std::shared_ptr lhs, std::shared_ptr rhs) : Binary(Kind::Add, "-", lhs, rhs) {} - virtual int eval(const Ctx &ctx) override { + virtual size_t eval(const Ctx &ctx) override { return lhs_->eval(ctx) + rhs_->eval(ctx); } @@ -295,7 +296,7 @@ struct KokkosMdrangeIterationPass struct Mul : public Binary { Mul(std::shared_ptr lhs, std::shared_ptr rhs) : Binary(Kind::Mul, "*", lhs, rhs) {} - virtual int eval(const Ctx &ctx) override { + virtual size_t eval(const Ctx &ctx) override { return lhs_->eval(ctx) * rhs_->eval(ctx); } @@ -334,7 +335,7 @@ struct KokkosMdrangeIterationPass struct Div : public Binary { Div(std::shared_ptr lhs, std::shared_ptr rhs) : Binary(Kind::Div, "/", lhs, rhs) {} - virtual int eval(const Ctx &ctx) override { + virtual size_t eval(const Ctx &ctx) override { return lhs_->eval(ctx) / rhs_->eval(ctx); } @@ -369,7 +370,7 @@ struct KokkosMdrangeIterationPass Constant(int value) : Expr(Kind::Constant), value_(value) {} int value_; - virtual int eval(const Ctx &ctx) override { + virtual size_t eval(const Ctx &ctx) override { return value_; } @@ -407,7 +408,7 @@ struct KokkosMdrangeIterationPass Unknown(const std::string &name) : Expr(Kind::Unknown), name_(name) {} std::string name_; - virtual int eval(const Ctx &ctx) override { + virtual size_t eval(const Ctx &ctx) override { return ctx.values.at(name_); } @@ -452,7 +453,7 @@ struct KokkosMdrangeIterationPass std::shared_ptr stride_; // stride of the memref w.r.t an induction variable std::shared_ptr count_; // number of times the memref is executed - int sf_; // scaling factor, 1 for load, 3 for store + size_t sf_; // scaling factor, 1 for load, 3 for store std::vector unknowns() const { std::vector ret; @@ -897,36 +898,41 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri } // cost of a module with a given parallel configuration - static size_t model_cost(ModuleOp &mod, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { + static size_t model_cost(ModuleOp &mod, const ParallelConfig &cfg, const MemrefInductionCosts &costTable, const std::vector &unknowns) { size_t cost = 0; mod.walk([&](Operation *op) { if (auto memrefOp = dyn_cast(op)) { - cost += model_cost(memrefOp, cfg, costTable); + cost += model_cost(memrefOp, cfg, costTable, unknowns); } else if (auto memrefOp = dyn_cast(op)) { - cost += model_cost(memrefOp, cfg, costTable); + cost += model_cost(memrefOp, cfg, costTable, unknowns); } }); // walk return cost; } - static size_t monte_carlo(const Cost &model, int n = 5, int seed = 31337) { + // TODO: unknowns is all unknowns combined from all models + // this means unkowns is the same for all calls here + // this is a bit confusing since model also has model.unkowns() which is a subset + static size_t monte_carlo(const std::vector &unknowns, const Cost &model, int n = 500, int seed = 31337) { std::mt19937 gen(seed); std::vector costs; - std::vector unknowns = model.unknowns(); - for (int i = 0; i < n; i++) { + // MDRANGE_DEBUG("MC iteration " << i << ":\n"); + // generate random values for all unknowns in cost model Ctx ctx; - for (auto &name : unknowns) { + for (auto &unk : unknowns) { auto val = log_random_int(gen, 1, 100000); - MDRANGE_DEBUG(name << ": " << val << "\n"); - ctx.values[name] = val; + // MDRANGE_DEBUG(unk << ": " << val << "\n"); + ctx.values[unk] = val; } - costs.push_back(model.stride_->eval(ctx) * model.count_->eval(ctx) * model.sf_); + const size_t cost = model.stride_->eval(ctx) * model.count_->eval(ctx) * model.sf_; + // MDRANGE_DEBUG("MC iteration " << i << " cost=" << cost << "\n"); + costs.push_back(cost); } // FIXME: here we do median, is there a principled aggregation strategy? @@ -943,7 +949,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri } template - static size_t model_cost(MemrefOp &memrefOp, const ParallelConfig &cfg, const MemrefInductionCosts &costTable) { + static size_t model_cost(MemrefOp &memrefOp, const ParallelConfig &cfg, const MemrefInductionCosts &costTable, const std::vector &unknowns) { static_assert(std::is_same_v || std::is_same_v); MDRANGE_DEBUG("model cost of " << memrefOp << "...\n"); @@ -964,7 +970,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri auto costKey = std::make_pair(memrefOp, leftMostVar); if (auto jt = costTable.find(costKey); jt != costTable.end()) { Cost model = jt->second; - size_t parallelContrib = monte_carlo(model); + size_t parallelContrib = monte_carlo(unknowns, model); cost += parallelContrib; MDRANGE_DEBUG("..." << memrefOp << " contributes " << parallelContrib << " (now " << cost << ")\n"); } else { @@ -1056,6 +1062,21 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri MDRANGE_DEBUG("====\nbuild_cost_table\n====\n"); MemrefInductionCosts costTable = build_cost_table(module, tripCounts); + MDRANGE_DEBUG("====\nunknowns:\n====\n"); + std::vector unknowns; + for (const auto &kv : costTable) { + for (const std::string &unk : kv.second.unknowns()) { + if (unknowns.end() == std::find(unknowns.begin(), unknowns.end(), unk)) { + unknowns.push_back(unk); + } + } + } + MDRANGE_DEBUG(unknowns.size() << " unknowns:\n"); + std::sort(unknowns.begin(), unknowns.end()); + for (const std::string &unk : unknowns) { + MDRANGE_DEBUG(unk << "\n"); + } + MDRANGE_DEBUG("====\nExtract parallel ops\n====\n"); auto parallelOps = get_parallel_ops(module); @@ -1073,7 +1094,7 @@ static MemrefInductionCosts get_costs(Memref &memrefOp, IterationSpaceExprs &tri MDRANGE_DEBUG("}\n"); } - size_t cost = model_cost(module, cfg, costTable); + size_t cost = model_cost(module, cfg, costTable, unknowns); MDRANGE_DEBUG("cost was " << cost << "\n"); if (cost < minCost) { MDRANGE_DEBUG("Info: new optimal! cost=" << cost << "\n");