Skip to content

Commit 57c0faf

Browse files
Merge commit 'bea27e37b6585e602322d3206dd0b8fcefe8523a'
2 parents b5c46f0 + bea27e3 commit 57c0faf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2113
-422
lines changed

include/triton/Dialect/Gluon/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,14 @@ def GluonInline: Pass<"gluon-inline"> {
3535
let dependentDialects = [];
3636
}
3737

38+
def GluonSimplifyControlFlow: Pass<"gluon-slimplify-control-flow"> {
39+
let summary = "simplications for control flow ops";
40+
41+
let description = [{
42+
The `gluon-inline` pass applies a reduced set of simplification
43+
and canonicalization patterns to the module.
44+
}];
45+
let dependentDialects = [];
46+
}
47+
3848
#endif

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
135135
ArrayRef<unsigned> tilesPerWarp,
136136
ArrayRef<unsigned> warpsPerCTA);
137137

138+
LinearLayout chooseScaledWmmaScaleLayout(
139+
MLIRContext *ctx, int dotOperandIdx,
140+
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
141+
ArrayRef<int64_t> dotOperandShape);
142+
138143
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, int dotOperandIdx,
139144
ArrayRef<int64_t> dotOperandShape,
140145
ArrayRef<unsigned> tilesPerWarp,

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,8 +1307,7 @@ Row |
13071307
let hasCustomAssemblyFormat = 1;
13081308

13091309
let extraClassDeclaration = extraDistributedDeclaration # [{
1310-
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
1311-
Type elemType, int opIdx) const;
1310+
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kDim, int opIdx) const;
13121311
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
13131312

13141313
static SmallVector<unsigned, 3> getDefaultInstrShape() {

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,4 +574,23 @@ def TTG_WarpReturnOp : TTG_Op<"warp_return", [
574574
let assemblyFormat = "attr-dict";
575575
}
576576

577+
def TTG_LocalBarrierOp : TTG_Op<"local_barrier"> {
578+
let summary = "Synchronizes execution and shared memory reads/writes for all threads in a CTA.";
579+
let description = [{
580+
The `local_barrier` op synchronizes the execution and all operations
581+
between shared memory and registers for all threads in a CTA.
582+
It is used to coordinate communication between the threads of the CTA.
583+
584+
This operation waits until all threads in the CTA have reached a `local_barrier`
585+
and operations between shared memory and registers made by these threads prior
586+
to the op are visible to all threads in the CTA.
587+
588+
Data hazards between threads accessing the same memory can be avoided by synchronizing the
589+
CTA in-between these accesses with a `local_barrier`.
590+
591+
A `local_barrier` operation does not provide syncronization guarantees on global memory.
592+
}];
593+
let assemblyFormat = "attr-dict";
594+
}
595+
577596
#endif // TRITONGPU_OPS

lib/Analysis/Membar.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,20 @@ void MembarOrFenceAnalysis::visitTerminator(
159159

160160
void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) {
161161
OpBuilder::InsertionGuard g(*builder);
162-
auto barrierOp = builder->create<gpu::BarrierOp>(op->getLoc());
162+
auto barrierOp = builder->create<triton::gpu::LocalBarrierOp>(op->getLoc());
163163
}
164164

165165
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
166166
FuncBlockInfoMapT *funcBlockInfoMap,
167167
OpBuilder *builder) {
168-
if (isa<gpu::BarrierOp>(op)) {
168+
if (isa<gpu::BarrierOp, triton::gpu::LocalBarrierOp>(op)) {
169169
// If the current op is a barrier, we sync previous reads and writes
170170
blockInfo->sync();
171171
return;
172172
}
173173

174174
if (isa<triton::gpu::AsyncWaitOp, triton::nvidia_gpu::TMAStoreWaitOp>(op) &&
175-
!isa<gpu::BarrierOp>(op->getNextNode())) {
175+
!isa<gpu::BarrierOp, triton::gpu::LocalBarrierOp>(op->getNextNode())) {
176176
// If the current op is an async wait and the next op is not a barrier we
177177
// insert a barrier op and sync
178178
builder->setInsertionPointAfter(op);

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "mlir/Conversion/LLVMCommon/Pattern.h"
22
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
3+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
34
#include "mlir/IR/PatternMatch.h"
45
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
56
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
@@ -232,6 +233,25 @@ struct LocalStoreOpConversion
232233
const TargetInfoBase &targetInfo;
233234
};
234235

236+
class LocalBarrierOpConversion
237+
: public ConvertOpToLLVMPattern<triton::gpu::LocalBarrierOp> {
238+
public:
239+
LocalBarrierOpConversion(const LLVMTypeConverter &converter,
240+
PatternBenefit benefit)
241+
: ConvertOpToLLVMPattern<triton::gpu::LocalBarrierOp>(converter,
242+
benefit) {}
243+
using OpAdaptor = typename triton::gpu::LocalBarrierOp::Adaptor;
244+
245+
LogicalResult
246+
matchAndRewrite(triton::gpu::LocalBarrierOp op, OpAdaptor adaptor,
247+
ConversionPatternRewriter &rewriter) const override {
248+
249+
rewriter.replaceOpWithNewOp<mlir::gpu::BarrierOp>(op);
250+
251+
return success();
252+
}
253+
};
254+
235255
} // namespace
236256

237257
void mlir::triton::populateMemoryOpToLLVMPatterns(
@@ -243,4 +263,5 @@ void mlir::triton::populateMemoryOpToLLVMPatterns(
243263
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
244264
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
245265
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);
266+
patterns.add<LocalBarrierOpConversion>(typeConverter, benefit);
246267
}

lib/Dialect/Gluon/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_triton_library(GluonTransforms
22
Canonicalize.cpp
33
Inline.cpp
44
ResolveAutoEncodings.cpp
5+
SimplifyControlFlow.cpp
56

67
DEPENDS
78
GluonTransformsIncGen

lib/Dialect/Gluon/Transforms/Inline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct Inline : public gluon::impl::GluonInlineBase<Inline> {
2222
void Inline::runOnOperation() {
2323
mlir::PassManager pm(&getContext());
2424
pm.addPass(createInlinerPass(/*opPipelines=*/{}, [](OpPassManager &pm) {
25-
pm.addPass(gluon::createGluonCanonicalize());
25+
pm.addPass(gluon::createGluonSimplifyControlFlow());
2626
}));
2727
if (failed(pm.run(getOperation())))
2828
return signalPassFailure();
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "mlir/IR/OperationSupport.h"
2+
#include "triton/Dialect/Gluon/Transforms/Passes.h"
3+
4+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
5+
6+
#include "mlir/Dialect/Arith/IR/Arith.h"
7+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
8+
#include "mlir/Dialect/SCF/IR/SCF.h"
9+
#include "mlir/Pass/Pass.h"
10+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
11+
12+
using namespace mlir;
13+
using namespace triton;
14+
15+
namespace mlir::triton::gluon {
16+
#define GEN_PASS_DEF_GLUONSIMPLIFYCONTROLFLOW
17+
#include "triton/Dialect/Gluon/Transforms/Passes.h.inc"
18+
} // namespace mlir::triton::gluon
19+
20+
namespace {
21+
struct SimplifyControlFlow
22+
: public gluon::impl::GluonSimplifyControlFlowBase<SimplifyControlFlow> {
23+
void runOnOperation() override;
24+
};
25+
} // namespace
26+
27+
void SimplifyControlFlow::runOnOperation() {
28+
MLIRContext *ctx = &getContext();
29+
RewritePatternSet patterns(&getContext());
30+
31+
// Populate `scf` and `cf` canonicalizers.
32+
ctx->getLoadedDialect<scf::SCFDialect>()->getCanonicalizationPatterns(
33+
patterns);
34+
ctx->getLoadedDialect<cf::ControlFlowDialect>()->getCanonicalizationPatterns(
35+
patterns);
36+
for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect(
37+
scf::SCFDialect::getDialectNamespace()))
38+
op.getCanonicalizationPatterns(patterns, ctx);
39+
for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect(
40+
cf::ControlFlowDialect::getDialectNamespace()))
41+
op.getCanonicalizationPatterns(patterns, ctx);
42+
populateForOpDeadArgumentElimination(patterns);
43+
44+
GreedyRewriteConfig config;
45+
// This is intended to run before AutoLayouts are resolved, in which case
46+
// CSEing constants can lead to additional layout conflicts.
47+
config.enableConstantCSE(false);
48+
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
49+
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,12 +2300,11 @@ AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
23002300
}
23012301

23022302
SmallVector<int64_t>
2303-
AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape,
2304-
Type elemType, int opIdx) const {
2303+
AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape, int kDim,
2304+
int opIdx) const {
23052305
auto mnkDim = getInstrShape();
2306-
auto operandTileShape = opIdx == 0
2307-
? SmallVector<int64_t>{mnkDim[0], mnkDim[2]}
2308-
: SmallVector<int64_t>{mnkDim[2], mnkDim[1]};
2306+
SmallVector<int64_t, 2> operandTileShape{opIdx == 0 ? mnkDim[0] : kDim,
2307+
opIdx == 0 ? kDim : mnkDim[1]};
23092308

23102309
assert(operandTileShape.size() == 2);
23112310
auto warpsPerCTA = getWarpsPerCTA();

0 commit comments

Comments
 (0)