Skip to content

Add llvm.protected.field.ptr intrinsic and pre-ISel lowering. #151647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: users/pcc/spr/main.add-llvmprotectedfieldptr-intrinsic-and-pre-isel-lowering
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31161,3 +31161,57 @@ This intrinsic is assumed to execute in the default :ref:`floating-point
environment <floatenv>` *except* for the rounding mode.
This intrinsic is not supported on all targets. Some targets may not support
all rounding modes.

'``llvm.protected.field.ptr``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

::

declare ptr @llvm.protected.field.ptr(ptr ptr, i64 disc, i1 use_hw_encoding)

Overview:
"""""""""

The '``llvm.protected.field.ptr``' intrinsic returns a pointer to the
storage location of a pointer that has special properties as described
below.

Arguments:
""""""""""

The first argument is the pointer specifying the location to store the
pointer. The second argument is the discriminator, which is used as an
input for the pointer encoding. The third argument specifies whether to
use a target-specific mechanism to encode the pointer.

Semantics:
""""""""""

This intrinsic returns a pointer which may be used to store a
pointer at the specified address that is encoded using the specified
discriminator. Stores via the pointer will cause the stored pointer to be
blended with the second argument before being stored. The blend operation
shall be either a weak but cheap and target-independent operation (if
the third argument is 0) or a stronger target-specific operation (if the
third argument is 1). When loading from the pointer, the inverse operation
is done on the loaded pointer after it is loaded. Specifically, when the
third argument is 1, the pointer is signed (using pointer authentication
instructions or emulated PAC if not supported by the hardware) using
the struct address before being stored, and authenticated after being
loaded. Note that it is currently unsupported to have the third argument
be 1 on targets other than AArch64. When the third argument is 0, it is
rotated left by 16 bits and the discriminator is subtracted before being
stored, and the discriminator is added and the pointer is rotated right
by 16 bits after being loaded.

If the pointer is used otherwise than for loading or storing (e.g. its
address escapes), that will disable all blending operations using
the deactivation symbol specified in the intrinsic's operand bundle.
The deactivation symbol operand bundle is copied onto any sign and auth
intrinsics that this intrinsic is lowered into. The intent is that the
deactivation symbol represents a field identifier.

This intrinsic is used to implement structure protection.
6 changes: 6 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -2850,6 +2850,12 @@ def int_experimental_convergence_anchor
def int_experimental_convergence_loop
: DefaultAttrsIntrinsic<[llvm_token_ty], [], [IntrNoMem, IntrConvergent]>;

//===----------------- Structure Protection Intrinsics --------------------===//

def int_protected_field_ptr :
DefaultAttrsIntrinsic<[llvm_ptr_ty], [llvm_ptr_ty, llvm_i64_ty, llvm_i1_ty],
[IntrNoMem, ImmArg<ArgIndex<2>>]>;

//===----------------------------------------------------------------------===//
// Target-specific intrinsics
//===----------------------------------------------------------------------===//
Expand Down
199 changes: 199 additions & 0 deletions llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/RuntimeLibcalls.h"
#include "llvm/IR/Type.h"
Expand All @@ -37,6 +39,8 @@
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
#include "llvm/Transforms/Utils/LowerVectorIntrinsics.h"

#include <set>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use std::set unless you actually need an ordered set. Based on usage, you want SmallPtrSet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


using namespace llvm;

/// Threshold to leave statically sized memory intrinsic calls. Calls of known
Expand Down Expand Up @@ -461,6 +465,198 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses(
return Changed;
}

namespace {

enum class PointerEncoding {
Rotate,
PACCopyable,
PACNonCopyable,
};

bool expandProtectedFieldPtr(Function &Intr) {
Module &M = *Intr.getParent();

std::set<GlobalValue *> DSsToDeactivate;
std::set<Instruction *> LoadsStores;

Type *Int8Ty = Type::getInt8Ty(M.getContext());
Type *Int64Ty = Type::getInt64Ty(M.getContext());
PointerType *PtrTy = PointerType::get(M.getContext(), 0);

Function *SignIntr =
Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_sign, {});
Function *AuthIntr =
Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_auth, {});

auto *EmuFnTy = FunctionType::get(Int64Ty, {Int64Ty, Int64Ty}, false);
FunctionCallee EmuSignIntr = M.getOrInsertFunction("__emupac_pacda", EmuFnTy);
FunctionCallee EmuAuthIntr = M.getOrInsertFunction("__emupac_autda", EmuFnTy);

auto CreateSign = [&](IRBuilder<> &B, Value *Val, Value *Disc,
OperandBundleDef DSBundle) {
Function *F = B.GetInsertBlock()->getParent();
Attribute FSAttr = F->getFnAttribute("target-features");
if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth"))
return B.CreateCall(SignIntr, {Val, B.getInt32(2), Disc}, DSBundle);
return B.CreateCall(EmuSignIntr, {Val, Disc}, DSBundle);
};

auto CreateAuth = [&](IRBuilder<> &B, Value *Val, Value *Disc,
OperandBundleDef DSBundle) {
Function *F = B.GetInsertBlock()->getParent();
Attribute FSAttr = F->getFnAttribute("target-features");
if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth"))
return B.CreateCall(AuthIntr, {Val, B.getInt32(2), Disc}, DSBundle);
return B.CreateCall(EmuAuthIntr, {Val, Disc}, DSBundle);
};

auto GetDeactivationSymbol = [&](CallInst *Call) -> GlobalValue * {
if (auto Bundle =
Call->getOperandBundle(LLVMContext::OB_deactivation_symbol))
return cast<GlobalValue>(Bundle->Inputs[0]);
return nullptr;
};

for (User *U : Intr.users()) {
auto *Call = cast<CallInst>(U);
auto *DS = GetDeactivationSymbol(Call);
std::set<PHINode *> VisitedPhis;

std::function<void(Instruction *)> FindLoadsStores;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use std::function for recursion. Make this a separate static function instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was able to remove the recursion after removing phi handling, so this is just inline now.

FindLoadsStores = [&](Instruction *I) {
for (Use &U : I->uses()) {
if (auto *LI = dyn_cast<LoadInst>(U.getUser())) {
if (isa<PointerType>(LI->getType())) {
LoadsStores.insert(LI);
continue;
}
}
if (auto *SI = dyn_cast<StoreInst>(U.getUser())) {
if (U.getOperandNo() == 1 &&
isa<PointerType>(SI->getValueOperand()->getType())) {
LoadsStores.insert(SI);
continue;
}
}
if (auto *P = dyn_cast<PHINode>(U.getUser())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Phi handling not tested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I checked whether this phi handling is actually necessary given #151649 (by building Fleetbench with/without the phi code removed and checking the number of defined deactivation symbols), and it turned out not to be, so I removed it, which allowed the code to be simplified significantly.

if (VisitedPhis.insert(P).second)
FindLoadsStores(P);
continue;
}
// Comparisons against null cannot be used to recover the original
// pointer so we allow them.
if (auto *CI = dyn_cast<ICmpInst>(U.getUser())) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also noticed that this is missing a test, added.

if (auto *Op = dyn_cast<Constant>(CI->getOperand(0)))
if (Op->isNullValue())
continue;
if (auto *Op = dyn_cast<Constant>(CI->getOperand(1)))
if (Op->isNullValue())
continue;
}
if (DS)
DSsToDeactivate.insert(DS);
}
};

FindLoadsStores(Call);
}

for (Instruction *I : LoadsStores) {
std::set<Value *> Pointers;
std::set<Value *> Discs;
std::set<GlobalValue *> DSs;
std::set<PHINode *> VisitedPhis;
bool UseHWEncoding = false;

std::function<void(Value *)> FindFields;
FindFields = [&](Value *V) {
if (auto *Call = dyn_cast<CallInst>(V)) {
if (Call->getCalledOperand() == &Intr) {
Pointers.insert(Call->getArgOperand(0));
Discs.insert(Call->getArgOperand(1));
if (cast<ConstantInt>(Call->getArgOperand(2))->getZExtValue())
UseHWEncoding = true;
DSs.insert(GetDeactivationSymbol(Call));
return;
}
}
if (auto *P = dyn_cast<PHINode>(V)) {
if (VisitedPhis.insert(P).second)
for (Value *V : P->incoming_values())
FindFields(V);
return;
}
Pointers.insert(nullptr);
};
FindFields(isa<StoreInst>(I) ? cast<StoreInst>(I)->getPointerOperand()
: cast<LoadInst>(I)->getPointerOperand());
if (Pointers.size() != 1 || Discs.size() != 1 || DSs.size() != 1) {
for (GlobalValue *DS : DSs)
if (DS)
DSsToDeactivate.insert(DS);
continue;
}

GlobalValue *DS = *DSs.begin();
OperandBundleDef DSBundle("deactivation-symbol", DS);

if (auto *LI = dyn_cast<LoadInst>(I)) {
IRBuilder<> B(LI->getNextNode());
auto *LIInt = cast<Instruction>(B.CreatePtrToInt(LI, B.getInt64Ty()));
Value *Auth;
if (UseHWEncoding) {
Auth = CreateAuth(B, LIInt, *Discs.begin(), DSBundle);
} else {
Auth = B.CreateAdd(LIInt, *Discs.begin());
Auth = B.CreateIntrinsic(
Auth->getType(), Intrinsic::fshr,
{Auth, Auth, ConstantInt::get(Auth->getType(), 16)});
}
LI->replaceAllUsesWith(B.CreateIntToPtr(Auth, B.getPtrTy()));
LIInt->setOperand(0, LI);
} else if (auto *SI = dyn_cast<StoreInst>(I)) {
IRBuilder<> B(SI);
auto *SIValInt =
B.CreatePtrToInt(SI->getValueOperand(), B.getInt64Ty());
Value *Sign;
if (UseHWEncoding) {
Sign = CreateSign(B, SIValInt, *Discs.begin(), DSBundle);
} else {
Sign = B.CreateIntrinsic(
SIValInt->getType(), Intrinsic::fshl,
{SIValInt, SIValInt, ConstantInt::get(SIValInt->getType(), 16)});
}
SI->setOperand(0, B.CreateIntToPtr(Sign, B.getPtrTy()));
}
}

for (User *U : llvm::make_early_inc_range(Intr.users())) {
auto *Call = cast<CallInst>(U);
auto *Pointer = Call->getArgOperand(0);

Call->replaceAllUsesWith(Pointer);
Call->eraseFromParent();
}

if (!DSsToDeactivate.empty()) {
Constant *Nop =
ConstantExpr::getIntToPtr(ConstantInt::get(Int64Ty, 0xd503201f), PtrTy);
for (GlobalValue *OldDS : DSsToDeactivate) {
GlobalValue *DS = GlobalAlias::create(
Int8Ty, 0, GlobalValue::ExternalLinkage, OldDS->getName(), Nop, &M);
DS->setVisibility(GlobalValue::HiddenVisibility);
if (OldDS) {
DS->takeName(OldDS);
OldDS->replaceAllUsesWith(DS);
OldDS->eraseFromParent();
}
}
}
return true;
}

}

bool PreISelIntrinsicLowering::lowerIntrinsics(Module &M) const {
// Map unique constants to globals.
DenseMap<Constant *, GlobalVariable *> CMap;
Expand Down Expand Up @@ -598,6 +794,9 @@ bool PreISelIntrinsicLowering::lowerIntrinsics(Module &M) const {
return lowerUnaryVectorIntrinsicAsLoop(M, CI);
});
break;
case Intrinsic::protected_field_ptr:
Changed |= expandProtectedFieldPtr(F);
break;
}
}
return Changed;
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use update_test_checks.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it, but it wanted to delete the checks for the deactivation symbols. Do you know if there is a way to prevent this?

diff --git a/llvm/test/Transforms/PreISelIntrinsicLowering/protected-field-pointer.ll b/llvm/test/Transforms/PreISelIntrinsicLowering/protected-field-pointer.ll
index cb7e695bfd12..0b17b544e4d8 100644
--- a/llvm/test/Transforms/PreISelIntrinsicLowering/protected-field-pointer.ll
+++ b/llvm/test/Transforms/PreISelIntrinsicLowering/protected-field-pointer.ll
@@ -1,89 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt -passes=pre-isel-intrinsic-lowering -S < %s | FileCheck --check-prefixes=CHECK,NOPAUTH %s
 ; RUN: opt -passes=pre-isel-intrinsic-lowering -mattr=+pauth -S < %s | FileCheck --check-prefixes=CHECK,PAUTH %s
 
 target triple = "aarch64-unknown-linux-gnu"
 
-; CHECK: @ds1 = external global i8
 @ds1 = external global i8
-; CHECK: @ds2 = external global i8
 @ds2 = external global i8
-; CHECK: @ds3 = external global i8
 @ds3 = external global i8
-; CHECK: @ds4 = external global i8
 @ds4 = external global i8
-; CHECK: @ds5 = external global i8
 @ds5 = external global i8
-; CHECK: @ds6 = hidden alias i8, inttoptr (i64 3573751839 to ptr)
 @ds6 = external global i8
 

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing --check-globals should do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That did it, thanks.

Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
; RUN: opt -passes=pre-isel-intrinsic-lowering -S < %s | FileCheck --check-prefixes=CHECK,NOPAUTH %s
; RUN: opt -passes=pre-isel-intrinsic-lowering -mattr=+pauth -S < %s | FileCheck --check-prefixes=CHECK,PAUTH %s

target triple = "aarch64-unknown-linux-gnu"

; CHECK: @ds1 = external global i8
@ds1 = external global i8
; CHECK: @ds2 = external global i8
@ds2 = external global i8
; CHECK: @ds3 = hidden alias i8, inttoptr (i64 3573751839 to ptr)
@ds3 = external global i8

; CHECK: define ptr @f1
define ptr @f1(ptr %ptrptr) {
; CHECK: %ptr = load ptr, ptr %ptrptr, align 8
; CHECK: %1 = ptrtoint ptr %ptr to i64
; NOPAUTH: %2 = call i64 @__emupac_autda(i64 %1, i64 1) [ "deactivation-symbol"(ptr @ds1) ]
; PAUTH: %2 = call i64 @llvm.ptrauth.auth(i64 %1, i32 2, i64 1) [ "deactivation-symbol"(ptr @ds1) ]
; CHECK: %3 = inttoptr i64 %2 to ptr
; CHECK: ret ptr %3
%protptrptr = call ptr @llvm.protected.field.ptr(ptr %ptrptr, i64 1, i1 true) [ "deactivation-symbol"(ptr @ds1) ]
%ptr = load ptr, ptr %protptrptr
ret ptr %ptr
}

; CHECK: define void @f2
define void @f2(ptr %ptrptr, ptr %ptr) {
; CHECK: %1 = ptrtoint ptr %ptr to i64
; NOPAUTH: %2 = call i64 @__emupac_pacda(i64 %1, i64 2) [ "deactivation-symbol"(ptr @ds2) ]
; PAUTH: %2 = call i64 @llvm.ptrauth.sign(i64 %1, i32 2, i64 2) [ "deactivation-symbol"(ptr @ds2) ]
; CHECK: %3 = inttoptr i64 %2 to ptr
; CHECK: store ptr %3, ptr %ptrptr, align 8
; CHECK: ret void
%protptrptr = call ptr @llvm.protected.field.ptr(ptr %ptrptr, i64 2, i1 true) [ "deactivation-symbol"(ptr @ds2) ]
store ptr %ptr, ptr %protptrptr
ret void
}

; CHECK: define ptr @f3
define ptr @f3(ptr %ptrptr) {
; CHECK: ret ptr %ptrptr
%protptrptr = call ptr @llvm.protected.field.ptr(ptr %ptrptr, i64 3, i1 true) [ "deactivation-symbol"(ptr @ds3) ]
ret ptr %protptrptr
}

declare ptr @llvm.protected.field.ptr(ptr, i64, i1 immarg)
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.