Skip to content

Commit d77230f

Browse files
committed
Address Tom's review comments
1 parent 2e2b2ca commit d77230f

File tree

7 files changed

+258
-198
lines changed

7 files changed

+258
-198
lines changed

flang/include/flang/Optimizer/OpenMP/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def GenericLoopConversionPass
112112
];
113113
}
114114

115-
def SimdOnlyPass : Pass<"omp-simd-only", "mlir::func::FuncOp"> {
115+
def SimdOnlyPass : Pass<"omp-simd-only", "mlir::ModuleOp"> {
116116
let summary = "Filters out non-simd OpenMP constructs";
117117
let dependentDialects = ["mlir::omp::OpenMPDialect"];
118118
}

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,10 @@ void ClauseProcessor::processTODO(mlir::Location currentLocation,
208208
if (!x)
209209
return;
210210
unsigned version = semaCtx.langOptions().OpenMPVersion;
211-
if (!semaCtx.langOptions().OpenMPSimd)
211+
bool isSimdDirective = llvm::omp::getOpenMPDirectiveName(directive, version)
212+
.upper()
213+
.find("SIMD") != llvm::StringRef::npos;
214+
if (!semaCtx.langOptions().OpenMPSimd || isSimdDirective)
212215
TODO(currentLocation,
213216
"Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() +
214217
" in " +

flang/lib/Optimizer/OpenMP/SimdOnly.cpp

Lines changed: 129 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1+
//===-- SimdOnly.cpp ------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
19
#include "flang/Optimizer/Builder/FIRBuilder.h"
210
#include "flang/Optimizer/Transforms/Utils.h"
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
312
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
413
#include "mlir/Dialect/Func/IR/FuncOps.h"
514
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
6-
#include "mlir/IR/IRMapping.h"
15+
#include "mlir/IR/MLIRContext.h"
16+
#include "mlir/IR/Operation.h"
17+
#include "mlir/IR/PatternMatch.h"
718
#include "mlir/Pass/Pass.h"
8-
#include "mlir/Transforms/DialectConversion.h"
19+
#include "mlir/Support/LLVM.h"
920
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
10-
#include <llvm/Support/Debug.h>
11-
#include <mlir/IR/MLIRContext.h>
12-
#include <mlir/IR/Operation.h>
13-
#include <mlir/IR/PatternMatch.h>
14-
#include <mlir/Support/LLVM.h>
21+
#include "llvm/Support/Debug.h"
1522

1623
namespace flangomp {
1724
#define GEN_PASS_DEF_SIMDONLYPASS
@@ -44,8 +51,15 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
4451
return rewriter.notifyMatchFailure(op, "Op is a plain SimdOp");
4552
}
4653

47-
if (op->getParentOfType<mlir::omp::SimdOp>())
48-
return rewriter.notifyMatchFailure(op, "Op is nested under a SimdOp");
54+
if (op->getParentOfType<mlir::omp::SimdOp>() &&
55+
(mlir::isa<mlir::omp::YieldOp>(op) ||
56+
mlir::isa<mlir::omp::LoopNestOp>(op) ||
57+
mlir::isa<mlir::omp::WsloopOp>(op) ||
58+
mlir::isa<mlir::omp::WorkshareLoopWrapperOp>(op) ||
59+
mlir::isa<mlir::omp::DistributeOp>(op) ||
60+
mlir::isa<mlir::omp::TaskloopOp>(op) ||
61+
mlir::isa<mlir::omp::TerminatorOp>(op)))
62+
return rewriter.notifyMatchFailure(op, "Op is part of a simd construct");
4963

5064
if (!mlir::isa<mlir::func::FuncOp>(op->getParentOp()) &&
5165
(mlir::isa<mlir::omp::TerminatorOp>(op) ||
@@ -67,6 +81,28 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
6781
LLVM_DEBUG(llvm::dbgs() << "SimdOnlyPass matched OpenMP op:\n");
6882
LLVM_DEBUG(op->dump());
6983

84+
auto eraseUnlessUsedBySimd = [&](mlir::Operation *ompOp,
85+
mlir::StringAttr name) {
86+
if (auto uses =
87+
mlir::SymbolTable::getSymbolUses(name, op->getParentOp())) {
88+
for (auto &use : *uses)
89+
if (mlir::isa<mlir::omp::SimdOp>(use.getUser()))
90+
return rewriter.notifyMatchFailure(op,
91+
"Op used by a simd construct");
92+
}
93+
rewriter.eraseOp(ompOp);
94+
return mlir::success();
95+
};
96+
97+
if (auto ompOp = mlir::dyn_cast<mlir::omp::PrivateClauseOp>(op))
98+
return eraseUnlessUsedBySimd(ompOp, ompOp.getSymNameAttr());
99+
if (auto ompOp = mlir::dyn_cast<mlir::omp::DeclareReductionOp>(op))
100+
return eraseUnlessUsedBySimd(ompOp, ompOp.getSymNameAttr());
101+
if (auto ompOp = mlir::dyn_cast<mlir::omp::CriticalDeclareOp>(op))
102+
return eraseUnlessUsedBySimd(ompOp, ompOp.getSymNameAttr());
103+
if (auto ompOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(op))
104+
return eraseUnlessUsedBySimd(ompOp, ompOp.getSymNameAttr());
105+
70106
// Erase ops that don't need any special handling
71107
if (mlir::isa<mlir::omp::BarrierOp>(op) ||
72108
mlir::isa<mlir::omp::FlushOp>(op) ||
@@ -87,75 +123,19 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
87123
fir::FirOpBuilder builder(rewriter, op);
88124
mlir::Location loc = op->getLoc();
89125

90-
auto inlineSimpleOp = [&](mlir::Operation *ompOp) -> bool {
91-
if (!ompOp)
92-
return false;
93-
94-
llvm::SmallVector<std::pair<mlir::Value, mlir::BlockArgument>>
95-
blockArgsPairs;
96-
if (auto iface =
97-
mlir::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(op)) {
98-
iface.getBlockArgsPairs(blockArgsPairs);
99-
for (auto [value, argument] : blockArgsPairs)
100-
rewriter.replaceAllUsesWith(argument, value);
101-
}
102-
103-
if (ompOp->getRegion(0).getBlocks().size() == 1) {
104-
auto &block = *ompOp->getRegion(0).getBlocks().begin();
105-
// This block is about to be removed so any arguments should have been
106-
// replaced by now.
107-
block.eraseArguments(0, block.getNumArguments());
108-
if (auto terminatorOp =
109-
mlir::dyn_cast<mlir::omp::TerminatorOp>(block.back())) {
110-
rewriter.eraseOp(terminatorOp);
111-
}
112-
rewriter.inlineBlockBefore(&block, op, {});
113-
} else {
114-
// When dealing with multi-block regions we need to fix up the control
115-
// flow
116-
auto *origBlock = ompOp->getBlock();
117-
auto *newBlock = rewriter.splitBlock(origBlock, ompOp->getIterator());
118-
auto *innerFrontBlock = &ompOp->getRegion(0).getBlocks().front();
119-
builder.setInsertionPointToEnd(origBlock);
120-
builder.create<mlir::cf::BranchOp>(loc, innerFrontBlock);
121-
// We are no longer passing any arguments to the first block in the
122-
// region, so this should be safe to erase.
123-
innerFrontBlock->eraseArguments(0, innerFrontBlock->getNumArguments());
124-
125-
for (auto &innerBlock : ompOp->getRegion(0).getBlocks()) {
126-
// Remove now-unused block arguments
127-
for (auto arg : innerBlock.getArguments()) {
128-
if (arg.getUses().empty())
129-
innerBlock.eraseArgument(arg.getArgNumber());
130-
}
131-
if (auto terminatorOp =
132-
mlir::dyn_cast<mlir::omp::TerminatorOp>(innerBlock.back())) {
133-
builder.setInsertionPointToEnd(&innerBlock);
134-
builder.create<mlir::cf::BranchOp>(loc, newBlock);
135-
rewriter.eraseOp(terminatorOp);
136-
}
137-
}
138-
139-
rewriter.inlineRegionBefore(ompOp->getRegion(0), newBlock);
140-
}
141-
142-
rewriter.eraseOp(op);
143-
return true;
144-
};
145-
146126
if (auto ompOp = mlir::dyn_cast<mlir::omp::LoopNestOp>(op)) {
147127
mlir::Type indexType = builder.getIndexType();
148128
mlir::Type oldIndexType = ompOp.getIVs().begin()->getType();
149129
builder.setInsertionPoint(op);
150-
auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
130+
auto one = mlir::arith::ConstantIndexOp::create(builder, loc, 1);
151131

152132
// Generate the new loop nest
153133
mlir::Block *nestBody = nullptr;
154134
fir::DoLoopOp outerLoop = nullptr;
155135
llvm::SmallVector<mlir::Value> loopIndArgs;
156136
for (auto extent : ompOp.getLoopUpperBounds()) {
157137
auto ub = builder.createConvert(loc, indexType, extent);
158-
auto doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, false);
138+
auto doLoop = fir::DoLoopOp::create(builder, loc, one, ub, one, false);
159139
nestBody = doLoop.getBody();
160140
builder.setInsertionPointToStart(nestBody);
161141
// Convert the indices to the type used inside the loop if needed
@@ -185,11 +165,12 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
185165
}
186166

187167
// Remove omp.yield at the end of the loop body
188-
if (auto yieldOp = mlir::dyn_cast<mlir::omp::YieldOp>(nestBody->back()))
168+
if (auto yieldOp =
169+
mlir::dyn_cast<mlir::omp::YieldOp>(nestBody->back())) {
170+
assert("omp.loop_nests's omp.yield has no operands" &&
171+
yieldOp->getNumOperands() == 0);
189172
rewriter.eraseOp(yieldOp);
190-
// DoLoopOp does not support multi-block regions, thus if we're dealing
191-
// with multiple blocks we need to convert it into basic control-flow
192-
// operations.
173+
}
193174
} else {
194175
rewriter.inlineRegionBefore(ompOp->getRegion(0), nestBody);
195176
auto indVarArg = outerLoop->getRegion(0).front().getArgument(0);
@@ -199,6 +180,9 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
199180
if (indVarArg.getType() != indexType)
200181
indVarArg.setType(indexType);
201182

183+
// fir.do_loop, unlike omp.loop_nest does not support multi-block
184+
// regions. If we're dealing with multiple blocks inside omp.loop_nest,
185+
// we need to convert it into basic control-flow operations instead.
202186
auto loopBlocks =
203187
fir::convertDoLoopToCFG(outerLoop, rewriter, false, false);
204188
auto *conditionalBlock = loopBlocks.first;
@@ -237,7 +221,9 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
237221
if (auto yieldOp =
238222
mlir::dyn_cast<mlir::omp::YieldOp>(loopBlock->back())) {
239223
builder.setInsertionPointToEnd(loopBlock);
240-
builder.create<mlir::cf::BranchOp>(loc, lastBlock);
224+
mlir::cf::BranchOp::create(builder, loc, lastBlock);
225+
assert("omp.loop_nests's omp.yield has no operands" &&
226+
yieldOp->getNumOperands() == 0);
241227
rewriter.eraseOp(yieldOp);
242228
}
243229
}
@@ -255,16 +241,16 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
255241

256242
if (auto atomicReadOp = mlir::dyn_cast<mlir::omp::AtomicReadOp>(op)) {
257243
builder.setInsertionPoint(op);
258-
auto loadOp = builder.create<fir::LoadOp>(loc, atomicReadOp.getX());
259-
auto storeOp = builder.create<fir::StoreOp>(loc, loadOp.getResult(),
260-
atomicReadOp.getV());
244+
auto loadOp = fir::LoadOp::create(builder, loc, atomicReadOp.getX());
245+
auto storeOp = fir::StoreOp::create(builder, loc, loadOp.getResult(),
246+
atomicReadOp.getV());
261247
rewriter.replaceOp(op, storeOp);
262248
return mlir::success();
263249
}
264250

265251
if (auto atomicWriteOp = mlir::dyn_cast<mlir::omp::AtomicWriteOp>(op)) {
266-
auto storeOp = builder.create<fir::StoreOp>(loc, atomicWriteOp.getExpr(),
267-
atomicWriteOp.getX());
252+
auto storeOp = fir::StoreOp::create(builder, loc, atomicWriteOp.getExpr(),
253+
atomicWriteOp.getX());
268254
rewriter.replaceOp(op, storeOp);
269255
return mlir::success();
270256
}
@@ -276,7 +262,7 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
276262
builder.setInsertionPointToStart(&block);
277263

278264
// Load the update `x` operand and replace its uses within the block
279-
auto loadOp = builder.create<fir::LoadOp>(loc, atomicUpdateOp.getX());
265+
auto loadOp = fir::LoadOp::create(builder, loc, atomicUpdateOp.getX());
280266
rewriter.replaceUsesWithIf(
281267
block.getArgument(0), loadOp.getResult(),
282268
[&](auto &op) { return op.get().getParentBlock() == &block; });
@@ -286,14 +272,14 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
286272
auto yieldOp = mlir::cast<mlir::omp::YieldOp>(block.back());
287273
assert("only one yield operand" && yieldOp->getNumOperands() == 1);
288274
builder.setInsertionPointAfter(yieldOp);
289-
builder.create<fir::StoreOp>(loc, yieldOp->getOperand(0),
290-
atomicUpdateOp.getX());
275+
fir::StoreOp::create(builder, loc, yieldOp->getOperand(0),
276+
atomicUpdateOp.getX());
291277
rewriter.eraseOp(yieldOp);
292278

293279
// Inline the final block and remove the now-empty op
294280
assert("only one block argument" && block.getNumArguments() == 1);
295281
block.eraseArguments(0, block.getNumArguments());
296-
rewriter.inlineBlockBefore(&block, op, {});
282+
rewriter.inlineBlockBefore(&block, atomicUpdateOp, {});
297283
rewriter.eraseOp(op);
298284
return mlir::success();
299285
}
@@ -305,6 +291,64 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
305291
return mlir::success();
306292
}
307293

294+
auto inlineSimpleOp = [&](mlir::Operation *ompOp) -> bool {
295+
if (!ompOp)
296+
return false;
297+
298+
assert("OpenMP operation has one region" && ompOp->getNumRegions() == 1);
299+
300+
llvm::SmallVector<std::pair<mlir::Value, mlir::BlockArgument>>
301+
blockArgsPairs;
302+
if (auto iface =
303+
mlir::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(op)) {
304+
iface.getBlockArgsPairs(blockArgsPairs);
305+
for (auto [value, argument] : blockArgsPairs)
306+
rewriter.replaceAllUsesWith(argument, value);
307+
}
308+
309+
if (ompOp->getRegion(0).getBlocks().size() == 1) {
310+
auto &block = *ompOp->getRegion(0).getBlocks().begin();
311+
// This block is about to be removed so any arguments should have been
312+
// replaced by now.
313+
block.eraseArguments(0, block.getNumArguments());
314+
if (auto terminatorOp =
315+
mlir::dyn_cast<mlir::omp::TerminatorOp>(block.back())) {
316+
rewriter.eraseOp(terminatorOp);
317+
}
318+
rewriter.inlineBlockBefore(&block, ompOp, {});
319+
} else {
320+
// When dealing with multi-block regions we need to fix up the control
321+
// flow
322+
auto *origBlock = ompOp->getBlock();
323+
auto *newBlock = rewriter.splitBlock(origBlock, ompOp->getIterator());
324+
auto *innerFrontBlock = &ompOp->getRegion(0).getBlocks().front();
325+
builder.setInsertionPointToEnd(origBlock);
326+
mlir::cf::BranchOp::create(builder, loc, innerFrontBlock);
327+
// We are no longer passing any arguments to the first block in the
328+
// region, so this should be safe to erase.
329+
innerFrontBlock->eraseArguments(0, innerFrontBlock->getNumArguments());
330+
331+
for (auto &innerBlock : ompOp->getRegion(0).getBlocks()) {
332+
// Remove now-unused block arguments
333+
for (auto arg : innerBlock.getArguments()) {
334+
if (arg.getUses().empty())
335+
innerBlock.eraseArgument(arg.getArgNumber());
336+
}
337+
if (auto terminatorOp =
338+
mlir::dyn_cast<mlir::omp::TerminatorOp>(innerBlock.back())) {
339+
builder.setInsertionPointToEnd(&innerBlock);
340+
mlir::cf::BranchOp::create(builder, loc, newBlock);
341+
rewriter.eraseOp(terminatorOp);
342+
}
343+
}
344+
345+
rewriter.inlineRegionBefore(ompOp->getRegion(0), newBlock);
346+
}
347+
348+
rewriter.eraseOp(op);
349+
return true;
350+
};
351+
308352
if (inlineSimpleOp(mlir::dyn_cast<mlir::omp::TeamsOp>(op)) ||
309353
inlineSimpleOp(mlir::dyn_cast<mlir::omp::ParallelOp>(op)) ||
310354
inlineSimpleOp(mlir::dyn_cast<mlir::omp::SingleOp>(op)) ||
@@ -324,7 +368,7 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
324368
inlineSimpleOp(mlir::dyn_cast<mlir::omp::MaskedOp>(op)))
325369
return mlir::success();
326370

327-
op->emitOpError("OpenMP operation left unhandled after SimdOnly pass.");
371+
op->emitOpError("left unhandled after SimdOnly pass.");
328372
return mlir::failure();
329373
}
330374
};
@@ -335,10 +379,7 @@ class SimdOnlyPass : public flangomp::impl::SimdOnlyPassBase<SimdOnlyPass> {
335379
SimdOnlyPass() = default;
336380

337381
void runOnOperation() override {
338-
mlir::func::FuncOp func = getOperation();
339-
340-
if (func.isDeclaration())
341-
return;
382+
mlir::ModuleOp module = getOperation();
342383

343384
mlir::MLIRContext *context = &getContext();
344385
mlir::RewritePatternSet patterns(context);
@@ -350,8 +391,8 @@ class SimdOnlyPass : public flangomp::impl::SimdOnlyPassBase<SimdOnlyPass> {
350391
mlir::GreedySimplifyRegionLevel::Disabled);
351392

352393
if (mlir::failed(
353-
mlir::applyPatternsGreedily(func, std::move(patterns), config))) {
354-
mlir::emitError(func.getLoc(), "error in simd-only conversion pass");
394+
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
395+
mlir::emitError(module.getLoc(), "error in simd-only conversion pass");
355396
signalPassFailure();
356397
}
357398
}

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_flang_library(FIRTransforms
3535
GenRuntimeCallsForTest.cpp
3636
SimplifyFIROperations.cpp
3737
OptimizeArrayRepacking.cpp
38+
Utils.cpp
3839

3940
DEPENDS
4041
CUFAttrs

0 commit comments

Comments
 (0)