Skip to content
Open
33 changes: 33 additions & 0 deletions llvm/include/llvm/BinaryFormat/DXContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,39 @@ enum class ShaderVisibility : uint32_t {
#include "DXContainerConstants.def"
};

inline dxil::ResourceClass
toResourceClass(dxbc::DescriptorRangeType RangeType) {
using namespace dxbc;
switch (RangeType) {
case DescriptorRangeType::SRV:
return dxil::ResourceClass::SRV;
case DescriptorRangeType::UAV:
return dxil::ResourceClass::UAV;
case DescriptorRangeType::CBV:
return dxil::ResourceClass::CBuffer;
case DescriptorRangeType::Sampler:
return dxil::ResourceClass::Sampler;
}
llvm_unreachable("Unknown DescriptorRangeType");
}

inline dxil::ResourceClass toResourceClass(dxbc::RootParameterType Type) {
using namespace dxbc;
switch (Type) {
case RootParameterType::Constants32Bit:
return dxil::ResourceClass::CBuffer;
case RootParameterType::SRV:
return dxil::ResourceClass::SRV;
case RootParameterType::UAV:
return dxil::ResourceClass::UAV;
case RootParameterType::CBV:
return dxil::ResourceClass::CBuffer;
case dxbc::RootParameterType::DescriptorTable:
llvm_unreachable("DescriptorTable is not convertible to ResourceClass");
}
llvm_unreachable("Unknown RootParameterType");
}

LLVM_ABI ArrayRef<EnumEntry<ShaderVisibility>> getShaderVisibility();

#define SHADER_VISIBILITY(Val, Enum) \
Expand Down
90 changes: 90 additions & 0 deletions llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,96 @@ class RootSignatureValidationError
}
};

class OffsetOverflowError : public ErrorInfo<OffsetOverflowError> {
public:
static char ID;
dxbc::DescriptorRangeType Type;
uint32_t Register;
uint32_t Space;

OffsetOverflowError(dxbc::DescriptorRangeType Type, uint32_t Register,
uint32_t Space)
: Type(Type), Register(Register), Space(Space) {}

void log(raw_ostream &OS) const override {
OS << "Cannot append range with implicit lower bound after an unbounded "
"range "
<< getResourceClassName(toResourceClass(Type))
<< "(register=" << Register << ", space=" << Space << ").";
}

std::error_code convertToErrorCode() const override {
return llvm::inconvertibleErrorCode();
}
};

class ShaderRegisterOverflowError
: public ErrorInfo<ShaderRegisterOverflowError> {
public:
static char ID;
dxbc::DescriptorRangeType Type;
uint32_t Register;
uint32_t Space;

ShaderRegisterOverflowError(dxbc::DescriptorRangeType Type, uint32_t Register,
uint32_t Space)
: Type(Type), Register(Register), Space(Space) {}

void log(raw_ostream &OS) const override {
OS << "Overflow for shader register range: "
<< getResourceClassName(toResourceClass(Type))
<< "(register=" << Register << ", space=" << Space << ").";
}

std::error_code convertToErrorCode() const override {
return llvm::inconvertibleErrorCode();
}
};

class DescriptorRangeOverflowError
: public ErrorInfo<DescriptorRangeOverflowError> {
public:
static char ID;
dxbc::DescriptorRangeType Type;
uint32_t Register;
uint32_t Space;

DescriptorRangeOverflowError(dxbc::DescriptorRangeType Type,
uint32_t Register, uint32_t Space)
: Type(Type), Register(Register), Space(Space) {}

void log(raw_ostream &OS) const override {
OS << "Overflow for descriptor range: "
<< getResourceClassName(toResourceClass(Type))
<< "(register=" << Register << ", space=" << Space << ").";
}

std::error_code convertToErrorCode() const override {
return llvm::inconvertibleErrorCode();
}
};

class TableSamplerMixinError : public ErrorInfo<TableSamplerMixinError> {
public:
static char ID;
dxbc::DescriptorRangeType Type;
uint32_t Location;

TableSamplerMixinError(dxbc::DescriptorRangeType Type, uint32_t Location)
: Type(Type), Location(Location) {}

void log(raw_ostream &OS) const override {
OS << "Samplers cannot be mixed with other "
<< "resource types in a descriptor table, "
<< getResourceClassName(toResourceClass(Type))
<< "(location=" << Location << ")";
}

std::error_code convertToErrorCode() const override {
return llvm::inconvertibleErrorCode();
}
};

class GenericRSMetadataError : public ErrorInfo<GenericRSMetadataError> {
public:
LLVM_ABI static char ID;
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/ADT/IntervalMap.h"
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
#include "llvm/Support/Compiler.h"
#include <cstdint>

namespace llvm {
namespace hlsl {
Expand All @@ -40,6 +41,12 @@ LLVM_ABI bool verifyMaxAnisotropy(uint32_t MaxAnisotropy);
LLVM_ABI bool verifyComparisonFunc(uint32_t ComparisonFunc);
LLVM_ABI bool verifyBorderColor(uint32_t BorderColor);
LLVM_ABI bool verifyLOD(float LOD);
LLVM_ABI bool verifyRegisterOverflow(uint64_t Register,
uint32_t NumDescriptors);
LLVM_ABI uint64_t updateAppendingRegister(uint64_t AppendingRegisterRegister,
uint64_t NumDescriptors,
uint64_t Offset);
LLVM_ABI bool verifyOffsetOverflow(uint64_t Register);

} // namespace rootsig
} // namespace hlsl
Expand Down
71 changes: 71 additions & 0 deletions llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ namespace rootsig {
char GenericRSMetadataError::ID;
char InvalidRSMetadataFormat::ID;
char InvalidRSMetadataValue::ID;
char TableSamplerMixinError::ID;
char ShaderRegisterOverflowError::ID;
char OffsetOverflowError::ID;
char DescriptorRangeOverflowError::ID;

template <typename T> char RootSignatureValidationError<T>::ID;

static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
Expand Down Expand Up @@ -513,6 +518,62 @@ Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
}

Error validateDescriptorTableSamplerMixin(mcdxbc::DescriptorTable Table,
uint32_t Location) {
bool HasSampler = false;
bool HasOtherRangeType = false;
dxbc::DescriptorRangeType OtherRangeType;

for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
dxbc::DescriptorRangeType RangeType =
static_cast<dxbc::DescriptorRangeType>(Range.RangeType);

if (RangeType == dxbc::DescriptorRangeType::Sampler) {
HasSampler = true;
} else {
HasOtherRangeType = true;
OtherRangeType = RangeType;
}
}

// Samplers cannot be mixed with other resources in a descriptor table.
if (HasSampler && HasOtherRangeType)
return make_error<TableSamplerMixinError>(OtherRangeType, Location);
return Error::success();
}

Error validateDescriptorTableRegisterOverflow(mcdxbc::DescriptorTable Table,
uint32_t Location) {
uint64_t AppendingRegister = 0;

for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
dxbc::DescriptorRangeType RangeType =
static_cast<dxbc::DescriptorRangeType>(Range.RangeType);

uint64_t StartSlot = AppendingRegister;
if (Range.OffsetInDescriptorsFromTableStart != ~0U)
StartSlot = Range.OffsetInDescriptorsFromTableStart;

if (verifyOffsetOverflow(StartSlot))
return make_error<OffsetOverflowError>(
RangeType, Range.BaseShaderRegister, Range.RegisterSpace);

if (verifyRegisterOverflow(Range.BaseShaderRegister, Range.NumDescriptors))
return make_error<ShaderRegisterOverflowError>(
RangeType, Range.BaseShaderRegister, Range.RegisterSpace);

if (verifyRegisterOverflow(StartSlot, Range.NumDescriptors))
return make_error<DescriptorRangeOverflowError>(
RangeType, Range.BaseShaderRegister, Range.RegisterSpace);

AppendingRegister =
updateAppendingRegister(StartSlot, Range.NumDescriptors,
Range.OffsetInDescriptorsFromTableStart);
}

return Error::success();
}

Error MetadataParser::validateRootSignature(
const mcdxbc::RootSignatureDesc &RSD) {
Error DeferredErrs = Error::success();
Expand Down Expand Up @@ -597,6 +658,16 @@ Error MetadataParser::validateRootSignature(
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
"DescriptorFlag", Range.Flags));

if (Error Err =
validateDescriptorTableSamplerMixin(Table, Info.Location)) {
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
}

if (Error Err =
validateDescriptorTableRegisterOverflow(Table, Info.Location)) {
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
}
}
break;
}
Expand Down
25 changes: 25 additions & 0 deletions llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,31 @@ bool verifyBorderColor(uint32_t BorderColor) {

bool verifyLOD(float LOD) { return !std::isnan(LOD); }

bool verifyOffsetOverflow(uint64_t Register) {
if (Register > ~0U)
return true;
return false;
}

bool verifyRegisterOverflow(uint64_t Register, uint32_t NumDescriptors) {
if (NumDescriptors == ~0U)
return false;

uint64_t UpperBound =
(uint64_t)Register + (uint64_t)NumDescriptors - (uint64_t)1U;
if (UpperBound > ~0U)
return true;

return false;
}

uint64_t updateAppendingRegister(uint64_t AppendingRegister,
uint64_t NumDescriptors, uint64_t Offset) {
if (NumDescriptors == ~0U)
return (uint64_t)~0U + (uint64_t)1ULL;
return Offset == ~0U ? AppendingRegister + NumDescriptors
: Offset + NumDescriptors;
}
} // namespace rootsig
} // namespace hlsl
} // namespace llvm
56 changes: 13 additions & 43 deletions llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,6 @@
using namespace llvm;
using namespace llvm::dxil;

static ResourceClass toResourceClass(dxbc::DescriptorRangeType RangeType) {
using namespace dxbc;
switch (RangeType) {
case DescriptorRangeType::SRV:
return ResourceClass::SRV;
case DescriptorRangeType::UAV:
return ResourceClass::UAV;
case DescriptorRangeType::CBV:
return ResourceClass::CBuffer;
case DescriptorRangeType::Sampler:
return ResourceClass::Sampler;
}
llvm_unreachable("Unknown DescriptorRangeType");
}

static ResourceClass toResourceClass(dxbc::RootParameterType Type) {
using namespace dxbc;
switch (Type) {
case RootParameterType::Constants32Bit:
return ResourceClass::CBuffer;
case RootParameterType::SRV:
return ResourceClass::SRV;
case RootParameterType::UAV:
return ResourceClass::UAV;
case RootParameterType::CBV:
return ResourceClass::CBuffer;
case dxbc::RootParameterType::DescriptorTable:
llvm_unreachable("DescriptorTable is not convertible to ResourceClass");
}
llvm_unreachable("Unknown RootParameterType");
}

static void reportInvalidDirection(Module &M, DXILResourceMap &DRM) {
for (const auto &UAV : DRM.uavs()) {
if (UAV.CounterDirection != ResourceCounterDirection::Invalid)
Expand Down Expand Up @@ -237,10 +205,10 @@ getRootDescriptorsBindingInfo(const mcdxbc::RootSignatureDesc &RSD,
return RDs;
}

static void validateRootSignature(Module &M,
const mcdxbc::RootSignatureDesc &RSD,
dxil::ModuleMetadataInfo &MMI,
DXILResourceMap &DRM) {
static void validateRootSignatureBindings(Module &M,
const mcdxbc::RootSignatureDesc &RSD,
dxil::ModuleMetadataInfo &MMI,
DXILResourceMap &DRM) {

hlsl::BindingInfoBuilder Builder;
dxbc::ShaderVisibility Visibility = tripleToVisibility(MMI.ShaderProfile);
Expand Down Expand Up @@ -268,10 +236,11 @@ static void validateRootSignature(Module &M,
case dxbc::RootParameterType::CBV: {
dxbc::RTS0::v2::RootDescriptor Desc =
RSD.ParametersContainer.getRootDescriptor(ParamInfo.Location);
Builder.trackBinding(toResourceClass(static_cast<dxbc::RootParameterType>(
ParamInfo.Header.ParameterType)),
Desc.RegisterSpace, Desc.ShaderRegister,
Desc.ShaderRegister, &ParamInfo);
Builder.trackBinding(
dxbc::toResourceClass(static_cast<dxbc::RootParameterType>(
ParamInfo.Header.ParameterType)),
Desc.RegisterSpace, Desc.ShaderRegister, Desc.ShaderRegister,
&ParamInfo);

break;
}
Expand All @@ -285,7 +254,7 @@ static void validateRootSignature(Module &M,
? Range.BaseShaderRegister
: Range.BaseShaderRegister + Range.NumDescriptors - 1;
Builder.trackBinding(
toResourceClass(
dxbc::toResourceClass(
static_cast<dxbc::DescriptorRangeType>(Range.RangeType)),
Range.RegisterSpace, Range.BaseShaderRegister, UpperBound,
&ParamInfo);
Expand Down Expand Up @@ -346,8 +315,9 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in "
"DXILResourceImplicitBinding pass");

if (mcdxbc::RootSignatureDesc *RSD = getRootSignature(RSBI, MMI))
validateRootSignature(M, *RSD, MMI, DRM);
if (mcdxbc::RootSignatureDesc *RSD = getRootSignature(RSBI, MMI)) {
validateRootSignatureBindings(M, *RSD, MMI, DRM);
}
}

PreservedAnalyses
Expand Down
Loading