-
Notifications
You must be signed in to change notification settings - Fork 14.8k
Add llvm.protected.field.ptr intrinsic and pre-ISel lowering. #151647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/pcc/spr/main.add-llvmprotectedfieldptr-intrinsic-and-pre-isel-lowering
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,9 +21,11 @@ | |
#include "llvm/CodeGen/TargetLowering.h" | ||
#include "llvm/CodeGen/TargetPassConfig.h" | ||
#include "llvm/IR/Function.h" | ||
#include "llvm/IR/GlobalValue.h" | ||
#include "llvm/IR/IRBuilder.h" | ||
#include "llvm/IR/Instructions.h" | ||
#include "llvm/IR/IntrinsicInst.h" | ||
#include "llvm/IR/Metadata.h" | ||
#include "llvm/IR/Module.h" | ||
#include "llvm/IR/RuntimeLibcalls.h" | ||
#include "llvm/IR/Type.h" | ||
|
@@ -37,6 +39,8 @@ | |
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h" | ||
#include "llvm/Transforms/Utils/LowerVectorIntrinsics.h" | ||
|
||
#include <set> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
using namespace llvm; | ||
|
||
/// Threshold to leave statically sized memory intrinsic calls. Calls of known | ||
|
@@ -461,6 +465,198 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses( | |
return Changed; | ||
} | ||
|
||
namespace { | ||
|
||
enum class PointerEncoding { | ||
Rotate, | ||
PACCopyable, | ||
PACNonCopyable, | ||
}; | ||
|
||
bool expandProtectedFieldPtr(Function &Intr) { | ||
Module &M = *Intr.getParent(); | ||
|
||
std::set<GlobalValue *> DSsToDeactivate; | ||
std::set<Instruction *> LoadsStores; | ||
|
||
Type *Int8Ty = Type::getInt8Ty(M.getContext()); | ||
Type *Int64Ty = Type::getInt64Ty(M.getContext()); | ||
PointerType *PtrTy = PointerType::get(M.getContext(), 0); | ||
|
||
Function *SignIntr = | ||
Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_sign, {}); | ||
Function *AuthIntr = | ||
Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_auth, {}); | ||
|
||
auto *EmuFnTy = FunctionType::get(Int64Ty, {Int64Ty, Int64Ty}, false); | ||
FunctionCallee EmuSignIntr = M.getOrInsertFunction("__emupac_pacda", EmuFnTy); | ||
FunctionCallee EmuAuthIntr = M.getOrInsertFunction("__emupac_autda", EmuFnTy); | ||
|
||
auto CreateSign = [&](IRBuilder<> &B, Value *Val, Value *Disc, | ||
OperandBundleDef DSBundle) { | ||
Function *F = B.GetInsertBlock()->getParent(); | ||
Attribute FSAttr = F->getFnAttribute("target-features"); | ||
if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth")) | ||
return B.CreateCall(SignIntr, {Val, B.getInt32(2), Disc}, DSBundle); | ||
return B.CreateCall(EmuSignIntr, {Val, Disc}, DSBundle); | ||
}; | ||
|
||
auto CreateAuth = [&](IRBuilder<> &B, Value *Val, Value *Disc, | ||
OperandBundleDef DSBundle) { | ||
Function *F = B.GetInsertBlock()->getParent(); | ||
Attribute FSAttr = F->getFnAttribute("target-features"); | ||
if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth")) | ||
return B.CreateCall(AuthIntr, {Val, B.getInt32(2), Disc}, DSBundle); | ||
return B.CreateCall(EmuAuthIntr, {Val, Disc}, DSBundle); | ||
}; | ||
|
||
auto GetDeactivationSymbol = [&](CallInst *Call) -> GlobalValue * { | ||
if (auto Bundle = | ||
Call->getOperandBundle(LLVMContext::OB_deactivation_symbol)) | ||
return cast<GlobalValue>(Bundle->Inputs[0]); | ||
return nullptr; | ||
}; | ||
|
||
for (User *U : Intr.users()) { | ||
auto *Call = cast<CallInst>(U); | ||
auto *DS = GetDeactivationSymbol(Call); | ||
std::set<PHINode *> VisitedPhis; | ||
|
||
std::function<void(Instruction *)> FindLoadsStores; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not use std::function for recursion. Make this a separate static function instead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was able to remove the recursion after removing phi handling, so this is just inline now. |
||
FindLoadsStores = [&](Instruction *I) { | ||
for (Use &U : I->uses()) { | ||
if (auto *LI = dyn_cast<LoadInst>(U.getUser())) { | ||
if (isa<PointerType>(LI->getType())) { | ||
LoadsStores.insert(LI); | ||
continue; | ||
} | ||
} | ||
if (auto *SI = dyn_cast<StoreInst>(U.getUser())) { | ||
if (U.getOperandNo() == 1 && | ||
isa<PointerType>(SI->getValueOperand()->getType())) { | ||
LoadsStores.insert(SI); | ||
continue; | ||
} | ||
} | ||
if (auto *P = dyn_cast<PHINode>(U.getUser())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Phi handling not tested? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. I checked whether this phi handling is actually necessary given #151649 (by building Fleetbench with/without the phi code removed and checking the number of defined deactivation symbols), and it turned out not to be, so I removed it, which allowed the code to be simplified significantly. |
||
if (VisitedPhis.insert(P).second) | ||
FindLoadsStores(P); | ||
continue; | ||
} | ||
// Comparisons against null cannot be used to recover the original | ||
// pointer so we allow them. | ||
if (auto *CI = dyn_cast<ICmpInst>(U.getUser())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also noticed that this is missing a test, added. |
||
if (auto *Op = dyn_cast<Constant>(CI->getOperand(0))) | ||
if (Op->isNullValue()) | ||
continue; | ||
if (auto *Op = dyn_cast<Constant>(CI->getOperand(1))) | ||
if (Op->isNullValue()) | ||
continue; | ||
} | ||
if (DS) | ||
DSsToDeactivate.insert(DS); | ||
} | ||
}; | ||
|
||
FindLoadsStores(Call); | ||
} | ||
|
||
for (Instruction *I : LoadsStores) { | ||
std::set<Value *> Pointers; | ||
std::set<Value *> Discs; | ||
std::set<GlobalValue *> DSs; | ||
std::set<PHINode *> VisitedPhis; | ||
bool UseHWEncoding = false; | ||
|
||
std::function<void(Value *)> FindFields; | ||
FindFields = [&](Value *V) { | ||
if (auto *Call = dyn_cast<CallInst>(V)) { | ||
if (Call->getCalledOperand() == &Intr) { | ||
Pointers.insert(Call->getArgOperand(0)); | ||
Discs.insert(Call->getArgOperand(1)); | ||
if (cast<ConstantInt>(Call->getArgOperand(2))->getZExtValue()) | ||
UseHWEncoding = true; | ||
DSs.insert(GetDeactivationSymbol(Call)); | ||
return; | ||
} | ||
} | ||
if (auto *P = dyn_cast<PHINode>(V)) { | ||
if (VisitedPhis.insert(P).second) | ||
for (Value *V : P->incoming_values()) | ||
FindFields(V); | ||
return; | ||
} | ||
Pointers.insert(nullptr); | ||
}; | ||
FindFields(isa<StoreInst>(I) ? cast<StoreInst>(I)->getPointerOperand() | ||
: cast<LoadInst>(I)->getPointerOperand()); | ||
if (Pointers.size() != 1 || Discs.size() != 1 || DSs.size() != 1) { | ||
for (GlobalValue *DS : DSs) | ||
if (DS) | ||
DSsToDeactivate.insert(DS); | ||
continue; | ||
} | ||
|
||
GlobalValue *DS = *DSs.begin(); | ||
OperandBundleDef DSBundle("deactivation-symbol", DS); | ||
|
||
if (auto *LI = dyn_cast<LoadInst>(I)) { | ||
IRBuilder<> B(LI->getNextNode()); | ||
auto *LIInt = cast<Instruction>(B.CreatePtrToInt(LI, B.getInt64Ty())); | ||
Value *Auth; | ||
if (UseHWEncoding) { | ||
Auth = CreateAuth(B, LIInt, *Discs.begin(), DSBundle); | ||
} else { | ||
Auth = B.CreateAdd(LIInt, *Discs.begin()); | ||
Auth = B.CreateIntrinsic( | ||
Auth->getType(), Intrinsic::fshr, | ||
{Auth, Auth, ConstantInt::get(Auth->getType(), 16)}); | ||
} | ||
LI->replaceAllUsesWith(B.CreateIntToPtr(Auth, B.getPtrTy())); | ||
LIInt->setOperand(0, LI); | ||
} else if (auto *SI = dyn_cast<StoreInst>(I)) { | ||
IRBuilder<> B(SI); | ||
auto *SIValInt = | ||
B.CreatePtrToInt(SI->getValueOperand(), B.getInt64Ty()); | ||
Value *Sign; | ||
if (UseHWEncoding) { | ||
Sign = CreateSign(B, SIValInt, *Discs.begin(), DSBundle); | ||
} else { | ||
Sign = B.CreateIntrinsic( | ||
SIValInt->getType(), Intrinsic::fshl, | ||
{SIValInt, SIValInt, ConstantInt::get(SIValInt->getType(), 16)}); | ||
} | ||
SI->setOperand(0, B.CreateIntToPtr(Sign, B.getPtrTy())); | ||
} | ||
} | ||
|
||
for (User *U : llvm::make_early_inc_range(Intr.users())) { | ||
auto *Call = cast<CallInst>(U); | ||
auto *Pointer = Call->getArgOperand(0); | ||
|
||
Call->replaceAllUsesWith(Pointer); | ||
Call->eraseFromParent(); | ||
} | ||
|
||
if (!DSsToDeactivate.empty()) { | ||
Constant *Nop = | ||
ConstantExpr::getIntToPtr(ConstantInt::get(Int64Ty, 0xd503201f), PtrTy); | ||
for (GlobalValue *OldDS : DSsToDeactivate) { | ||
GlobalValue *DS = GlobalAlias::create( | ||
Int8Ty, 0, GlobalValue::ExternalLinkage, OldDS->getName(), Nop, &M); | ||
DS->setVisibility(GlobalValue::HiddenVisibility); | ||
if (OldDS) { | ||
DS->takeName(OldDS); | ||
OldDS->replaceAllUsesWith(DS); | ||
OldDS->eraseFromParent(); | ||
} | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
} | ||
|
||
bool PreISelIntrinsicLowering::lowerIntrinsics(Module &M) const { | ||
// Map unique constants to globals. | ||
DenseMap<Constant *, GlobalVariable *> CMap; | ||
|
@@ -598,6 +794,9 @@ bool PreISelIntrinsicLowering::lowerIntrinsics(Module &M) const { | |
return lowerUnaryVectorIntrinsicAsLoop(M, CI); | ||
}); | ||
break; | ||
case Intrinsic::protected_field_ptr: | ||
Changed |= expandProtectedFieldPtr(F); | ||
break; | ||
} | ||
} | ||
return Changed; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use update_test_checks.py. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried it, but it wanted to delete the checks for the deactivation symbols. Do you know if there is a way to prevent this?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That did it, thanks. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
; RUN: opt -passes=pre-isel-intrinsic-lowering -S < %s | FileCheck --check-prefixes=CHECK,NOPAUTH %s | ||
; RUN: opt -passes=pre-isel-intrinsic-lowering -mattr=+pauth -S < %s | FileCheck --check-prefixes=CHECK,PAUTH %s | ||
|
||
target triple = "aarch64-unknown-linux-gnu" | ||
|
||
; CHECK: @ds1 = external global i8 | ||
@ds1 = external global i8 | ||
; CHECK: @ds2 = external global i8 | ||
@ds2 = external global i8 | ||
; CHECK: @ds3 = hidden alias i8, inttoptr (i64 3573751839 to ptr) | ||
@ds3 = external global i8 | ||
|
||
; CHECK: define ptr @f1 | ||
define ptr @f1(ptr %ptrptr) { | ||
; CHECK: %ptr = load ptr, ptr %ptrptr, align 8 | ||
; CHECK: %1 = ptrtoint ptr %ptr to i64 | ||
; NOPAUTH: %2 = call i64 @__emupac_autda(i64 %1, i64 1) [ "deactivation-symbol"(ptr @ds1) ] | ||
; PAUTH: %2 = call i64 @llvm.ptrauth.auth(i64 %1, i32 2, i64 1) [ "deactivation-symbol"(ptr @ds1) ] | ||
; CHECK: %3 = inttoptr i64 %2 to ptr | ||
; CHECK: ret ptr %3 | ||
%protptrptr = call ptr @llvm.protected.field.ptr(ptr %ptrptr, i64 1, i1 true) [ "deactivation-symbol"(ptr @ds1) ] | ||
nikic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
%ptr = load ptr, ptr %protptrptr | ||
ret ptr %ptr | ||
} | ||
|
||
; CHECK: define void @f2 | ||
define void @f2(ptr %ptrptr, ptr %ptr) { | ||
; CHECK: %1 = ptrtoint ptr %ptr to i64 | ||
; NOPAUTH: %2 = call i64 @__emupac_pacda(i64 %1, i64 2) [ "deactivation-symbol"(ptr @ds2) ] | ||
; PAUTH: %2 = call i64 @llvm.ptrauth.sign(i64 %1, i32 2, i64 2) [ "deactivation-symbol"(ptr @ds2) ] | ||
; CHECK: %3 = inttoptr i64 %2 to ptr | ||
; CHECK: store ptr %3, ptr %ptrptr, align 8 | ||
; CHECK: ret void | ||
%protptrptr = call ptr @llvm.protected.field.ptr(ptr %ptrptr, i64 2, i1 true) [ "deactivation-symbol"(ptr @ds2) ] | ||
store ptr %ptr, ptr %protptrptr | ||
ret void | ||
} | ||
|
||
; CHECK: define ptr @f3 | ||
define ptr @f3(ptr %ptrptr) { | ||
; CHECK: ret ptr %ptrptr | ||
%protptrptr = call ptr @llvm.protected.field.ptr(ptr %ptrptr, i64 3, i1 true) [ "deactivation-symbol"(ptr @ds3) ] | ||
ret ptr %protptrptr | ||
} | ||
|
||
declare ptr @llvm.protected.field.ptr(ptr, i64, i1 immarg) |
Uh oh!
There was an error while loading. Please reload this page.