Skip to content

[SeparateConstOffsetFromGEP] Decompose constant xor operand if possible #150438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 125 additions & 16 deletions llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,17 @@ class ConstantOffsetExtractor {
/// \p PreservesNUW Outputs whether the extraction allows preserving the
/// GEP's nuw flag, if it has one.
static Value *Extract(Value *Idx, GetElementPtrInst *GEP,
User *&UserChainTail, bool &PreservesNUW);
User *&UserChainTail, bool &PreservesNUW,
DominatorTree *DT);

/// Looks for a constant offset from the given GEP index without extracting
/// it. It returns the numeric value of the extracted constant offset (0 if
/// failed). The meaning of the arguments are the same as Extract.
static int64_t Find(Value *Idx, GetElementPtrInst *GEP);
static int64_t Find(Value *Idx, GetElementPtrInst *GEP, DominatorTree *DT);

private:
ConstantOffsetExtractor(BasicBlock::iterator InsertionPt)
: IP(InsertionPt), DL(InsertionPt->getDataLayout()) {}
ConstantOffsetExtractor(BasicBlock::iterator InsertionPt, DominatorTree *DT)
: IP(InsertionPt), DT(DT), DL(InsertionPt->getDataLayout()) {}

/// Searches the expression that computes V for a non-zero constant C s.t.
/// V can be reassociated into the form V' + C. If the searching is
Expand Down Expand Up @@ -321,6 +322,20 @@ class ConstantOffsetExtractor {
bool CanTraceInto(bool SignExtended, bool ZeroExtended, BinaryOperator *BO,
bool NonNegative);

// Find the most dominating Xor with the same base operand.
BinaryOperator *findDominatingXor(Value *BaseOperand,
BinaryOperator *CurrentXor);

/// Check if Xor instruction should be considered for optimization.
bool shouldConsiderXor(BinaryOperator *XorInst);

/// Cache the information about Xor idiom.
struct XorRewriteInfo {
llvm::BinaryOperator *BaseXor = nullptr;
int64_t AdjustedOffset = 0;
};
std::optional<XorRewriteInfo> CachedXorInfo;

/// The path from the constant offset to the old GEP index. e.g., if the GEP
/// index is "a * b + (c + 5)". After running function find, UserChain[0] will
/// be the constant 5, UserChain[1] will be the subexpression "c + 5", and
Expand All @@ -336,6 +351,8 @@ class ConstantOffsetExtractor {
/// Insertion position of cloned instructions.
BasicBlock::iterator IP;

DominatorTree *DT;

const DataLayout &DL;
};

Expand Down Expand Up @@ -514,12 +531,14 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
bool ZeroExtended,
BinaryOperator *BO,
bool NonNegative) {
// We only consider ADD, SUB and OR, because a non-zero constant found in
// We only consider ADD, SUB, OR and XOR, because a non-zero constant found in
// expressions composed of these operations can be easily hoisted as a
// constant offset by reassociation.
// constant offset by reassociation. XOR is a special case and can be folded
// in to gep if the constant is proven to be disjoint.
if (BO->getOpcode() != Instruction::Add &&
BO->getOpcode() != Instruction::Sub &&
BO->getOpcode() != Instruction::Or) {
BO->getOpcode() != Instruction::Or &&
BO->getOpcode() != Instruction::Xor) {
return false;
}

Expand All @@ -530,6 +549,10 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
!cast<PossiblyDisjointInst>(BO)->isDisjoint())
return false;

// Handle Xor idiom.
if (BO->getOpcode() == Instruction::Xor)
return shouldConsiderXor(BO);

// FIXME: We don't currently support constants from the RHS of subs,
// when we are zero-extended, because we need a way to zero-extended
// them before they are negated.
Expand Down Expand Up @@ -740,6 +763,10 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {
"UserChain, so no one should be used more than "
"once");

// Special case for Xor idiom.
if (BO->getOpcode() == Instruction::Xor)
return CachedXorInfo->BaseXor;

unsigned OpNo = (BO->getOperand(0) == UserChain[ChainIndex - 1] ? 0 : 1);
assert(BO->getOperand(OpNo) == UserChain[ChainIndex - 1]);
Value *NextInChain = removeConstOffset(ChainIndex - 1);
Expand Down Expand Up @@ -780,6 +807,80 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {
return NewBO;
}

// Find the most dominating Xor with the same base operand.
BinaryOperator *
ConstantOffsetExtractor::findDominatingXor(Value *BaseOperand,
BinaryOperator *CurrentXor) {
BinaryOperator *MostDominatingXor = nullptr;
// Iterate over all instructions that use the BaseOperand.
for (User *U : BaseOperand->users()) {
auto *CandidateXor = dyn_cast<BinaryOperator>(U);

// Simple checks.
if (!CandidateXor || CandidateXor == CurrentXor)
continue;

// Check if the binary operator is a Xor with constant.
if (!match(CandidateXor, m_Xor(m_Specific(BaseOperand), m_ConstantInt())))
continue;

// After confirming the structure, check the dominance relationship.
if (DT->dominates(CandidateXor, CurrentXor))
// If we find a dominating Xor, keep it if it's the first one,
// or if it dominates the best candidate we've found so far.
if (!MostDominatingXor || DT->dominates(CandidateXor, MostDominatingXor))
MostDominatingXor = CandidateXor;
}

return MostDominatingXor;
}

// Check if Xor should be considered.
// Only the following idiom is considered.
// Example:
// %3 = xor i32 %2, 32
// %4 = xor i32 %2, 8224
// %6 = getelementptr half, ptr addrspace(3) %1, i32 %3
// %7 = getelementptr half, ptr addrspace(3) %1, i32 %4
// GEP that corresponds to %7, looks at the binary operator %4.
// In order for %4 to be considered, it should have a dominating xor with
// constant offset that is disjoint with an adjusted offset.
// If disjoint, %4 = xor i32 %2, 8224 can be treated as %4 = add i32 %3, 8192
bool ConstantOffsetExtractor::shouldConsiderXor(BinaryOperator *XorInst) {

Value *BaseOperand = nullptr;
ConstantInt *CurrentConst = nullptr;
if (!match(XorInst, m_Xor(m_Value(BaseOperand), m_ConstantInt(CurrentConst))))
return false;

// Find the most dominating Xor with the same base operand.
BinaryOperator *DominatingXor = findDominatingXor(BaseOperand, XorInst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should be trying to find a dominating xor. I think it makes sense, at least in the initial PR, to always extract out the maximum disjoint constant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, finding the maximum disjoint constant and use it is the best way to optimize.
Considering the linear layout of Triton, there may exist a xor that satifies the maximum disjoint constant (but, it is an assumption).

It would not hurt if we can find the maximum disjoint constant as I am already scanning all the xors for each GEP :)

Copy link
Contributor

@jrbyrnes jrbyrnes Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignoring the implementation details, I was thinking something like this --

bool ConstantOffsetExtractor::shouldConsiderXor(BinaryOperator *XorInst) {

  Value *BaseOperand = nullptr;
  ConstantInt *CurrentConst = nullptr;
  if (!match(XorInst, m_Xor(m_Value(BaseOperand), m_ConstantInt(CurrentConst))))
    return false;

  auto SQ = SimplifyQuery(DL, DT, nullptr, nullptr, false);
  auto KnownBitsValue = computeKnownBits(BaseOperand, SQ);
  auto KnownBitsConst = computeKnownBits(CurrentConst, SQ);

  auto DisjointBits = KnownBitsConst.One & KnownBitsValue.Zero;
  if (DisjointBits.isZero())
    return false;
  
  auto NonDisjointBits = KnownBitsConst.One & KnownBitsValue.getMaxValue();
  

  
  XorInst->setOperand(1, ConstantInt::get(CurrentConst->getType(), NonDisjointBits));
  CachedXorInfo = XorRewriteInfo{XorInst, DisjointBits.getSExtValue()};
  return true;
}

Not ignoring the implementation details, we shouldn't need CachedXorInfo -- we should add the appropriate logic in the appropriate places. This also implies we can delete the DT and findDominatingXor. This will create some redundant xors in your example, but maybe we should be doing CSE at the end of the pass or rely on a different pass to do CSE or otherwise reconcile the addressing modes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we run CSE shortly after separate-const-offset-from-gep in our backend, so it should be fine.

if (!DominatingXor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are cases where we can extract a constant offset even if there is no dominating xor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This is true for cases like this
%2 = select i1 %0, i32 0, i32 288
%3 = xor i32 %2, 12320
%4 = xor i32 %2, 8224
%5 = xor i32 %2, 4096
The disjoinedness is hard to prove with subtraction if %3 is considered as BaseXor.
In this case, all the constants need to be reasoned and come up with new base. This is what I did in my first cut.
%2 = select i1 %0, i32 0, i32 288
%tmp = xor i32 %2, 32
%3 = xor i32 %tmp, 12288
%4 = xor i32 %tmp, 8192
%5 = xor i32 %tmp, 4064
Let me know what you think.

If this patch goes through, I can build upon this to address the case above.
This is one of the reasons why I went with the earlier approach of traversing the entire function instead of one xor at a time.

Let me know if you have any examples to share Jeff.

return false;

// We expect the dominating instruction to also be a 'xor-const'.
ConstantInt *DominatingConst = nullptr;
if (!match(DominatingXor,
m_Xor(m_Specific(BaseOperand), m_ConstantInt(DominatingConst))))
return false;

// Calculate the adjusted offset (difference between constants)
APInt AdjustedOffset = CurrentConst->getValue() - DominatingConst->getValue();

// Check disjoint conditions
// 1. AdjustedOffset and DominatingConst should be disjoint
if ((AdjustedOffset & DominatingConst->getValue()) != 0)
return false;

// 2. DominatingXor and AdjustedOffset should be disjoint
if (!MaskedValueIsZero(DominatingXor, AdjustedOffset, SimplifyQuery(DL), 0))
return false;

// Cache the result.
CachedXorInfo = XorRewriteInfo{DominatingXor, AdjustedOffset.getSExtValue()};
return true;
}

/// A helper function to check if reassociating through an entry in the user
/// chain would invalidate the GEP's nuw flag.
static bool allowsPreservingNUW(const User *U) {
Expand All @@ -805,8 +906,8 @@ static bool allowsPreservingNUW(const User *U) {

Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP,
User *&UserChainTail,
bool &PreservesNUW) {
ConstantOffsetExtractor Extractor(GEP->getIterator());
bool &PreservesNUW, DominatorTree *DT) {
ConstantOffsetExtractor Extractor(GEP->getIterator(), DT);
// Find a non-zero constant offset first.
APInt ConstantOffset =
Extractor.find(Idx, /* SignExtended */ false, /* ZeroExtended */ false,
Expand All @@ -825,12 +926,20 @@ Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP,
return IdxWithoutConstOffset;
}

int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP) {
int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP,
DominatorTree *DT) {
// If Idx is an index of an inbound GEP, Idx is guaranteed to be non-negative.
return ConstantOffsetExtractor(GEP->getIterator())
.find(Idx, /* SignExtended */ false, /* ZeroExtended */ false,
GEP->isInBounds())
.getSExtValue();
ConstantOffsetExtractor Extractor(GEP->getIterator(), DT);
auto Offset = Extractor
.find(Idx, /* SignExtended */ false,
/* ZeroExtended */ false, GEP->isInBounds())
.getSExtValue();

// Return the disjoint offset for Xor.
if (Extractor.CachedXorInfo)
return Extractor.CachedXorInfo->AdjustedOffset;

return Offset;
}

bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToIndexSize(
Expand Down Expand Up @@ -866,7 +975,7 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP,

// Tries to extract a constant offset from this GEP index.
int64_t ConstantOffset =
ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP);
ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP, DT);
if (ConstantOffset != 0) {
NeedsExtraction = true;
// A GEP may have multiple indices. We accumulate the extracted
Expand Down Expand Up @@ -1106,7 +1215,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
User *UserChainTail;
bool PreservesNUW;
Value *NewIdx = ConstantOffsetExtractor::Extract(
OldIdx, GEP, UserChainTail, PreservesNUW);
OldIdx, GEP, UserChainTail, PreservesNUW, DT);
if (NewIdx != nullptr) {
// Switches to the index with the constant offset removed.
GEP->setOperand(I, NewIdx);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -mtriple=amdgcn-amd-amdhsa -passes=separate-const-offset-from-gep \
; RUN: -S < %s | FileCheck %s

; Test that xor idiom.
; Xors with disjoint constants 4128,8224 and 12320 must be expressed in GEPs.
; Xors with non-disjoint constants 2336 and 8480, should not be optimized.
define amdgpu_kernel void @test1(i1 %0, ptr addrspace(3) %1) {
; CHECK-LABEL: define amdgpu_kernel void @test1(
; CHECK-SAME: i1 [[TMP0:%.*]], ptr addrspace(3) [[TMP1:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP0]], i32 0, i32 288
; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP2]], 32
; CHECK-NEXT: [[TMP14:%.*]] = xor i32 [[TMP2]], 2336
; CHECK-NEXT: [[TMP5:%.*]] = xor i32 [[TMP2]], 8480
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
; CHECK-NEXT: [[TMP16:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP14]]
; CHECK-NEXT: [[TMP20:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
; CHECK-NEXT: [[TMP21:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP20]], i32 8192
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP7]], i32 16384
; CHECK-NEXT: [[TMP15:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP5]]
; CHECK-NEXT: [[TMP25:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP25]], i32 24576
; CHECK-NEXT: [[TMP9:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP4]], align 16
; CHECK-NEXT: [[TMP10:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP16]], align 16
; CHECK-NEXT: [[TMP17:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP21]], align 16
; CHECK-NEXT: [[TMP18:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP6]], align 16
; CHECK-NEXT: [[TMP19:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP15]], align 16
; CHECK-NEXT: [[TMP11:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP8]], align 16
; CHECK-NEXT: [[TMP12:%.*]] = fadd <8 x half> [[TMP9]], [[TMP10]]
; CHECK-NEXT: [[TMP22:%.*]] = fadd <8 x half> [[TMP17]], [[TMP18]]
; CHECK-NEXT: [[TMP23:%.*]] = fadd <8 x half> [[TMP19]], [[TMP11]]
; CHECK-NEXT: [[TMP24:%.*]] = fadd <8 x half> [[TMP12]], [[TMP22]]
; CHECK-NEXT: [[TMP13:%.*]] = fadd <8 x half> [[TMP23]], [[TMP24]]
; CHECK-NEXT: store <8 x half> [[TMP13]], ptr addrspace(3) [[TMP1]], align 16
; CHECK-NEXT: ret void
;
entry:
%2 = select i1 %0, i32 0, i32 288
%3 = xor i32 %2, 32 ; Base
%4 = xor i32 %2, 2336 ; Not disjoint
%5 = xor i32 %2, 4128 ; Disjoint
%6 = xor i32 %2, 8224 ; Disjoint
%7 = xor i32 %2, 8480 ; Not disjoint
%8 = xor i32 %2, 12320 ; Disjoint
%9 = getelementptr half, ptr addrspace(3) %1, i32 %3
%10 = getelementptr half, ptr addrspace(3) %1, i32 %4
%11 = getelementptr half, ptr addrspace(3) %1, i32 %5
%12 = getelementptr half, ptr addrspace(3) %1, i32 %6
%13 = getelementptr half, ptr addrspace(3) %1, i32 %7
%14 = getelementptr half, ptr addrspace(3) %1, i32 %8
%15 = load <8 x half>, ptr addrspace(3) %9, align 16
%16 = load <8 x half>, ptr addrspace(3) %10, align 16
%17 = load <8 x half>, ptr addrspace(3) %11, align 16
%18 = load <8 x half>, ptr addrspace(3) %12, align 16
%19 = load <8 x half>, ptr addrspace(3) %13, align 16
%20 = load <8 x half>, ptr addrspace(3) %14, align 16
%21 = fadd <8 x half> %15, %16
%22 = fadd <8 x half> %17, %18
%23 = fadd <8 x half> %19, %20
%24 = fadd <8 x half> %21, %22
%25 = fadd <8 x half> %23, %24
store <8 x half> %25, ptr addrspace(3) %1, align 16
ret void
}