Skip to content

[HLSL] Rewrite semantics parsing #152537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions clang/include/clang/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,40 @@ class HLSLAnnotationAttr : public InheritableAttr {
}
};

class HLSLSemanticAttr : public HLSLAnnotationAttr {
unsigned SemanticIndex = 0;
LLVM_PREFERRED_TYPE(bool)
unsigned SemanticIndexable : 1;
LLVM_PREFERRED_TYPE(bool)
unsigned SemanticExplicitIndex : 1;

protected:
HLSLSemanticAttr(ASTContext &Context, const AttributeCommonInfo &CommonInfo,
attr::Kind AK, bool IsLateParsed,
bool InheritEvenIfAlreadyPresent, bool SemanticIndexable)
: HLSLAnnotationAttr(Context, CommonInfo, AK, IsLateParsed,
InheritEvenIfAlreadyPresent) {
this->SemanticIndexable = SemanticIndexable;
this->SemanticExplicitIndex = false;
}

public:
bool isSemanticIndexable() const { return SemanticIndexable; }

void setSemanticIndex(unsigned SemanticIndex) {
this->SemanticIndex = SemanticIndex;
this->SemanticExplicitIndex = true;
}

unsigned getSemanticIndex() const { return SemanticIndex; }

// Implement isa/cast/dyncast/etc.
static bool classof(const Attr *A) {
return A->getKind() >= attr::FirstHLSLSemanticAttr &&
A->getKind() <= attr::LastHLSLSemanticAttr;
}
};

/// A parameter attribute which changes the argument-passing ABI rule
/// for the parameter.
class ParameterABIAttr : public InheritableParamAttr {
Expand Down
66 changes: 35 additions & 31 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,16 @@ class DeclOrStmtAttr : InheritableAttr;
/// An attribute class for HLSL Annotations.
class HLSLAnnotationAttr : InheritableAttr;

class HLSLSemanticAttr<bit Indexable> : HLSLAnnotationAttr {
bit SemanticIndexable = Indexable;
int SemanticIndex = 0;
bit SemanticExplicitIndex = 0;

let Spellings = [];
let Subjects = SubjectList<[ParmVar, Field, Function]>;
let LangOpts = [HLSL];
}

/// A target-specific attribute. This class is meant to be used as a mixin
/// with InheritableAttr or Attr depending on the attribute's needs.
class TargetSpecificAttr<TargetSpec target> {
Expand Down Expand Up @@ -4873,27 +4883,6 @@ def HLSLNumThreads: InheritableAttr {
let Documentation = [NumThreadsDocs];
}

def HLSLSV_GroupThreadID: HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"sv_groupthreadid">];
let Subjects = SubjectList<[ParmVar, Field]>;
let LangOpts = [HLSL];
let Documentation = [HLSLSV_GroupThreadIDDocs];
}

def HLSLSV_GroupID: HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"sv_groupid">];
let Subjects = SubjectList<[ParmVar, Field]>;
let LangOpts = [HLSL];
let Documentation = [HLSLSV_GroupIDDocs];
}

def HLSLSV_GroupIndex: HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"sv_groupindex">];
let Subjects = SubjectList<[ParmVar, GlobalVar]>;
let LangOpts = [HLSL];
let Documentation = [HLSLSV_GroupIndexDocs];
}

def HLSLResourceBinding: InheritableAttr {
let Spellings = [HLSLAnnotation<"register">];
let Subjects = SubjectList<[HLSLBufferObj, ExternalGlobalVar], ErrorDiag>;
Expand Down Expand Up @@ -4943,13 +4932,35 @@ def HLSLResourceBinding: InheritableAttr {
}];
}

def HLSLSV_Position : HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"sv_position">];
let Subjects = SubjectList<[ParmVar, Field]>;
def HLSLUnparsedSemantic : HLSLAnnotationAttr {
let Spellings = [];
let Args = [DefaultIntArgument<"Index", 0>,
DefaultBoolArgument<"ExplicitIndex", 0>];
let Subjects = SubjectList<[ParmVar, Field, Function]>;
let LangOpts = [HLSL];
let Documentation = [InternalOnly];
}

def HLSLSV_Position : HLSLSemanticAttr</* Indexable= */ 1> {
let Documentation = [HLSLSV_PositionDocs];
}

def HLSLSV_GroupThreadID : HLSLSemanticAttr</* Indexable= */ 0> {
let Documentation = [HLSLSV_GroupThreadIDDocs];
}

def HLSLSV_GroupID : HLSLSemanticAttr</* Indexable= */ 0> {
let Documentation = [HLSLSV_GroupIDDocs];
}

def HLSLSV_GroupIndex : HLSLSemanticAttr</* Indexable= */ 0> {
let Documentation = [HLSLSV_GroupIndexDocs];
}

def HLSLSV_DispatchThreadID : HLSLSemanticAttr</* Indexable= */ 0> {
let Documentation = [HLSLSV_DispatchThreadIDDocs];
}

def HLSLPackOffset: HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"packoffset">];
let LangOpts = [HLSL];
Expand All @@ -4962,13 +4973,6 @@ def HLSLPackOffset: HLSLAnnotationAttr {
}];
}

def HLSLSV_DispatchThreadID: HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"sv_dispatchthreadid">];
let Subjects = SubjectList<[ParmVar, Field]>;
let LangOpts = [HLSL];
let Documentation = [HLSLSV_DispatchThreadIDDocs];
}

def HLSLShader : InheritableAttr {
let Spellings = [Microsoft<"shader">];
let Subjects = SubjectList<[HLSLEntry]>;
Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/DiagnosticFrontendKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ def warn_hlsl_langstd_minimal :
"recommend using %1 instead">,
InGroup<HLSLDXCCompat>;

def err_hlsl_semantic_missing : Error<"semantic annotations must be present "
"for all input and outputs of an entry "
"function or patch constant function">;

// ClangIR frontend errors
def err_cir_to_cir_transform_failed : Error<
"CIR-to-CIR transformation failed">, DefaultFatal;
Expand Down
5 changes: 2 additions & 3 deletions clang/include/clang/Basic/DiagnosticParseKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -1851,9 +1851,8 @@ def note_max_tokens_total_override : Note<"total token limit set here">;

def err_expected_semantic_identifier : Error<
"expected HLSL Semantic identifier">;
def err_invalid_declaration_in_hlsl_buffer : Error<
"invalid declaration inside %select{tbuffer|cbuffer}0">;
def err_unknown_hlsl_semantic : Error<"unknown HLSL semantic %0">;
def err_invalid_declaration_in_hlsl_buffer
: Error<"invalid declaration inside %select{tbuffer|cbuffer}0">;
def err_hlsl_separate_attr_arg_and_number : Error<"wrong argument format for hlsl attribute, use %0 instead">;
def ext_hlsl_access_specifiers : ExtWarn<
"access specifiers are a clang HLSL extension">,
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -13058,6 +13058,11 @@ def err_hlsl_duplicate_parameter_modifier : Error<"duplicate parameter modifier
def err_hlsl_missing_semantic_annotation : Error<
"semantic annotations must be present for all parameters of an entry "
"function or patch constant function">;
def err_hlsl_unknown_semantic : Error<"unknown HLSL semantic %0">;
def err_hlsl_semantic_output_not_supported
: Error<"semantic %0 does not support output">;
def err_hlsl_semantic_indexing_not_supported
: Error<"semantic %0 does not allow indexing">;
def err_hlsl_init_priority_unsupported : Error<
"initializer priorities are not supported in HLSL">;

Expand Down
8 changes: 8 additions & 0 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -5188,6 +5188,14 @@ class Parser : public CodeCompletionHandler {
ParseHLSLAnnotations(Attrs, EndLoc);
}

struct ParsedSemantic {
StringRef Name = "";
unsigned Index = 0;
bool Explicit = false;
};

ParsedSemantic ParseHLSLSemantic();

void ParseHLSLAnnotations(ParsedAttributes &Attrs,
SourceLocation *EndLoc = nullptr,
bool CouldBeBitField = false);
Expand Down
25 changes: 21 additions & 4 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "clang/AST/Attr.h"
#include "clang/AST/Type.h"
#include "clang/AST/TypeLoc.h"
#include "clang/Basic/DiagnosticSema.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Sema/SemaBase.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -129,6 +130,7 @@ class SemaHLSL : public SemaBase {
bool ActOnUninitializedVarDecl(VarDecl *D);
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU);
void CheckEntryPoint(FunctionDecl *FD);
bool isSemanticValid(FunctionDecl *FD, DeclaratorDecl *D);
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
void DiagnoseAttrStageMismatch(
Expand Down Expand Up @@ -161,16 +163,31 @@ class SemaHLSL : public SemaBase {
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_PositionAttr(Decl *D, const ParsedAttr &AL);
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
void handleShaderAttr(Decl *D, const ParsedAttr &AL);
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL);
void handleParamModifierAttr(Decl *D, const ParsedAttr &AL);
bool handleResourceTypeAttr(QualType T, const ParsedAttr &AL);

template <typename T>
T *createSemanticAttr(const ParsedAttr &AL,
std::optional<unsigned> Location) {
T *Attr = ::new (getASTContext()) T(getASTContext(), AL);
if (Attr->isSemanticIndexable())
Attr->setSemanticIndex(Location ? *Location : 0);
else if (Location.has_value()) {
Diag(Attr->getLocation(), diag::err_hlsl_semantic_indexing_not_supported)
<< Attr->getAttrName()->getName();
return nullptr;
}

return Attr;
}

void diagnoseSystemSemanticAttr(Decl *D, const ParsedAttr &AL,
std::optional<unsigned> Index);
void handleSemanticAttr(Decl *D, const ParsedAttr &AL);

void handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL);

bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
Expand Down
7 changes: 6 additions & 1 deletion clang/lib/Basic/Attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,12 @@ AttributeCommonInfo::Kind
AttributeCommonInfo::getParsedKind(const IdentifierInfo *Name,
const IdentifierInfo *ScopeName,
Syntax SyntaxUsed) {
return ::getAttrKind(normalizeName(Name, ScopeName, SyntaxUsed), SyntaxUsed);
AttributeCommonInfo::Kind Kind =
::getAttrKind(normalizeName(Name, ScopeName, SyntaxUsed), SyntaxUsed);
if (SyntaxUsed == AS_HLSLAnnotation &&
Kind == AttributeCommonInfo::Kind::UnknownAttribute)
return AttributeCommonInfo::Kind::AT_HLSLUnparsedSemantic;
return Kind;
}

AttributeCommonInfo::AttrArgsInfo
Expand Down
76 changes: 57 additions & 19 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/Type.h"
#include "clang/Basic/TargetOptions.h"
#include "clang/Frontend/FrontendDiagnostic.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/IR/Constants.h"
Expand Down Expand Up @@ -383,47 +384,82 @@ static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
return B.CreateLoad(Ty, GV);
}

llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
const ParmVarDecl &D,
llvm::Type *Ty) {
assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
llvm::Value *
CGHLSLRuntime::emitSystemSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {
if (HLSLSV_GroupIndexAttr *S =
dyn_cast<HLSLSV_GroupIndexAttr>(ActiveSemantic.Semantic)) {
llvm::Function *GroupIndex =
CGM.getIntrinsic(getFlattenedThreadIdInGroupIntrinsic());
return B.CreateCall(FunctionCallee(GroupIndex));
}
if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {

if (HLSLSV_DispatchThreadIDAttr *S =
dyn_cast<HLSLSV_DispatchThreadIDAttr>(ActiveSemantic.Semantic)) {
llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic();
llvm::Function *ThreadIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
? CGM.getIntrinsic(IntrinID, {CGM.Int32Ty})
: CGM.getIntrinsic(IntrinID);
return buildVectorInput(B, ThreadIDIntrinsic, Ty);
return buildVectorInput(B, ThreadIDIntrinsic, Type);
}
if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {

if (HLSLSV_GroupThreadIDAttr *S =
dyn_cast<HLSLSV_GroupThreadIDAttr>(ActiveSemantic.Semantic)) {
llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic();
llvm::Function *GroupThreadIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
? CGM.getIntrinsic(IntrinID, {CGM.Int32Ty})
: CGM.getIntrinsic(IntrinID);
return buildVectorInput(B, GroupThreadIDIntrinsic, Ty);
return buildVectorInput(B, GroupThreadIDIntrinsic, Type);
}
if (D.hasAttr<HLSLSV_GroupIDAttr>()) {

if (HLSLSV_GroupIDAttr *S =
dyn_cast<HLSLSV_GroupIDAttr>(ActiveSemantic.Semantic)) {
llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic();
llvm::Function *GroupIDIntrinsic =
llvm::Intrinsic::isOverloaded(IntrinID)
? CGM.getIntrinsic(IntrinID, {CGM.Int32Ty})
: CGM.getIntrinsic(IntrinID);
return buildVectorInput(B, GroupIDIntrinsic, Ty);
return buildVectorInput(B, GroupIDIntrinsic, Type);
}
if (D.hasAttr<HLSLSV_PositionAttr>()) {
if (getArch() == llvm::Triple::spirv)
return createSPIRVBuiltinLoad(B, CGM.getModule(), Ty, "sv_position",
/* BuiltIn::Position */ 0);
llvm_unreachable("SV_Position semantic not implemented for this target.");

if (HLSLSV_PositionAttr *S =
dyn_cast<HLSLSV_PositionAttr>(ActiveSemantic.Semantic)) {
if (CGM.getTriple().getEnvironment() == Triple::EnvironmentType::Pixel)
return createSPIRVBuiltinLoad(B, CGM.getModule(), Type,
S->getAttrName()->getName(),
/* BuiltIn::FragCoord */ 15);
}
assert(false && "Unhandled parameter attribute");
return nullptr;

llvm_unreachable("non-handled system semantic. FIXME.");
}

llvm::Value *
CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {

if (!ActiveSemantic.Semantic) {
ActiveSemantic.Semantic = Decl->getAttr<HLSLSemanticAttr>();
if (!ActiveSemantic.Semantic) {
CGM.getDiags().Report(Decl->getInnerLocStart(),
diag::err_hlsl_semantic_missing);
return nullptr;
}
ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
}

return emitSystemSemanticLoad(B, Type, Decl, ActiveSemantic);
}

llvm::Value *
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
SemanticInfo &ActiveSemantic) {
assert(!Type->isStructTy());
return handleScalarSemanticLoad(B, Type, Decl, ActiveSemantic);
}

void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
Expand Down Expand Up @@ -468,8 +504,10 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
Args.emplace_back(PoisonValue::get(Param.getType()));
continue;
}

const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
SemanticInfo ActiveSemantic = {nullptr, 0};
Args.push_back(handleSemanticLoad(B, Param.getType(), PD, ActiveSemantic));
}

CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
Expand Down
Loading