Skip to content

Commit 33b0c09

Browse files
committed
const folding: difftest for bitshift
1 parent b971247 commit 33b0c09

File tree

16 files changed

+310
-2
lines changed

16 files changed

+310
-2
lines changed

tests/difftests/tests/Cargo.lock

Lines changed: 58 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/difftests/tests/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ members = [
2727
"lang/core/ops/matrix_ops/matrix_ops-rust",
2828
"lang/core/ops/matrix_ops/matrix_ops-wgsl",
2929
"lang/core/ops/bitwise_ops/bitwise_ops-rust",
30+
"lang/core/ops/const_fold_int/const-fold-cpu",
31+
"lang/core/ops/const_fold_int/const-fold-shader",
32+
"lang/core/ops/const_fold_int/dynamic-values-cpu",
33+
"lang/core/ops/const_fold_int/dynamic-values-shader",
3034
"lang/core/ops/bitwise_ops/bitwise_ops-wgsl",
3135
"lang/core/ops/trig_ops/trig_ops-rust",
3236
"lang/core/ops/trig_ops/trig_ops-wgsl",
@@ -46,8 +50,6 @@ unexpected_cfgs = { level = "allow", check-cfg = [
4650

4751
[workspace.dependencies]
4852
spirv-std = { path = "../../../crates/spirv-std", version = "=0.9.0" }
49-
spirv-std-types = { path = "../../../crates/spirv-std/shared", version = "=0.9.0" }
50-
spirv-std-macros = { path = "../../../crates/spirv-std/macros", version = "=0.9.0" }
5153
difftest = { path = "../../../tests/difftests/lib" }
5254
# External dependencies that need to be mentioned more than once.
5355
num-traits = { version = "0.2.15", default-features = false }
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[package]
2+
name = "const_fold_int-const-fold-cpu"
3+
edition.workspace = true
4+
5+
[lints]
6+
workspace = true
7+
8+
# GPU deps
9+
[dependencies]
10+
spirv-std.workspace = true
11+
num_enum = { version = "0.7.4", default-features = false }
12+
13+
# CPU deps (for the test harness)
14+
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
15+
difftest.workspace = true
16+
bytemuck.workspace = true
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
use crate::{INTERESTING_PATTERNS, Variants};
2+
use difftest::config::Config;
3+
4+
pub fn run(variant: Variants) {
5+
let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap();
6+
let result = variant
7+
.eval(&INTERESTING_PATTERNS)
8+
.into_iter()
9+
.flatten()
10+
.flatten()
11+
.collect::<Vec<_>>();
12+
config.write_result(&result).unwrap()
13+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#![cfg_attr(target_arch = "spirv", no_std)]
2+
#![allow(arithmetic_overflow)]
3+
4+
#[cfg(not(target_arch = "spirv"))]
5+
pub mod cpu_driver;
6+
pub mod shader;
7+
#[cfg(not(target_arch = "spirv"))]
8+
pub mod shader_driver;
9+
10+
use num_enum::{IntoPrimitive, TryFromPrimitive};
11+
12+
macro_rules! op_u {
13+
($value:expr) => {
14+
[
15+
($value << 0) as u32,
16+
($value << 1) as u32,
17+
($value << 2) as u32,
18+
($value << 30) as u32,
19+
($value << 31) as u32,
20+
($value << 32) as u32,
21+
($value << 33) as u32,
22+
($value << 34) as u32,
23+
($value >> 0) as u32,
24+
($value >> 1) as u32,
25+
($value >> 2) as u32,
26+
($value >> 30) as u32,
27+
($value >> 31) as u32,
28+
($value >> 32) as u32,
29+
($value >> 33) as u32,
30+
($value >> 34) as u32,
31+
]
32+
};
33+
}
34+
35+
macro_rules! op_i {
36+
($value:expr) => {
37+
op_u!($value as i32)
38+
};
39+
}
40+
41+
macro_rules! interesting_patterns {
42+
($op_name:ident) => {
43+
[
44+
$op_name!(0u32),
45+
$op_name!(1u32),
46+
$op_name!(0xFFFFFFFFu32),
47+
$op_name!(0xDEADBEEFu32),
48+
$op_name!(0b10011001100110011001100110011001u32),
49+
$op_name!(0b10000000000000000000000000000001u32),
50+
$op_name!(0x12345678u32),
51+
$op_name!(0x87654321u32),
52+
]
53+
};
54+
}
55+
56+
macro_rules! identity {
57+
($expr:expr) => {
58+
$expr
59+
};
60+
}
61+
62+
pub const INTERESTING_PATTERNS: [u32; 8] = interesting_patterns!(identity);
63+
64+
#[repr(u32)]
65+
#[derive(Copy, Clone, Debug, Eq, PartialEq, TryFromPrimitive, IntoPrimitive)]
66+
pub enum Variants {
67+
/// const folding in rust-gpu
68+
ConstFold,
69+
// /// `const {}` expr for const eval within rustc
70+
// ConstExpr,
71+
/// dynamic values from `input_patterns`
72+
DynamicValues,
73+
}
74+
75+
pub type EvalResult = [[[u32; 16]; 8]; 2];
76+
77+
impl Variants {
78+
pub fn eval(&self, input_patterns: &[u32; 8]) -> EvalResult {
79+
match self {
80+
Variants::ConstFold => [interesting_patterns!(op_u), interesting_patterns!(op_i)],
81+
// Variants::ConstExpr => {
82+
// const { [interesting_patterns!(op_u), interesting_patterns!(op_i)] }
83+
// }
84+
Variants::DynamicValues => [
85+
[
86+
op_u!(input_patterns[0]),
87+
op_u!(input_patterns[1]),
88+
op_u!(input_patterns[2]),
89+
op_u!(input_patterns[3]),
90+
op_u!(input_patterns[4]),
91+
op_u!(input_patterns[5]),
92+
op_u!(input_patterns[6]),
93+
op_u!(input_patterns[7]),
94+
],
95+
[
96+
op_i!(input_patterns[0]),
97+
op_i!(input_patterns[1]),
98+
op_i!(input_patterns[2]),
99+
op_i!(input_patterns[3]),
100+
op_i!(input_patterns[4]),
101+
op_i!(input_patterns[5]),
102+
op_i!(input_patterns[6]),
103+
op_i!(input_patterns[7]),
104+
],
105+
],
106+
}
107+
}
108+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
use const_fold_int_const_fold_cpu::Variants;
2+
3+
fn main() {
4+
const_fold_int_const_fold_cpu::cpu_driver::run(Variants::ConstFold);
5+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
use crate::{EvalResult, Variants};
2+
use spirv_std::spirv;
3+
4+
#[spirv(compute(threads(1)))]
5+
pub fn main_cs(
6+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] variant: &u32,
7+
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] input_patterns: &[u32; 8],
8+
#[spirv(storage_buffer, descriptor_set = 0, binding = 2)] output: &mut EvalResult,
9+
) {
10+
if let Ok(variant) = Variants::try_from(*variant) {
11+
*output = variant.eval(input_patterns);
12+
}
13+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
use crate::{EvalResult, INTERESTING_PATTERNS, Variants};
2+
use difftest::config::Config;
3+
use difftest::scaffold::compute::{BufferConfig, RustComputeShader, WgpuComputeTestMultiBuffer};
4+
5+
pub fn run(variant: Variants) {
6+
let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap();
7+
8+
let test = WgpuComputeTestMultiBuffer::new(
9+
RustComputeShader::default(),
10+
[64, 1, 1],
11+
Vec::from(&[
12+
BufferConfig::read_only(&[u32::from(variant)]),
13+
BufferConfig::read_only(&INTERESTING_PATTERNS),
14+
BufferConfig::writeback(size_of::<EvalResult>()),
15+
]),
16+
);
17+
18+
test.run_test(&config).unwrap();
19+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[package]
2+
name = "const_fold_int-const-fold-shader"
3+
edition.workspace = true
4+
5+
[lints]
6+
workspace = true
7+
8+
[lib]
9+
crate-type = ["dylib"]
10+
11+
# GPU deps
12+
[dependencies]
13+
const_fold_int-const-fold-cpu = { path = "../const-fold-cpu" }
14+
15+
# CPU deps (for the test harness)
16+
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
17+
difftest.workspace = true
18+
bytemuck.workspace = true
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#![cfg_attr(target_arch = "spirv", no_std)]
2+
#![allow(arithmetic_overflow)]
3+
4+
pub use const_fold_int_const_fold_cpu::shader::main_cs;

0 commit comments

Comments
 (0)