Skip to content

Commit f299e40

Browse files
committed
[HLSL] Add support for input semantics in structs
This commit adds the support for semantics annotations on structs, but only for inputs. Due to the current semantics implemented, we cannot test much more than nesting/shadowing. Once user semantics are implemented, we'll be able to test arrays in structs and more complex cases. As-is, this commit has one weakness vs DXC: semantics type validation is not looking at the inner-most type, but the outermost type: ```hlsl struct Inner { uint tid; }; Inner inner : SV_GroupID ``` This sample would fail today because `SV_GroupID` require the type to be an integer. This works in DXC as the inner type is a integer. Because GroupIndex is not correctly validated, I uses this semantic to test the inheritance/shadowing. But this will need to be fixed in a later commit. Requires llvm#152537
1 parent d155488 commit f299e40

File tree

7 files changed

+202
-3
lines changed

7 files changed

+202
-3
lines changed

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,51 @@ CGHLSLRuntime::handleScalarSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
454454
return emitSystemSemanticLoad(B, Type, Decl, ActiveSemantic);
455455
}
456456

457+
llvm::Value *
458+
CGHLSLRuntime::handleStructSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
459+
const clang::DeclaratorDecl *Decl,
460+
SemanticInfo &ActiveSemantic) {
461+
const llvm::StructType *ST = cast<StructType>(Type);
462+
const clang::RecordDecl *RD = Decl->getType()->getAsRecordDecl();
463+
464+
assert(std::distance(RD->field_begin(), RD->field_end()) ==
465+
ST->getNumElements());
466+
467+
if (!ActiveSemantic.Semantic) {
468+
ActiveSemantic.Semantic = Decl->getAttr<HLSLSemanticAttr>();
469+
ActiveSemantic.Index = ActiveSemantic.Semantic
470+
? ActiveSemantic.Semantic->getSemanticIndex()
471+
: 0;
472+
}
473+
474+
llvm::Value *Aggregate = llvm::PoisonValue::get(Type);
475+
auto FieldDecl = RD->field_begin();
476+
for (unsigned I = 0; I < ST->getNumElements(); ++I) {
477+
SemanticInfo Info = ActiveSemantic;
478+
llvm::Value *ChildValue =
479+
handleSemanticLoad(B, ST->getElementType(I), *FieldDecl, Info);
480+
if (!ChildValue) {
481+
CGM.getDiags().Report(Decl->getInnerLocStart(),
482+
diag::note_hlsl_semantic_used_here)
483+
<< Decl;
484+
return nullptr;
485+
}
486+
if (ActiveSemantic.Semantic)
487+
ActiveSemantic = Info;
488+
489+
Aggregate = B.CreateInsertValue(Aggregate, ChildValue, I);
490+
++FieldDecl;
491+
}
492+
493+
return Aggregate;
494+
}
495+
457496
llvm::Value *
458497
CGHLSLRuntime::handleSemanticLoad(IRBuilder<> &B, llvm::Type *Type,
459498
const clang::DeclaratorDecl *Decl,
460499
SemanticInfo &ActiveSemantic) {
461-
assert(!Type->isStructTy());
500+
if (Type->isStructTy())
501+
return handleStructSemanticLoad(B, Type, Decl, ActiveSemantic);
462502
return handleScalarSemanticLoad(B, Type, Decl, ActiveSemantic);
463503
}
464504

@@ -507,8 +547,25 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
507547
}
508548

509549
const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
510-
SemanticInfo ActiveSemantic = {nullptr, 0};
511-
Args.push_back(handleSemanticLoad(B, Param.getType(), PD, ActiveSemantic));
550+
llvm::Value *SemanticValue = nullptr;
551+
if (HLSLParamModifierAttr *MA = PD->getAttr<HLSLParamModifierAttr>()) {
552+
llvm_unreachable("Not handled yet");
553+
} else {
554+
llvm::Type *ParamType =
555+
Param.hasByValAttr() ? Param.getParamByValType() : Param.getType();
556+
SemanticInfo ActiveSemantic = {nullptr, 0};
557+
SemanticValue = handleSemanticLoad(B, ParamType, PD, ActiveSemantic);
558+
if (!SemanticValue)
559+
return;
560+
if (Param.hasByValAttr()) {
561+
llvm::Value *Var = B.CreateAlloca(Param.getParamByValType());
562+
B.CreateStore(SemanticValue, Var);
563+
SemanticValue = Var;
564+
}
565+
}
566+
567+
assert(SemanticValue);
568+
Args.push_back(SemanticValue);
512569
}
513570

514571
CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ class CGHLSLRuntime {
153153
const clang::DeclaratorDecl *Decl,
154154
SemanticInfo &ActiveSemantic);
155155

156+
llvm::Value *handleStructSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
157+
const clang::DeclaratorDecl *Decl,
158+
SemanticInfo &ActiveSemantic);
159+
156160
llvm::Value *handleSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
157161
const clang::DeclaratorDecl *Decl,
158162
SemanticInfo &ActiveSemantic);
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -DTARGET=dx
2+
// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV -DTARGET=spv
3+
4+
5+
struct Input {
6+
uint Idx : SV_DispatchThreadID;
7+
8+
};
9+
10+
// Make sure SV_DispatchThreadID translated into dx.thread.id.
11+
12+
// CHECK: define void @foo()
13+
// CHECK-DXIL: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id(i32 0)
14+
// CHECK-SPIRV: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id.i32(i32 0)
15+
// CHECK: %[[#TMP:]] = insertvalue %struct.Input poison, i32 %[[#ID]], 0
16+
// CHECK: %[[#VAR:]] = alloca %struct.Input, align 8
17+
// CHECK: store %struct.Input %[[#TMP]], ptr %[[#VAR]], align 4
18+
// CHECK-DXIL: call void @{{.*}}foo{{.*}}(ptr %[[#VAR]])
19+
// CHECK-SPIRV: call spir_func void @{{.*}}foo{{.*}}(ptr %[[#VAR]])
20+
[shader("compute")]
21+
[numthreads(8,8,1)]
22+
void foo(Input input) {}
23+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -DTARGET=dx
2+
// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV -DTARGET=spv
3+
4+
5+
struct Input {
6+
uint Idx : SV_DispatchThreadID;
7+
uint Gid : SV_GroupID;
8+
};
9+
10+
// Make sure SV_DispatchThreadID translated into dx.thread.id.
11+
12+
// CHECK: define void @foo()
13+
// CHECK-DXIL: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id(i32 0)
14+
// CHECK-SPIRV: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id.i32(i32 0)
15+
// CHECK: %[[#TMP1:]] = insertvalue %struct.Input poison, i32 %[[#ID]], 0
16+
// CHECK-DXIL: %[[#GID:]] = call i32 @llvm.[[TARGET]].group.id(i32 0)
17+
// CHECK-SPIRV:%[[#GID:]] = call i32 @llvm.[[TARGET]].group.id.i32(i32 0)
18+
// CHECK: %[[#TMP2:]] = insertvalue %struct.Input %[[#TMP1]], i32 %[[#GID]], 1
19+
// CHECK: %[[#VAR:]] = alloca %struct.Input, align 8
20+
// CHECK: store %struct.Input %[[#TMP2]], ptr %[[#VAR]], align 4
21+
// CHECK-DXIL: call void @{{.*}}foo{{.*}}(ptr %[[#VAR]])
22+
// CHECK-SPIRV: call spir_func void @{{.*}}foo{{.*}}(ptr %[[#VAR]])
23+
[shader("compute")]
24+
[numthreads(8,8,1)]
25+
void foo(Input input) {}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -DTARGET=dx
2+
// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV -DTARGET=spv
3+
4+
5+
struct Inner {
6+
uint Gid;
7+
};
8+
9+
struct Input {
10+
uint Idx : SV_DispatchThreadID;
11+
Inner inner : SV_GroupIndex;
12+
};
13+
14+
// Make sure SV_DispatchThreadID translated into dx.thread.id.
15+
16+
// CHECK: define void @foo()
17+
// CHECK-DXIL: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id(i32 0)
18+
// CHECK-SPIRV: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id.i32(i32 0)
19+
// CHECK: %[[#TMP1:]] = insertvalue %struct.Input poison, i32 %[[#ID]], 0
20+
// CHECK-DXIL: %[[#GID:]] = call i32 @llvm.dx.flattened.thread.id.in.group()
21+
// CHECK-SPIRV:%[[#GID:]] = call i32 @llvm.spv.flattened.thread.id.in.group()
22+
// CHECK: %[[#TMP2:]] = insertvalue %struct.Inner poison, i32 %[[#GID]], 0
23+
// CHECK: %[[#TMP3:]] = insertvalue %struct.Input %[[#TMP1]], %struct.Inner %[[#TMP2]], 1
24+
// CHECK: %[[#VAR:]] = alloca %struct.Input, align 8
25+
// CHECK: store %struct.Input %[[#TMP3]], ptr %[[#VAR]], align 4
26+
// CHECK-DXIL: call void @{{.*}}foo{{.*}}(ptr %[[#VAR]])
27+
// CHECK-SPIRV: call spir_func void @{{.*}}foo{{.*}}(ptr %[[#VAR]])
28+
[shader("compute")]
29+
[numthreads(8,8,1)]
30+
void foo(Input input) {}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -DTARGET=dx
2+
// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV -DTARGET=spv
3+
4+
5+
struct Inner {
6+
uint Gid : SV_GroupID;
7+
};
8+
9+
struct Input {
10+
uint Idx : SV_DispatchThreadID;
11+
Inner inner : SV_GroupIndex;
12+
};
13+
14+
// Make sure SV_DispatchThreadID translated into dx.thread.id.
15+
16+
// CHECK: define void @foo()
17+
// CHECK-DXIL: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id(i32 0)
18+
// CHECK-SPIRV: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id.i32(i32 0)
19+
// CHECK: %[[#TMP1:]] = insertvalue %struct.Input poison, i32 %[[#ID]], 0
20+
// CHECK-DXIL: %[[#GID:]] = call i32 @llvm.dx.flattened.thread.id.in.group()
21+
// CHECK-SPIRV:%[[#GID:]] = call i32 @llvm.spv.flattened.thread.id.in.group()
22+
// CHECK: %[[#TMP2:]] = insertvalue %struct.Inner poison, i32 %[[#GID]], 0
23+
// CHECK: %[[#TMP3:]] = insertvalue %struct.Input %[[#TMP1]], %struct.Inner %[[#TMP2]], 1
24+
// CHECK: %[[#VAR:]] = alloca %struct.Input, align 8
25+
// CHECK: store %struct.Input %[[#TMP3]], ptr %[[#VAR]], align 4
26+
// CHECK-DXIL: call void @{{.*}}foo{{.*}}(ptr %[[#VAR]])
27+
// CHECK-SPIRV: call spir_func void @{{.*}}foo{{.*}}(ptr %[[#VAR]])
28+
[shader("compute")]
29+
[numthreads(8,8,1)]
30+
void foo(Input input) {}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -DTARGET=dx
2+
// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV -DTARGET=spv
3+
4+
5+
struct Inner {
6+
uint Gid : SV_GroupID;
7+
};
8+
9+
struct Input {
10+
uint Idx : SV_DispatchThreadID;
11+
Inner inner;
12+
};
13+
14+
// Make sure SV_DispatchThreadID translated into dx.thread.id.
15+
16+
// CHECK: define void @foo()
17+
// CHECK-DXIL: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id(i32 0)
18+
// CHECK-SPIRV: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id.i32(i32 0)
19+
// CHECK: %[[#TMP1:]] = insertvalue %struct.Input poison, i32 %[[#ID]], 0
20+
// CHECK-DXIL: %[[#GID:]] = call i32 @llvm.[[TARGET]].group.id(i32 0)
21+
// CHECK-SPIRV:%[[#GID:]] = call i32 @llvm.[[TARGET]].group.id.i32(i32 0)
22+
// CHECK: %[[#TMP2:]] = insertvalue %struct.Inner poison, i32 %[[#GID]], 0
23+
// CHECK: %[[#TMP3:]] = insertvalue %struct.Input %[[#TMP1]], %struct.Inner %[[#TMP2]], 1
24+
// CHECK: %[[#VAR:]] = alloca %struct.Input, align 8
25+
// CHECK: store %struct.Input %[[#TMP3]], ptr %[[#VAR]], align 4
26+
// CHECK-DXIL: call void @{{.*}}foo{{.*}}(ptr %[[#VAR]])
27+
// CHECK-SPIRV: call spir_func void @{{.*}}foo{{.*}}(ptr %[[#VAR]])
28+
[shader("compute")]
29+
[numthreads(8,8,1)]
30+
void foo(Input input) {}

0 commit comments

Comments
 (0)