-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[SeparateConstOffsetFromGEP] Decompose constant xor operand if possible #150438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) is part of base for memory operations. This transformation is true under the following conditions Check 1 - B and C are disjoint. Check 2 - XOR(A,C) and B are disjoint. This transformation can map these Xors in to better addressing mode and eventually decompose them in to geps.
@llvm/pr-subscribers-backend-amdgpu Author: Sumanth Gundapaneni (sgundapa) ChangesTry to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) is part of base for memory operations. This transformation is true under the following conditions This transformation can map these Xors in to better addressing mode and eventually decompose them in to geps. Full diff: https://github.com/llvm/llvm-project/pull/150438.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
index 320b79203c0b3..203850c28787c 100644
--- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
+++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
@@ -238,16 +238,17 @@ class ConstantOffsetExtractor {
/// \p PreservesNUW Outputs whether the extraction allows preserving the
/// GEP's nuw flag, if it has one.
static Value *Extract(Value *Idx, GetElementPtrInst *GEP,
- User *&UserChainTail, bool &PreservesNUW);
+ User *&UserChainTail, bool &PreservesNUW,
+ DominatorTree *DT);
/// Looks for a constant offset from the given GEP index without extracting
/// it. It returns the numeric value of the extracted constant offset (0 if
/// failed). The meaning of the arguments are the same as Extract.
- static int64_t Find(Value *Idx, GetElementPtrInst *GEP);
+ static int64_t Find(Value *Idx, GetElementPtrInst *GEP, DominatorTree *DT);
private:
- ConstantOffsetExtractor(BasicBlock::iterator InsertionPt)
- : IP(InsertionPt), DL(InsertionPt->getDataLayout()) {}
+ ConstantOffsetExtractor(BasicBlock::iterator InsertionPt, DominatorTree *DT)
+ : IP(InsertionPt), DT(DT), DL(InsertionPt->getDataLayout()) {}
/// Searches the expression that computes V for a non-zero constant C s.t.
/// V can be reassociated into the form V' + C. If the searching is
@@ -321,6 +322,20 @@ class ConstantOffsetExtractor {
bool CanTraceInto(bool SignExtended, bool ZeroExtended, BinaryOperator *BO,
bool NonNegative);
+ // Find the most dominating Xor with the same base operand.
+ BinaryOperator *findDominatingXor(Value *BaseOperand,
+ BinaryOperator *CurrentXor);
+
+ /// Check if Xor instruction should be considered for optimization.
+ bool shouldConsiderXor(BinaryOperator *XorInst);
+
+ /// Cache the information about Xor idiom.
+ struct XorRewriteInfo {
+ llvm::BinaryOperator *BaseXor = nullptr;
+ int64_t AdjustedOffset = 0;
+ };
+ std::optional<XorRewriteInfo> CachedXorInfo;
+
/// The path from the constant offset to the old GEP index. e.g., if the GEP
/// index is "a * b + (c + 5)". After running function find, UserChain[0] will
/// be the constant 5, UserChain[1] will be the subexpression "c + 5", and
@@ -336,6 +351,8 @@ class ConstantOffsetExtractor {
/// Insertion position of cloned instructions.
BasicBlock::iterator IP;
+ DominatorTree *DT;
+
const DataLayout &DL;
};
@@ -514,12 +531,14 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
bool ZeroExtended,
BinaryOperator *BO,
bool NonNegative) {
- // We only consider ADD, SUB and OR, because a non-zero constant found in
+ // We only consider ADD, SUB, OR and XOR, because a non-zero constant found in
// expressions composed of these operations can be easily hoisted as a
- // constant offset by reassociation.
+ // constant offset by reassociation. XOR is a special case and can be folded
+ // in to gep if the constant is proven to be disjoint.
if (BO->getOpcode() != Instruction::Add &&
BO->getOpcode() != Instruction::Sub &&
- BO->getOpcode() != Instruction::Or) {
+ BO->getOpcode() != Instruction::Or &&
+ BO->getOpcode() != Instruction::Xor) {
return false;
}
@@ -530,6 +549,10 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
!cast<PossiblyDisjointInst>(BO)->isDisjoint())
return false;
+ // Handle Xor idiom.
+ if (BO->getOpcode() == Instruction::Xor)
+ return shouldConsiderXor(BO);
+
// FIXME: We don't currently support constants from the RHS of subs,
// when we are zero-extended, because we need a way to zero-extended
// them before they are negated.
@@ -740,6 +763,10 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {
"UserChain, so no one should be used more than "
"once");
+ // Special case for Xor idiom.
+ if (BO->getOpcode() == Instruction::Xor)
+ return CachedXorInfo->BaseXor;
+
unsigned OpNo = (BO->getOperand(0) == UserChain[ChainIndex - 1] ? 0 : 1);
assert(BO->getOperand(OpNo) == UserChain[ChainIndex - 1]);
Value *NextInChain = removeConstOffset(ChainIndex - 1);
@@ -780,6 +807,80 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {
return NewBO;
}
+// Find the most dominating Xor with the same base operand.
+BinaryOperator *
+ConstantOffsetExtractor::findDominatingXor(Value *BaseOperand,
+ BinaryOperator *CurrentXor) {
+ BinaryOperator *MostDominatingXor = nullptr;
+ // Iterate over all instructions that use the BaseOperand.
+ for (User *U : BaseOperand->users()) {
+ auto *CandidateXor = dyn_cast<BinaryOperator>(U);
+
+ // Simple checks.
+ if (!CandidateXor || CandidateXor == CurrentXor)
+ continue;
+
+ // Check if the binary operator is a Xor with constant.
+ if (!match(CandidateXor, m_Xor(m_Specific(BaseOperand), m_ConstantInt())))
+ continue;
+
+ // After confirming the structure, check the dominance relationship.
+ if (DT->dominates(CandidateXor, CurrentXor))
+ // If we find a dominating Xor, keep it if it's the first one,
+ // or if it dominates the best candidate we've found so far.
+ if (!MostDominatingXor || DT->dominates(CandidateXor, MostDominatingXor))
+ MostDominatingXor = CandidateXor;
+ }
+
+ return MostDominatingXor;
+}
+
+// Check if Xor should be considered.
+// Only the following idiom is considered.
+// Example:
+// %3 = xor i32 %2, 32
+// %4 = xor i32 %2, 8224
+// %6 = getelementptr half, ptr addrspace(3) %1, i32 %3
+// %7 = getelementptr half, ptr addrspace(3) %1, i32 %4
+// GEP that corresponds to %7, looks at the binary operator %4.
+// In order for %4 to be considered, it should have a dominating xor with
+// constant offset that is disjoint with an adjusted offset.
+// If disjoint, %4 = xor i32 %2, 8224 can be treated as %4 = add i32 %3, 8192
+bool ConstantOffsetExtractor::shouldConsiderXor(BinaryOperator *XorInst) {
+
+ Value *BaseOperand = nullptr;
+ ConstantInt *CurrentConst = nullptr;
+ if (!match(XorInst, m_Xor(m_Value(BaseOperand), m_ConstantInt(CurrentConst))))
+ return false;
+
+ // Find the most dominating Xor with the same base operand.
+ BinaryOperator *DominatingXor = findDominatingXor(BaseOperand, XorInst);
+ if (!DominatingXor)
+ return false;
+
+ // We expect the dominating instruction to also be a 'xor-const'.
+ ConstantInt *DominatingConst = nullptr;
+ if (!match(DominatingXor,
+ m_Xor(m_Specific(BaseOperand), m_ConstantInt(DominatingConst))))
+ return false;
+
+ // Calculate the adjusted offset (difference between constants)
+ APInt AdjustedOffset = CurrentConst->getValue() - DominatingConst->getValue();
+
+ // Check disjoint conditions
+ // 1. AdjustedOffset and DominatingConst should be disjoint
+ if ((AdjustedOffset & DominatingConst->getValue()) != 0)
+ return false;
+
+ // 2. DominatingXor and AdjustedOffset should be disjoint
+ if (!MaskedValueIsZero(DominatingXor, AdjustedOffset, SimplifyQuery(DL), 0))
+ return false;
+
+ // Cache the result.
+ CachedXorInfo = XorRewriteInfo{DominatingXor, AdjustedOffset.getSExtValue()};
+ return true;
+}
+
/// A helper function to check if reassociating through an entry in the user
/// chain would invalidate the GEP's nuw flag.
static bool allowsPreservingNUW(const User *U) {
@@ -805,8 +906,8 @@ static bool allowsPreservingNUW(const User *U) {
Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP,
User *&UserChainTail,
- bool &PreservesNUW) {
- ConstantOffsetExtractor Extractor(GEP->getIterator());
+ bool &PreservesNUW, DominatorTree *DT) {
+ ConstantOffsetExtractor Extractor(GEP->getIterator(), DT);
// Find a non-zero constant offset first.
APInt ConstantOffset =
Extractor.find(Idx, /* SignExtended */ false, /* ZeroExtended */ false,
@@ -825,12 +926,20 @@ Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP,
return IdxWithoutConstOffset;
}
-int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP) {
+int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP,
+ DominatorTree *DT) {
// If Idx is an index of an inbound GEP, Idx is guaranteed to be non-negative.
- return ConstantOffsetExtractor(GEP->getIterator())
- .find(Idx, /* SignExtended */ false, /* ZeroExtended */ false,
- GEP->isInBounds())
- .getSExtValue();
+ ConstantOffsetExtractor Extractor(GEP->getIterator(), DT);
+ auto Offset = Extractor
+ .find(Idx, /* SignExtended */ false,
+ /* ZeroExtended */ false, GEP->isInBounds())
+ .getSExtValue();
+
+ // Return the disjoint offset for Xor.
+ if (Extractor.CachedXorInfo)
+ return Extractor.CachedXorInfo->AdjustedOffset;
+
+ return Offset;
}
bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToIndexSize(
@@ -866,7 +975,7 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP,
// Tries to extract a constant offset from this GEP index.
int64_t ConstantOffset =
- ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP);
+ ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP, DT);
if (ConstantOffset != 0) {
NeedsExtraction = true;
// A GEP may have multiple indices. We accumulate the extracted
@@ -1106,7 +1215,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
User *UserChainTail;
bool PreservesNUW;
Value *NewIdx = ConstantOffsetExtractor::Extract(
- OldIdx, GEP, UserChainTail, PreservesNUW);
+ OldIdx, GEP, UserChainTail, PreservesNUW, DT);
if (NewIdx != nullptr) {
// Switches to the index with the constant offset removed.
GEP->setOperand(I, NewIdx);
diff --git a/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
new file mode 100644
index 0000000000000..a0d0de070e735
--- /dev/null
+++ b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -mtriple=amdgcn-amd-amdhsa -passes=separate-const-offset-from-gep \
+; RUN: -S < %s | FileCheck %s
+
+; Test that xor idiom.
+; Xors with disjoint constants 4128,8224 and 12320 must be expressed in GEPs.
+; Xors with non-disjoint constants 2336 and 8480, should not be optimized.
+define amdgpu_kernel void @test1(i1 %0, ptr addrspace(3) %1) {
+; CHECK-LABEL: define amdgpu_kernel void @test1(
+; CHECK-SAME: i1 [[TMP0:%.*]], ptr addrspace(3) [[TMP1:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP0]], i32 0, i32 288
+; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP2]], 32
+; CHECK-NEXT: [[TMP14:%.*]] = xor i32 [[TMP2]], 2336
+; CHECK-NEXT: [[TMP5:%.*]] = xor i32 [[TMP2]], 8480
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP16:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP14]]
+; CHECK-NEXT: [[TMP20:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP21:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP20]], i32 8192
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP7]], i32 16384
+; CHECK-NEXT: [[TMP15:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP5]]
+; CHECK-NEXT: [[TMP25:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP25]], i32 24576
+; CHECK-NEXT: [[TMP9:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP4]], align 16
+; CHECK-NEXT: [[TMP10:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP16]], align 16
+; CHECK-NEXT: [[TMP17:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP21]], align 16
+; CHECK-NEXT: [[TMP18:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP6]], align 16
+; CHECK-NEXT: [[TMP19:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP15]], align 16
+; CHECK-NEXT: [[TMP11:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP8]], align 16
+; CHECK-NEXT: [[TMP12:%.*]] = fadd <8 x half> [[TMP9]], [[TMP10]]
+; CHECK-NEXT: [[TMP22:%.*]] = fadd <8 x half> [[TMP17]], [[TMP18]]
+; CHECK-NEXT: [[TMP23:%.*]] = fadd <8 x half> [[TMP19]], [[TMP11]]
+; CHECK-NEXT: [[TMP24:%.*]] = fadd <8 x half> [[TMP12]], [[TMP22]]
+; CHECK-NEXT: [[TMP13:%.*]] = fadd <8 x half> [[TMP23]], [[TMP24]]
+; CHECK-NEXT: store <8 x half> [[TMP13]], ptr addrspace(3) [[TMP1]], align 16
+; CHECK-NEXT: ret void
+;
+entry:
+ %2 = select i1 %0, i32 0, i32 288
+ %3 = xor i32 %2, 32 // Base
+ %4 = xor i32 %2, 2336 // Not disjoint
+ %5 = xor i32 %2, 4128 // Disjoint
+ %6 = xor i32 %2, 8224 // Disjoint
+ %7 = xor i32 %2, 8480 // Not disjoint
+ %8 = xor i32 %2, 12320 // Disjoint
+ %9 = getelementptr half, ptr addrspace(3) %1, i32 %3
+ %10 = getelementptr half, ptr addrspace(3) %1, i32 %4
+ %11 = getelementptr half, ptr addrspace(3) %1, i32 %5
+ %12 = getelementptr half, ptr addrspace(3) %1, i32 %6
+ %13 = getelementptr half, ptr addrspace(3) %1, i32 %7
+ %14 = getelementptr half, ptr addrspace(3) %1, i32 %8
+ %15 = load <8 x half>, ptr addrspace(3) %9, align 16
+ %16 = load <8 x half>, ptr addrspace(3) %10, align 16
+ %17 = load <8 x half>, ptr addrspace(3) %11, align 16
+ %18 = load <8 x half>, ptr addrspace(3) %12, align 16
+ %19 = load <8 x half>, ptr addrspace(3) %13, align 16
+ %20 = load <8 x half>, ptr addrspace(3) %14, align 16
+ %21 = fadd <8 x half> %15, %16
+ %22 = fadd <8 x half> %17, %18
+ %23 = fadd <8 x half> %19, %20
+ %24 = fadd <8 x half> %21, %22
+ %25 = fadd <8 x half> %23, %24
+ store <8 x half> %25, ptr addrspace(3) %1, align 16
+ ret void
+}
|
@llvm/pr-subscribers-llvm-transforms Author: Sumanth Gundapaneni (sgundapa) ChangesTry to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) is part of base for memory operations. This transformation is true under the following conditions This transformation can map these Xors in to better addressing mode and eventually decompose them in to geps. Full diff: https://github.com/llvm/llvm-project/pull/150438.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
index 320b79203c0b3..203850c28787c 100644
--- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
+++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
@@ -238,16 +238,17 @@ class ConstantOffsetExtractor {
/// \p PreservesNUW Outputs whether the extraction allows preserving the
/// GEP's nuw flag, if it has one.
static Value *Extract(Value *Idx, GetElementPtrInst *GEP,
- User *&UserChainTail, bool &PreservesNUW);
+ User *&UserChainTail, bool &PreservesNUW,
+ DominatorTree *DT);
/// Looks for a constant offset from the given GEP index without extracting
/// it. It returns the numeric value of the extracted constant offset (0 if
/// failed). The meaning of the arguments are the same as Extract.
- static int64_t Find(Value *Idx, GetElementPtrInst *GEP);
+ static int64_t Find(Value *Idx, GetElementPtrInst *GEP, DominatorTree *DT);
private:
- ConstantOffsetExtractor(BasicBlock::iterator InsertionPt)
- : IP(InsertionPt), DL(InsertionPt->getDataLayout()) {}
+ ConstantOffsetExtractor(BasicBlock::iterator InsertionPt, DominatorTree *DT)
+ : IP(InsertionPt), DT(DT), DL(InsertionPt->getDataLayout()) {}
/// Searches the expression that computes V for a non-zero constant C s.t.
/// V can be reassociated into the form V' + C. If the searching is
@@ -321,6 +322,20 @@ class ConstantOffsetExtractor {
bool CanTraceInto(bool SignExtended, bool ZeroExtended, BinaryOperator *BO,
bool NonNegative);
+ // Find the most dominating Xor with the same base operand.
+ BinaryOperator *findDominatingXor(Value *BaseOperand,
+ BinaryOperator *CurrentXor);
+
+ /// Check if Xor instruction should be considered for optimization.
+ bool shouldConsiderXor(BinaryOperator *XorInst);
+
+ /// Cache the information about Xor idiom.
+ struct XorRewriteInfo {
+ llvm::BinaryOperator *BaseXor = nullptr;
+ int64_t AdjustedOffset = 0;
+ };
+ std::optional<XorRewriteInfo> CachedXorInfo;
+
/// The path from the constant offset to the old GEP index. e.g., if the GEP
/// index is "a * b + (c + 5)". After running function find, UserChain[0] will
/// be the constant 5, UserChain[1] will be the subexpression "c + 5", and
@@ -336,6 +351,8 @@ class ConstantOffsetExtractor {
/// Insertion position of cloned instructions.
BasicBlock::iterator IP;
+ DominatorTree *DT;
+
const DataLayout &DL;
};
@@ -514,12 +531,14 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
bool ZeroExtended,
BinaryOperator *BO,
bool NonNegative) {
- // We only consider ADD, SUB and OR, because a non-zero constant found in
+ // We only consider ADD, SUB, OR and XOR, because a non-zero constant found in
// expressions composed of these operations can be easily hoisted as a
- // constant offset by reassociation.
+ // constant offset by reassociation. XOR is a special case and can be folded
+ // in to gep if the constant is proven to be disjoint.
if (BO->getOpcode() != Instruction::Add &&
BO->getOpcode() != Instruction::Sub &&
- BO->getOpcode() != Instruction::Or) {
+ BO->getOpcode() != Instruction::Or &&
+ BO->getOpcode() != Instruction::Xor) {
return false;
}
@@ -530,6 +549,10 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
!cast<PossiblyDisjointInst>(BO)->isDisjoint())
return false;
+ // Handle Xor idiom.
+ if (BO->getOpcode() == Instruction::Xor)
+ return shouldConsiderXor(BO);
+
// FIXME: We don't currently support constants from the RHS of subs,
// when we are zero-extended, because we need a way to zero-extended
// them before they are negated.
@@ -740,6 +763,10 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {
"UserChain, so no one should be used more than "
"once");
+ // Special case for Xor idiom.
+ if (BO->getOpcode() == Instruction::Xor)
+ return CachedXorInfo->BaseXor;
+
unsigned OpNo = (BO->getOperand(0) == UserChain[ChainIndex - 1] ? 0 : 1);
assert(BO->getOperand(OpNo) == UserChain[ChainIndex - 1]);
Value *NextInChain = removeConstOffset(ChainIndex - 1);
@@ -780,6 +807,80 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {
return NewBO;
}
+// Find the most dominating Xor with the same base operand.
+BinaryOperator *
+ConstantOffsetExtractor::findDominatingXor(Value *BaseOperand,
+ BinaryOperator *CurrentXor) {
+ BinaryOperator *MostDominatingXor = nullptr;
+ // Iterate over all instructions that use the BaseOperand.
+ for (User *U : BaseOperand->users()) {
+ auto *CandidateXor = dyn_cast<BinaryOperator>(U);
+
+ // Simple checks.
+ if (!CandidateXor || CandidateXor == CurrentXor)
+ continue;
+
+ // Check if the binary operator is a Xor with constant.
+ if (!match(CandidateXor, m_Xor(m_Specific(BaseOperand), m_ConstantInt())))
+ continue;
+
+ // After confirming the structure, check the dominance relationship.
+ if (DT->dominates(CandidateXor, CurrentXor))
+ // If we find a dominating Xor, keep it if it's the first one,
+ // or if it dominates the best candidate we've found so far.
+ if (!MostDominatingXor || DT->dominates(CandidateXor, MostDominatingXor))
+ MostDominatingXor = CandidateXor;
+ }
+
+ return MostDominatingXor;
+}
+
+// Check if Xor should be considered.
+// Only the following idiom is considered.
+// Example:
+// %3 = xor i32 %2, 32
+// %4 = xor i32 %2, 8224
+// %6 = getelementptr half, ptr addrspace(3) %1, i32 %3
+// %7 = getelementptr half, ptr addrspace(3) %1, i32 %4
+// GEP that corresponds to %7, looks at the binary operator %4.
+// In order for %4 to be considered, it should have a dominating xor with
+// constant offset that is disjoint with an adjusted offset.
+// If disjoint, %4 = xor i32 %2, 8224 can be treated as %4 = add i32 %3, 8192
+bool ConstantOffsetExtractor::shouldConsiderXor(BinaryOperator *XorInst) {
+
+ Value *BaseOperand = nullptr;
+ ConstantInt *CurrentConst = nullptr;
+ if (!match(XorInst, m_Xor(m_Value(BaseOperand), m_ConstantInt(CurrentConst))))
+ return false;
+
+ // Find the most dominating Xor with the same base operand.
+ BinaryOperator *DominatingXor = findDominatingXor(BaseOperand, XorInst);
+ if (!DominatingXor)
+ return false;
+
+ // We expect the dominating instruction to also be a 'xor-const'.
+ ConstantInt *DominatingConst = nullptr;
+ if (!match(DominatingXor,
+ m_Xor(m_Specific(BaseOperand), m_ConstantInt(DominatingConst))))
+ return false;
+
+ // Calculate the adjusted offset (difference between constants)
+ APInt AdjustedOffset = CurrentConst->getValue() - DominatingConst->getValue();
+
+ // Check disjoint conditions
+ // 1. AdjustedOffset and DominatingConst should be disjoint
+ if ((AdjustedOffset & DominatingConst->getValue()) != 0)
+ return false;
+
+ // 2. DominatingXor and AdjustedOffset should be disjoint
+ if (!MaskedValueIsZero(DominatingXor, AdjustedOffset, SimplifyQuery(DL), 0))
+ return false;
+
+ // Cache the result.
+ CachedXorInfo = XorRewriteInfo{DominatingXor, AdjustedOffset.getSExtValue()};
+ return true;
+}
+
/// A helper function to check if reassociating through an entry in the user
/// chain would invalidate the GEP's nuw flag.
static bool allowsPreservingNUW(const User *U) {
@@ -805,8 +906,8 @@ static bool allowsPreservingNUW(const User *U) {
Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP,
User *&UserChainTail,
- bool &PreservesNUW) {
- ConstantOffsetExtractor Extractor(GEP->getIterator());
+ bool &PreservesNUW, DominatorTree *DT) {
+ ConstantOffsetExtractor Extractor(GEP->getIterator(), DT);
// Find a non-zero constant offset first.
APInt ConstantOffset =
Extractor.find(Idx, /* SignExtended */ false, /* ZeroExtended */ false,
@@ -825,12 +926,20 @@ Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP,
return IdxWithoutConstOffset;
}
-int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP) {
+int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP,
+ DominatorTree *DT) {
// If Idx is an index of an inbound GEP, Idx is guaranteed to be non-negative.
- return ConstantOffsetExtractor(GEP->getIterator())
- .find(Idx, /* SignExtended */ false, /* ZeroExtended */ false,
- GEP->isInBounds())
- .getSExtValue();
+ ConstantOffsetExtractor Extractor(GEP->getIterator(), DT);
+ auto Offset = Extractor
+ .find(Idx, /* SignExtended */ false,
+ /* ZeroExtended */ false, GEP->isInBounds())
+ .getSExtValue();
+
+ // Return the disjoint offset for Xor.
+ if (Extractor.CachedXorInfo)
+ return Extractor.CachedXorInfo->AdjustedOffset;
+
+ return Offset;
}
bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToIndexSize(
@@ -866,7 +975,7 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP,
// Tries to extract a constant offset from this GEP index.
int64_t ConstantOffset =
- ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP);
+ ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP, DT);
if (ConstantOffset != 0) {
NeedsExtraction = true;
// A GEP may have multiple indices. We accumulate the extracted
@@ -1106,7 +1215,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
User *UserChainTail;
bool PreservesNUW;
Value *NewIdx = ConstantOffsetExtractor::Extract(
- OldIdx, GEP, UserChainTail, PreservesNUW);
+ OldIdx, GEP, UserChainTail, PreservesNUW, DT);
if (NewIdx != nullptr) {
// Switches to the index with the constant offset removed.
GEP->setOperand(I, NewIdx);
diff --git a/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
new file mode 100644
index 0000000000000..a0d0de070e735
--- /dev/null
+++ b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -mtriple=amdgcn-amd-amdhsa -passes=separate-const-offset-from-gep \
+; RUN: -S < %s | FileCheck %s
+
+; Test that xor idiom.
+; Xors with disjoint constants 4128,8224 and 12320 must be expressed in GEPs.
+; Xors with non-disjoint constants 2336 and 8480, should not be optimized.
+define amdgpu_kernel void @test1(i1 %0, ptr addrspace(3) %1) {
+; CHECK-LABEL: define amdgpu_kernel void @test1(
+; CHECK-SAME: i1 [[TMP0:%.*]], ptr addrspace(3) [[TMP1:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP0]], i32 0, i32 288
+; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP2]], 32
+; CHECK-NEXT: [[TMP14:%.*]] = xor i32 [[TMP2]], 2336
+; CHECK-NEXT: [[TMP5:%.*]] = xor i32 [[TMP2]], 8480
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP16:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP14]]
+; CHECK-NEXT: [[TMP20:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP21:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP20]], i32 8192
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP7]], i32 16384
+; CHECK-NEXT: [[TMP15:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP5]]
+; CHECK-NEXT: [[TMP25:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP25]], i32 24576
+; CHECK-NEXT: [[TMP9:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP4]], align 16
+; CHECK-NEXT: [[TMP10:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP16]], align 16
+; CHECK-NEXT: [[TMP17:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP21]], align 16
+; CHECK-NEXT: [[TMP18:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP6]], align 16
+; CHECK-NEXT: [[TMP19:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP15]], align 16
+; CHECK-NEXT: [[TMP11:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP8]], align 16
+; CHECK-NEXT: [[TMP12:%.*]] = fadd <8 x half> [[TMP9]], [[TMP10]]
+; CHECK-NEXT: [[TMP22:%.*]] = fadd <8 x half> [[TMP17]], [[TMP18]]
+; CHECK-NEXT: [[TMP23:%.*]] = fadd <8 x half> [[TMP19]], [[TMP11]]
+; CHECK-NEXT: [[TMP24:%.*]] = fadd <8 x half> [[TMP12]], [[TMP22]]
+; CHECK-NEXT: [[TMP13:%.*]] = fadd <8 x half> [[TMP23]], [[TMP24]]
+; CHECK-NEXT: store <8 x half> [[TMP13]], ptr addrspace(3) [[TMP1]], align 16
+; CHECK-NEXT: ret void
+;
+entry:
+ %2 = select i1 %0, i32 0, i32 288
+ %3 = xor i32 %2, 32 // Base
+ %4 = xor i32 %2, 2336 // Not disjoint
+ %5 = xor i32 %2, 4128 // Disjoint
+ %6 = xor i32 %2, 8224 // Disjoint
+ %7 = xor i32 %2, 8480 // Not disjoint
+ %8 = xor i32 %2, 12320 // Disjoint
+ %9 = getelementptr half, ptr addrspace(3) %1, i32 %3
+ %10 = getelementptr half, ptr addrspace(3) %1, i32 %4
+ %11 = getelementptr half, ptr addrspace(3) %1, i32 %5
+ %12 = getelementptr half, ptr addrspace(3) %1, i32 %6
+ %13 = getelementptr half, ptr addrspace(3) %1, i32 %7
+ %14 = getelementptr half, ptr addrspace(3) %1, i32 %8
+ %15 = load <8 x half>, ptr addrspace(3) %9, align 16
+ %16 = load <8 x half>, ptr addrspace(3) %10, align 16
+ %17 = load <8 x half>, ptr addrspace(3) %11, align 16
+ %18 = load <8 x half>, ptr addrspace(3) %12, align 16
+ %19 = load <8 x half>, ptr addrspace(3) %13, align 16
+ %20 = load <8 x half>, ptr addrspace(3) %14, align 16
+ %21 = fadd <8 x half> %15, %16
+ %22 = fadd <8 x half> %17, %18
+ %23 = fadd <8 x half> %19, %20
+ %24 = fadd <8 x half> %21, %22
+ %25 = fadd <8 x half> %23, %24
+ store <8 x half> %25, ptr addrspace(3) %1, align 16
+ ret void
+}
|
|
||
// Find the most dominating Xor with the same base operand. | ||
BinaryOperator *DominatingXor = findDominatingXor(BaseOperand, XorInst); | ||
if (!DominatingXor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are cases where we can extract a constant offset even if there is no dominating xor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. This is true for cases like this
%2 = select i1 %0, i32 0, i32 288
%3 = xor i32 %2, 12320
%4 = xor i32 %2, 8224
%5 = xor i32 %2, 4096
The disjoinedness is hard to prove with subtraction if %3 is considered as BaseXor.
In this case, all the constants need to be reasoned and come up with new base. This is what I did in my first cut.
%2 = select i1 %0, i32 0, i32 288
%tmp = xor i32 %2, 32
%3 = xor i32 %tmp, 12288
%4 = xor i32 %tmp, 8192
%5 = xor i32 %tmp, 4064
Let me know what you think.
If this patch goes through, I can build upon this to address the case above.
This is one of the reasons why I went with the earlier approach of traversing the entire function instead of one xor at a time.
Let me know if you have any examples to share Jeff.
return false; | ||
|
||
// Find the most dominating Xor with the same base operand. | ||
BinaryOperator *DominatingXor = findDominatingXor(BaseOperand, XorInst); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we should be trying to find a dominating xor. I think it makes sense, at least in the initial PR, to always extract out the maximum disjoint constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, finding the maximum disjoint constant and use it is the best way to optimize.
Considering the linear layout of Triton, there may exist a xor that satifies the maximum disjoint constant (but, it is an assumption).
It would not hurt if we can find the maximum disjoint constant as I am already scanning all the xors for each GEP :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignoring the implementation details, I was thinking something like this --
bool ConstantOffsetExtractor::shouldConsiderXor(BinaryOperator *XorInst) {
Value *BaseOperand = nullptr;
ConstantInt *CurrentConst = nullptr;
if (!match(XorInst, m_Xor(m_Value(BaseOperand), m_ConstantInt(CurrentConst))))
return false;
auto SQ = SimplifyQuery(DL, DT, nullptr, nullptr, false);
auto KnownBitsValue = computeKnownBits(BaseOperand, SQ);
auto KnownBitsConst = computeKnownBits(CurrentConst, SQ);
auto DisjointBits = KnownBitsConst.One & KnownBitsValue.Zero;
if (DisjointBits.isZero())
return false;
auto NonDisjointBits = KnownBitsConst.One & KnownBitsValue.getMaxValue();
XorInst->setOperand(1, ConstantInt::get(CurrentConst->getType(), NonDisjointBits));
CachedXorInfo = XorRewriteInfo{XorInst, DisjointBits.getSExtValue()};
return true;
}
Not ignoring the implementation details, we shouldn't need CachedXorInfo -- we should add the appropriate logic in the appropriate places. This also implies we can delete the DT and findDominatingXor. This will create some redundant xors in your example, but maybe we should be doing CSE at the end of the pass or rely on a different pass to do CSE or otherwise reconcile the addressing modes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, we run CSE shortly after separate-const-offset-from-gep in our backend, so it should be fine.
Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) is part of base for memory operations. This transformation is true under the following conditions
Check 1 - B and C are disjoint.
Check 2 - XOR(A,C) and B are disjoint.
This transformation can map these Xors in to better addressing mode and eventually decompose them in to geps.