Skip to content

Commit 0dcbc9c

Browse files
committed
Make ash backend use ComputeShader trait to reduce code duplication
1 parent 581cf4a commit 0dcbc9c

File tree

4 files changed

+116
-50
lines changed

4 files changed

+116
-50
lines changed

tests/difftests/lib/src/scaffold/compute/backend.rs

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,32 @@ pub enum BufferUsage {
1717
Uniform,
1818
}
1919

20+
use super::SpirvShader;
21+
2022
/// A generic trait for compute backends
2123
pub trait ComputeBackend: Sized {
2224
/// Initialize the backend
2325
fn init() -> Result<Self>;
2426

25-
/// Create and run a compute shader with multiple buffers
27+
/// Create and run a compute shader with multiple buffers from raw SPIRV bytes
2628
fn run_compute(
2729
&self,
2830
spirv_bytes: &[u8],
2931
entry_point: &str,
3032
dispatch: [u32; 3],
3133
buffers: Vec<BufferConfig>,
3234
) -> Result<Vec<Vec<u8>>>;
35+
36+
/// Create and run a compute shader with multiple buffers from a shader object
37+
fn run_compute_shader<S: SpirvShader>(
38+
&self,
39+
shader: &S,
40+
dispatch: [u32; 3],
41+
buffers: Vec<BufferConfig>,
42+
) -> Result<Vec<Vec<u8>>> {
43+
let (spirv_bytes, entry_point) = shader.spirv_bytes()?;
44+
self.run_compute(&spirv_bytes, &entry_point, dispatch, buffers)
45+
}
3346
}
3447

3548
/// A compute test that can run on any backend
@@ -82,3 +95,43 @@ impl<B: ComputeBackend> ComputeTest<B> {
8295
anyhow::bail!("No storage buffer output found")
8396
}
8497
}
98+
99+
/// A compute test that can run on any backend using a shader object
100+
pub struct ComputeShaderTest<B: ComputeBackend, S: SpirvShader> {
101+
backend: B,
102+
shader: S,
103+
dispatch: [u32; 3],
104+
buffers: Vec<BufferConfig>,
105+
}
106+
107+
impl<B: ComputeBackend, S: SpirvShader> ComputeShaderTest<B, S> {
108+
pub fn new(shader: S, dispatch: [u32; 3], buffers: Vec<BufferConfig>) -> Result<Self> {
109+
Ok(Self {
110+
backend: B::init()?,
111+
shader,
112+
dispatch,
113+
buffers,
114+
})
115+
}
116+
117+
pub fn run(self) -> Result<Vec<Vec<u8>>> {
118+
self.backend
119+
.run_compute_shader(&self.shader, self.dispatch, self.buffers)
120+
}
121+
122+
pub fn run_test(self, config: &Config) -> Result<()> {
123+
let buffers = self.buffers.clone();
124+
let outputs = self.run()?;
125+
// Write the first storage buffer output to the file
126+
for (output, buffer_config) in outputs.iter().zip(&buffers) {
127+
if matches!(buffer_config.usage, BufferUsage::Storage) && !output.is_empty() {
128+
use std::fs::File;
129+
use std::io::Write;
130+
let mut f = File::create(&config.output_path)?;
131+
f.write_all(output)?;
132+
return Ok(());
133+
}
134+
}
135+
anyhow::bail!("No storage buffer output found")
136+
}
137+
}

tests/difftests/lib/src/scaffold/compute/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ mod backend;
33
mod wgpu;
44

55
pub use ash::AshBackend;
6-
pub use backend::{BufferConfig, BufferUsage, ComputeBackend, ComputeTest};
6+
pub use backend::{BufferConfig, BufferUsage, ComputeBackend, ComputeShaderTest, ComputeTest};
77
pub use wgpu::{
8-
RustComputeShader, WgpuBackend, WgpuComputeTest, WgpuComputeTestMultiBuffer,
9-
WgpuComputeTestPushConstants, WgslComputeShader,
8+
RustComputeShader, SpirvShader, WgpuBackend, WgpuComputeTest, WgpuComputeTestMultiBuffer,
9+
WgpuComputeTestPushConstants, WgpuShader, WgslComputeShader,
1010
};

tests/difftests/lib/src/scaffold/compute/wgpu.rs

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,15 @@ use super::backend::{self, ComputeBackend};
1818
pub type BufferConfig = backend::BufferConfig;
1919
pub type BufferUsage = backend::BufferUsage;
2020

21-
/// Trait that creates a shader module and provides its entry point.
22-
pub trait ComputeShader {
21+
/// Trait for shaders that can provide SPIRV bytes.
22+
pub trait SpirvShader {
23+
/// Returns the SPIRV bytes and entry point name.
24+
fn spirv_bytes(&self) -> anyhow::Result<(Vec<u8>, String)>;
25+
}
26+
27+
/// Trait for shaders that can create wgpu modules.
28+
pub trait WgpuShader {
29+
/// Creates a wgpu shader module.
2330
fn create_module(
2431
&self,
2532
device: &wgpu::Device,
@@ -29,25 +36,46 @@ pub trait ComputeShader {
2936
/// A compute shader written in Rust compiled with spirv-builder.
3037
pub struct RustComputeShader {
3138
pub path: PathBuf,
39+
pub target: String,
40+
pub capabilities: Vec<spirv_builder::Capability>,
3241
}
3342

3443
impl RustComputeShader {
3544
pub fn new<P: Into<PathBuf>>(path: P) -> Self {
36-
Self { path: path.into() }
45+
Self {
46+
path: path.into(),
47+
target: "spirv-unknown-vulkan1.1".to_string(),
48+
capabilities: Vec::new(),
49+
}
50+
}
51+
52+
pub fn with_target<P: Into<PathBuf>>(path: P, target: impl Into<String>) -> Self {
53+
Self {
54+
path: path.into(),
55+
target: target.into(),
56+
capabilities: Vec::new(),
57+
}
58+
}
59+
60+
pub fn with_capability(mut self, capability: spirv_builder::Capability) -> Self {
61+
self.capabilities.push(capability);
62+
self
3763
}
3864
}
3965

40-
impl ComputeShader for RustComputeShader {
41-
fn create_module(
42-
&self,
43-
device: &wgpu::Device,
44-
) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> {
45-
let builder = SpirvBuilder::new(&self.path, "spirv-unknown-vulkan1.1")
66+
impl SpirvShader for RustComputeShader {
67+
fn spirv_bytes(&self) -> anyhow::Result<(Vec<u8>, String)> {
68+
let mut builder = SpirvBuilder::new(&self.path, &self.target)
4669
.print_metadata(spirv_builder::MetadataPrintout::None)
4770
.release(true)
4871
.multimodule(false)
4972
.shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit)
5073
.preserve_bindings(true);
74+
75+
for capability in &self.capabilities {
76+
builder = builder.capability(*capability);
77+
}
78+
5179
let artifact = builder.build().context("SpirvBuilder::build() failed")?;
5280

5381
if artifact.entry_points.len() != 1 {
@@ -66,6 +94,17 @@ impl ComputeShader for RustComputeShader {
6694
}
6795
};
6896

97+
Ok((shader_bytes, entry_point))
98+
}
99+
}
100+
101+
impl WgpuShader for RustComputeShader {
102+
fn create_module(
103+
&self,
104+
device: &wgpu::Device,
105+
) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> {
106+
let (shader_bytes, entry_point) = self.spirv_bytes()?;
107+
69108
if shader_bytes.len() % 4 != 0 {
70109
anyhow::bail!("SPIR-V binary length is not a multiple of 4");
71110
}
@@ -93,7 +132,7 @@ impl WgslComputeShader {
93132
}
94133
}
95134

96-
impl ComputeShader for WgslComputeShader {
135+
impl WgpuShader for WgslComputeShader {
97136
fn create_module(
98137
&self,
99138
device: &wgpu::Device,
@@ -133,7 +172,7 @@ pub struct WgpuComputeTestPushConstants<S> {
133172

134173
impl<S> WgpuComputeTest<S>
135174
where
136-
S: ComputeShader,
175+
S: WgpuShader,
137176
{
138177
pub fn new(shader: S, dispatch: [u32; 3], output_bytes: u64) -> Self {
139178
Self {
@@ -544,7 +583,7 @@ impl Default for RustComputeShader {
544583

545584
impl<S> WgpuComputeTestMultiBuffer<S>
546585
where
547-
S: ComputeShader,
586+
S: WgpuShader,
548587
{
549588
pub fn new(shader: S, dispatch: [u32; 3], buffers: Vec<BufferConfig>) -> Self {
550589
Self {
@@ -714,7 +753,7 @@ where
714753

715754
impl<S> WgpuComputeTestPushConstants<S>
716755
where
717-
S: ComputeShader,
756+
S: WgpuShader,
718757
{
719758
pub fn new(
720759
shader: S,

tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,10 @@ fn main() {
1717
// Run the actual test on other platforms
1818
#[cfg(not(target_os = "macos"))]
1919
{
20-
use difftest::scaffold::compute::{AshBackend, BufferConfig, BufferUsage, ComputeTest};
21-
use difftest::spirv_builder::{
22-
Capability, MetadataPrintout, ModuleResult, ShaderPanicStrategy, SpirvBuilder,
23-
};
24-
use std::fs;
25-
26-
// Build the Rust shader to SPIR-V
27-
let builder = SpirvBuilder::new(".", "spirv-unknown-vulkan1.2")
28-
.print_metadata(MetadataPrintout::None)
29-
.release(true)
30-
.multimodule(false)
31-
.shader_panic_strategy(ShaderPanicStrategy::SilentExit)
32-
.preserve_bindings(true)
33-
.capability(Capability::VulkanMemoryModel);
34-
35-
let artifact = builder.build().expect("Failed to build SPIR-V");
36-
37-
if artifact.entry_points.len() != 1 {
38-
panic!(
39-
"Expected exactly one entry point, found {}",
40-
artifact.entry_points.len()
41-
);
42-
}
43-
let entry_point = artifact.entry_points.into_iter().next().unwrap();
44-
45-
let spirv_bytes = match artifact.module {
46-
ModuleResult::SingleModule(path) => {
47-
fs::read(&path).expect("Failed to read SPIR-V file")
48-
}
49-
ModuleResult::MultiModule(_) => panic!("Unexpected multi-module result"),
20+
use difftest::scaffold::compute::{
21+
AshBackend, BufferConfig, BufferUsage, ComputeShaderTest, RustComputeShader,
5022
};
23+
use difftest::spirv_builder::Capability;
5124

5225
// Initialize input buffer with values to sum
5326
let input_data: Vec<u32> = (1..=64).collect();
@@ -66,9 +39,10 @@ fn main() {
6639
},
6740
];
6841

69-
let test = ComputeTest::<AshBackend>::new(
70-
spirv_bytes,
71-
entry_point,
42+
let shader = RustComputeShader::with_target(".", "spirv-unknown-vulkan1.2")
43+
.with_capability(Capability::VulkanMemoryModel);
44+
let test = ComputeShaderTest::<AshBackend, _>::new(
45+
shader,
7246
[1, 1, 1], // Single workgroup with 64 threads
7347
buffers,
7448
)

0 commit comments

Comments
 (0)