From 263667deaac8db9c2a1e2d78c7fdcfefab081c56 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Fri, 4 Jul 2025 11:15:51 +0200 Subject: [PATCH 01/24] Add 7 new difftests and push constants support New tests: - bitwise_ops: bit manipulation operations - trig_ops: trigonometric functions - control_flow_complex: nested loops and complex control flow - vector_swizzle: vector component access and swizzling - memory_barriers: workgroup memory synchronization - vector_extract_insert: dynamic vector element access - push_constants: push constants in compute shaders Infrastructure changes: - Add WgpuComputeTestPushConstants for push constant support - Enable PUSH_CONSTANTS feature in wgpu when needed - Register all existing unregistered tests in workspace --- .../difftests/lib/src/scaffold/compute/mod.rs | 5 +- .../lib/src/scaffold/compute/wgpu.rs | 480 +++++++++++++++++- tests/difftests/tests/Cargo.lock | 228 +++++++++ tests/difftests/tests/Cargo.toml | 29 ++ .../atomic_ops/atomic_ops-rust/Cargo.toml | 20 + .../atomic_ops/atomic_ops-rust/src/lib.rs | 43 ++ .../atomic_ops/atomic_ops-rust/src/main.rs | 37 ++ .../atomic_ops/atomic_ops-wgsl/Cargo.toml | 10 + .../atomic_ops/atomic_ops-wgsl/shader.wgsl | 34 ++ .../atomic_ops/atomic_ops-wgsl/src/main.rs | 33 ++ .../memory_barriers-rust/Cargo.toml | 17 + .../memory_barriers-rust/src/lib.rs | 66 +++ .../memory_barriers-rust/src/main.rs | 27 + .../memory_barriers-wgsl/Cargo.toml | 9 + .../memory_barriers-wgsl/shader.wgsl | 60 +++ .../memory_barriers-wgsl/src/main.rs | 27 + tests/difftests/tests/arch/mod.rs | 3 + .../push_constants-rust/Cargo.toml | 19 + .../push_constants-rust/src/lib.rs | 83 +++ .../push_constants-rust/src/main.rs | 53 ++ .../push_constants-wgsl/Cargo.toml | 10 + .../push_constants-wgsl/shader.wgsl | 79 +++ .../push_constants-wgsl/src/main.rs | 53 ++ .../vector_extract_insert-rust/Cargo.toml | 18 + .../vector_extract_insert-rust/src/lib.rs | 78 +++ .../vector_extract_insert-rust/src/main.rs | 16 + .../vector_extract_insert-wgsl/Cargo.toml | 9 + .../vector_extract_insert-wgsl/shader.wgsl | 72 +++ .../vector_extract_insert-wgsl/src/main.rs | 16 + .../workgroup_memory-rust/Cargo.toml | 20 + .../workgroup_memory-rust/src/lib.rs | 73 +++ .../workgroup_memory-rust/src/main.rs | 37 ++ .../workgroup_memory-wgsl/Cargo.toml | 10 + .../workgroup_memory-wgsl/shader.wgsl | 59 +++ .../workgroup_memory-wgsl/src/main.rs | 33 ++ .../control_flow/control_flow-rust/Cargo.toml | 20 + .../control_flow/control_flow-rust/src/lib.rs | 103 ++++ .../control_flow-rust/src/main.rs | 37 ++ .../control_flow/control_flow-wgsl/Cargo.toml | 10 + .../control_flow-wgsl/shader.wgsl | 115 +++++ .../control_flow-wgsl/src/main.rs | 33 ++ .../control_flow_complex-rust/Cargo.toml | 17 + .../control_flow_complex-rust/src/lib.rs | 111 ++++ .../control_flow_complex-rust/src/main.rs | 15 + .../control_flow_complex-wgsl/Cargo.toml | 9 + .../control_flow_complex-wgsl/shader.wgsl | 120 +++++ .../control_flow_complex-wgsl/src/main.rs | 15 + .../bitwise_ops/bitwise_ops-rust/Cargo.toml | 17 + .../bitwise_ops/bitwise_ops-rust/src/lib.rs | 38 ++ .../bitwise_ops/bitwise_ops-rust/src/main.rs | 16 + .../bitwise_ops/bitwise_ops-wgsl/Cargo.toml | 9 + .../bitwise_ops/bitwise_ops-wgsl/shader.wgsl | 41 ++ .../bitwise_ops/bitwise_ops-wgsl/src/main.rs | 16 + .../ops/math_ops/math_ops-rust/Cargo.toml | 20 + .../ops/math_ops/math_ops-rust/src/lib.rs | 55 ++ .../ops/math_ops/math_ops-rust/src/main.rs | 58 +++ .../ops/math_ops/math_ops-wgsl/Cargo.toml | 10 + .../ops/math_ops/math_ops-wgsl/shader.wgsl | 51 ++ .../ops/math_ops/math_ops-wgsl/src/main.rs | 54 ++ .../ops/matrix_ops/matrix_ops-rust/Cargo.toml | 20 + .../ops/matrix_ops/matrix_ops-rust/src/lib.rs | 156 ++++++ .../matrix_ops/matrix_ops-rust/src/main.rs | 58 +++ .../ops/matrix_ops/matrix_ops-wgsl/Cargo.toml | 10 + .../matrix_ops/matrix_ops-wgsl/shader.wgsl | 177 +++++++ .../matrix_ops/matrix_ops-wgsl/src/main.rs | 54 ++ .../tests/lang/core/ops/matrix_ops/mod.rs | 16 + .../ops/trig_ops/trig_ops-rust/Cargo.toml | 17 + .../ops/trig_ops/trig_ops-rust/src/lib.rs | 51 ++ .../ops/trig_ops/trig_ops-rust/src/main.rs | 15 + .../ops/trig_ops/trig_ops-wgsl/Cargo.toml | 9 + .../ops/trig_ops/trig_ops-wgsl/shader.wgsl | 47 ++ .../ops/trig_ops/trig_ops-wgsl/src/main.rs | 15 + .../ops/vector_ops/vector_ops-rust/Cargo.toml | 20 + .../ops/vector_ops/vector_ops-rust/src/lib.rs | 141 +++++ .../vector_ops/vector_ops-rust/src/main.rs | 58 +++ .../ops/vector_ops/vector_ops-wgsl/Cargo.toml | 10 + .../vector_ops/vector_ops-wgsl/shader.wgsl | 136 +++++ .../vector_ops/vector_ops-wgsl/src/main.rs | 54 ++ .../vector_swizzle-rust/Cargo.toml | 18 + .../vector_swizzle-rust/src/lib.rs | 71 +++ .../vector_swizzle-rust/src/main.rs | 15 + .../vector_swizzle-wgsl/Cargo.toml | 9 + .../vector_swizzle-wgsl/shader.wgsl | 64 +++ .../vector_swizzle-wgsl/src/main.rs | 15 + .../array_access/array_access-rust/Cargo.toml | 20 + .../array_access/array_access-rust/src/lib.rs | 48 ++ .../array_access-rust/src/main.rs | 46 ++ .../array_access/array_access-wgsl/Cargo.toml | 10 + .../array_access-wgsl/shader.wgsl | 47 ++ .../array_access-wgsl/src/main.rs | 42 ++ 90 files changed, 4321 insertions(+), 4 deletions(-) create mode 100644 tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/Cargo.toml create mode 100644 tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/lib.rs create mode 100644 tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/main.rs create mode 100644 tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/src/main.rs create mode 100644 tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/Cargo.toml create mode 100644 tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/lib.rs create mode 100644 tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/main.rs create mode 100644 tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/src/main.rs create mode 100644 tests/difftests/tests/arch/mod.rs create mode 100644 tests/difftests/tests/arch/push_constants/push_constants-rust/Cargo.toml create mode 100644 tests/difftests/tests/arch/push_constants/push_constants-rust/src/lib.rs create mode 100644 tests/difftests/tests/arch/push_constants/push_constants-rust/src/main.rs create mode 100644 tests/difftests/tests/arch/push_constants/push_constants-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/arch/push_constants/push_constants-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/arch/push_constants/push_constants-wgsl/src/main.rs create mode 100644 tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/Cargo.toml create mode 100644 tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/lib.rs create mode 100644 tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/main.rs create mode 100644 tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/src/main.rs create mode 100644 tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml create mode 100644 tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs create mode 100644 tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs create mode 100644 tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs create mode 100644 tests/difftests/tests/lang/control_flow/control_flow-rust/Cargo.toml create mode 100644 tests/difftests/tests/lang/control_flow/control_flow-rust/src/lib.rs create mode 100644 tests/difftests/tests/lang/control_flow/control_flow-rust/src/main.rs create mode 100644 tests/difftests/tests/lang/control_flow/control_flow-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/lang/control_flow/control_flow-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/lang/control_flow/control_flow-wgsl/src/main.rs create mode 100644 tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/Cargo.toml create mode 100644 tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/lib.rs create mode 100644 tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/main.rs create mode 100644 tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/lib.rs create mode 100644 tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs create mode 100644 tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs create mode 100644 tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/matrix_ops/mod.rs create mode 100644 tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/lib.rs create mode 100644 tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs create mode 100644 tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/lib.rs create mode 100644 tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/main.rs create mode 100644 tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/src/main.rs create mode 100644 tests/difftests/tests/storage_class/array_access/array_access-rust/Cargo.toml create mode 100644 tests/difftests/tests/storage_class/array_access/array_access-rust/src/lib.rs create mode 100644 tests/difftests/tests/storage_class/array_access/array_access-rust/src/main.rs create mode 100644 tests/difftests/tests/storage_class/array_access/array_access-wgsl/Cargo.toml create mode 100644 tests/difftests/tests/storage_class/array_access/array_access-wgsl/shader.wgsl create mode 100644 tests/difftests/tests/storage_class/array_access/array_access-wgsl/src/main.rs diff --git a/tests/difftests/lib/src/scaffold/compute/mod.rs b/tests/difftests/lib/src/scaffold/compute/mod.rs index 0478faecfd..47dc6cafbb 100644 --- a/tests/difftests/lib/src/scaffold/compute/mod.rs +++ b/tests/difftests/lib/src/scaffold/compute/mod.rs @@ -1,2 +1,5 @@ mod wgpu; -pub use wgpu::{RustComputeShader, WgpuComputeTest, WgslComputeShader}; +pub use wgpu::{ + BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTest, WgpuComputeTestMultiBuffer, + WgpuComputeTestPushConstants, WgslComputeShader, +}; diff --git a/tests/difftests/lib/src/scaffold/compute/wgpu.rs b/tests/difftests/lib/src/scaffold/compute/wgpu.rs index 08c32f0937..20b3491a44 100644 --- a/tests/difftests/lib/src/scaffold/compute/wgpu.rs +++ b/tests/difftests/lib/src/scaffold/compute/wgpu.rs @@ -109,6 +109,36 @@ pub struct WgpuComputeTest { output_bytes: u64, } +/// More flexible compute test that supports multiple buffers. +pub struct WgpuComputeTestMultiBuffer { + shader: S, + dispatch: [u32; 3], + buffers: Vec, +} + +/// Compute test that supports push constants. +pub struct WgpuComputeTestPushConstants { + shader: S, + dispatch: [u32; 3], + buffers: Vec, + push_constants_size: u32, + push_constants_data: Vec, +} + +#[derive(Clone)] +pub struct BufferConfig { + pub size: u64, + pub usage: BufferUsage, + pub initial_data: Option>, +} + +#[derive(Clone, Copy, PartialEq)] +pub enum BufferUsage { + Storage, + StorageReadOnly, + Uniform, +} + impl WgpuComputeTest where S: ComputeShader, @@ -122,6 +152,10 @@ where } fn init() -> anyhow::Result<(wgpu::Device, wgpu::Queue)> { + Self::init_with_features(wgpu::Features::empty()) + } + + fn init_with_features(features: wgpu::Features) -> anyhow::Result<(wgpu::Device, wgpu::Queue)> { block_on(async { let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor { #[cfg(target_os = "linux")] @@ -143,10 +177,13 @@ where .request_device(&wgpu::DeviceDescriptor { label: Some("wgpu Device"), #[cfg(target_os = "linux")] - required_features: wgpu::Features::SPIRV_SHADER_PASSTHROUGH, + required_features: wgpu::Features::SPIRV_SHADER_PASSTHROUGH | features, #[cfg(not(target_os = "linux"))] - required_features: wgpu::Features::empty(), - required_limits: wgpu::Limits::default(), + required_features: features, + required_limits: wgpu::Limits { + max_push_constant_size: 128, + ..wgpu::Limits::default() + }, memory_hints: Default::default(), trace: Default::default(), }) @@ -335,3 +372,440 @@ impl Default for RustComputeShader { Self::new(PathBuf::from(manifest_dir)) } } + +impl WgpuComputeTestMultiBuffer +where + S: ComputeShader, +{ + pub fn new(shader: S, dispatch: [u32; 3], buffers: Vec) -> Self { + Self { + shader, + dispatch, + buffers, + } + } + + pub fn new_with_sizes(shader: S, dispatch: [u32; 3], sizes: &[u64]) -> Self { + let buffers = sizes + .iter() + .map(|&size| BufferConfig { + size, + usage: BufferUsage::Storage, + initial_data: None, + }) + .collect(); + Self::new(shader, dispatch, buffers) + } + + pub fn run(self) -> anyhow::Result>> { + let (device, queue) = WgpuComputeTest::::init()?; + let (module, entrypoint) = self.shader.create_module(&device)?; + let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Compute Pipeline"), + layout: None, + module: &module, + entry_point: entrypoint.as_deref(), + compilation_options: PipelineCompilationOptions::default(), + cache: None, + }); + + // Create buffers. + let mut gpu_buffers = Vec::new(); + + for (i, buffer_config) in self.buffers.iter().enumerate() { + let usage = match buffer_config.usage { + BufferUsage::Storage => wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + BufferUsage::StorageReadOnly => { + wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC + } + BufferUsage::Uniform => wgpu::BufferUsages::UNIFORM, + }; + + let buffer = if let Some(initial_data) = &buffer_config.initial_data { + device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some(&format!("Buffer {}", i)), + contents: initial_data, + usage, + }) + } else { + let buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some(&format!("Buffer {}", i)), + size: buffer_config.size, + usage, + mapped_at_creation: true, + }); + { + // Zero the buffer. + let initial_data = vec![0u8; buffer_config.size as usize]; + let mut mapping = buffer.slice(..).get_mapped_range_mut(); + mapping.copy_from_slice(&initial_data); + } + buffer.unmap(); + buffer + }; + + gpu_buffers.push(buffer); + } + + // Create bind entries after all buffers are created + let bind_entries: Vec<_> = gpu_buffers + .iter() + .enumerate() + .map(|(i, buffer)| wgpu::BindGroupEntry { + binding: i as u32, + resource: buffer.as_entire_binding(), + }) + .collect(); + + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &pipeline.get_bind_group_layout(0), + entries: &bind_entries, + label: Some("Compute Bind Group"), + }); + + let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Compute Encoder"), + }); + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Compute Pass"), + timestamp_writes: Default::default(), + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, &bind_group, &[]); + pass.dispatch_workgroups(self.dispatch[0], self.dispatch[1], self.dispatch[2]); + } + + // Create staging buffers and copy results. + let mut staging_buffers = Vec::new(); + for (i, buffer_config) in self.buffers.iter().enumerate() { + if matches!( + buffer_config.usage, + BufferUsage::Storage | BufferUsage::StorageReadOnly + ) { + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some(&format!("Staging Buffer {}", i)), + size: buffer_config.size, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + encoder.copy_buffer_to_buffer( + &gpu_buffers[i], + 0, + &staging_buffer, + 0, + buffer_config.size, + ); + staging_buffers.push(Some(staging_buffer)); + } else { + staging_buffers.push(None); + } + } + + queue.submit(Some(encoder.finish())); + + // Read back results. + let mut results = Vec::new(); + for staging_buffer in staging_buffers.into_iter() { + if let Some(buffer) = staging_buffer { + let buffer_slice = buffer.slice(..); + let (sender, receiver) = futures::channel::oneshot::channel(); + buffer_slice.map_async(wgpu::MapMode::Read, move |res| { + let _ = sender.send(res); + }); + device.poll(wgpu::PollType::Wait)?; + block_on(receiver) + .context("mapping canceled")? + .context("mapping failed")?; + let data = buffer_slice.get_mapped_range().to_vec(); + buffer.unmap(); + results.push(data); + } else { + results.push(Vec::new()); + } + } + + Ok(results) + } + + pub fn run_test(self, config: &Config) -> anyhow::Result<()> { + let buffers = self.buffers.clone(); + let outputs = self.run()?; + // Write the first storage buffer output to the file. + for (output, buffer_config) in outputs.iter().zip(&buffers) { + if matches!(buffer_config.usage, BufferUsage::Storage) && !output.is_empty() { + let mut f = File::create(&config.output_path)?; + f.write_all(output)?; + return Ok(()); + } + } + anyhow::bail!("No storage buffer output found") + } +} + +impl WgpuComputeTestPushConstants +where + S: ComputeShader, +{ + pub fn new( + shader: S, + dispatch: [u32; 3], + buffers: Vec, + push_constants_size: u32, + push_constants_data: Vec, + ) -> Self { + Self { + shader, + dispatch, + buffers, + push_constants_size, + push_constants_data, + } + } + + pub fn run(self) -> anyhow::Result>> { + let (device, queue) = + WgpuComputeTest::::init_with_features(wgpu::Features::PUSH_CONSTANTS)?; + let (module, entrypoint) = self.shader.create_module(&device)?; + + // Create pipeline layout with push constants + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("Bind Group Layout"), + entries: &self + .buffers + .iter() + .enumerate() + .map(|(i, buffer_config)| wgpu::BindGroupLayoutEntry { + binding: i as u32, + visibility: wgpu::ShaderStages::COMPUTE, + ty: match buffer_config.usage { + BufferUsage::Storage => wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + BufferUsage::StorageReadOnly => wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + BufferUsage::Uniform => wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Uniform, + has_dynamic_offset: false, + min_binding_size: None, + }, + }, + count: None, + }) + .collect::>(), + }); + + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[wgpu::PushConstantRange { + stages: wgpu::ShaderStages::COMPUTE, + range: 0..self.push_constants_size, + }], + }); + + let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Compute Pipeline"), + layout: Some(&pipeline_layout), + module: &module, + entry_point: entrypoint.as_deref(), + compilation_options: PipelineCompilationOptions::default(), + cache: None, + }); + + // Create buffers. + let mut gpu_buffers = Vec::new(); + + for (i, buffer_config) in self.buffers.iter().enumerate() { + let usage = match buffer_config.usage { + BufferUsage::Storage => wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + BufferUsage::StorageReadOnly => { + wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC + } + BufferUsage::Uniform => wgpu::BufferUsages::UNIFORM, + }; + + let buffer = if let Some(initial_data) = &buffer_config.initial_data { + device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some(&format!("Buffer {}", i)), + contents: initial_data, + usage, + }) + } else { + let buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some(&format!("Buffer {}", i)), + size: buffer_config.size, + usage, + mapped_at_creation: true, + }); + { + // Zero the buffer. + let initial_data = vec![0u8; buffer_config.size as usize]; + let mut mapping = buffer.slice(..).get_mapped_range_mut(); + mapping.copy_from_slice(&initial_data); + } + buffer.unmap(); + buffer + }; + + gpu_buffers.push(buffer); + } + + // Create bind entries after all buffers are created + let bind_entries: Vec<_> = gpu_buffers + .iter() + .enumerate() + .map(|(i, buffer)| wgpu::BindGroupEntry { + binding: i as u32, + resource: buffer.as_entire_binding(), + }) + .collect(); + + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &bind_entries, + label: Some("Compute Bind Group"), + }); + + let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Compute Encoder"), + }); + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Compute Pass"), + timestamp_writes: Default::default(), + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, &bind_group, &[]); + pass.set_push_constants(0, &self.push_constants_data); + pass.dispatch_workgroups(self.dispatch[0], self.dispatch[1], self.dispatch[2]); + } + + // Create staging buffers and copy results. + let mut staging_buffers = Vec::new(); + for (i, buffer_config) in self.buffers.iter().enumerate() { + if matches!( + buffer_config.usage, + BufferUsage::Storage | BufferUsage::StorageReadOnly + ) { + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some(&format!("Staging Buffer {}", i)), + size: buffer_config.size, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + encoder.copy_buffer_to_buffer( + &gpu_buffers[i], + 0, + &staging_buffer, + 0, + buffer_config.size, + ); + staging_buffers.push(Some(staging_buffer)); + } else { + staging_buffers.push(None); + } + } + + queue.submit(Some(encoder.finish())); + + // Read back results. + let mut results = Vec::new(); + for staging_buffer in staging_buffers.into_iter() { + if let Some(buffer) = staging_buffer { + let buffer_slice = buffer.slice(..); + let (sender, receiver) = futures::channel::oneshot::channel(); + buffer_slice.map_async(wgpu::MapMode::Read, move |res| { + let _ = sender.send(res); + }); + device.poll(wgpu::PollType::Wait)?; + block_on(receiver) + .context("mapping canceled")? + .context("mapping failed")?; + let data = buffer_slice.get_mapped_range().to_vec(); + buffer.unmap(); + results.push(data); + } else { + results.push(Vec::new()); + } + } + + Ok(results) + } + + pub fn run_test(self, config: &Config) -> anyhow::Result<()> { + let buffers = self.buffers.clone(); + let results = self.run()?; + // Write first storage buffer output to file. + for (_i, (data, buffer_config)) in results.iter().zip(&buffers).enumerate() { + if buffer_config.usage == BufferUsage::Storage && !data.is_empty() { + let mut f = File::create(&config.output_path)?; + f.write_all(data)?; + return Ok(()); + } + } + anyhow::bail!("No storage buffer output found") + } +} + +// Convenience implementation for WgpuComputeTestPushConstants +impl WgpuComputeTestPushConstants { + pub fn new_with_data( + shader: RustComputeShader, + dispatch: [u32; 3], + sizes: &[u64], + push_constant_data: &T, + ) -> Self { + let buffers = sizes + .iter() + .map(|&size| BufferConfig { + size, + usage: BufferUsage::Storage, + initial_data: None, + }) + .collect(); + let push_constants_data = bytemuck::bytes_of(push_constant_data).to_vec(); + let push_constants_size = push_constants_data.len() as u32; + + Self::new( + shader, + dispatch, + buffers, + push_constants_size, + push_constants_data, + ) + } +} + +impl WgpuComputeTestPushConstants { + pub fn new_with_data( + shader: WgslComputeShader, + dispatch: [u32; 3], + sizes: &[u64], + push_constant_data: &T, + ) -> Self { + let buffers = sizes + .iter() + .map(|&size| BufferConfig { + size, + usage: BufferUsage::Storage, + initial_data: None, + }) + .collect(); + let push_constants_data = bytemuck::bytes_of(push_constant_data).to_vec(); + let push_constants_size = push_constants_data.len() as u32; + + Self::new( + shader, + dispatch, + buffers, + push_constants_size, + push_constants_data, + ) + } +} diff --git a/tests/difftests/tests/Cargo.lock b/tests/difftests/tests/Cargo.lock index 7ad114e41e..c1cc81b10f 100644 --- a/tests/difftests/tests/Cargo.lock +++ b/tests/difftests/tests/Cargo.lock @@ -23,6 +23,23 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +[[package]] +name = "array_access-rust" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", + "spirv-std", +] + +[[package]] +name = "array_access-wgsl" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", +] + [[package]] name = "arrayvec" version = "0.7.6" @@ -38,6 +55,23 @@ dependencies = [ "libloading", ] +[[package]] +name = "atomic_ops-rust" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", + "spirv-std", +] + +[[package]] +name = "atomic_ops-wgsl" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -74,6 +108,21 @@ dependencies = [ "serde", ] +[[package]] +name = "bitwise_ops-rust" +version = "0.0.0" +dependencies = [ + "difftest", + "spirv-std", +] + +[[package]] +name = "bitwise_ops-wgsl" +version = "0.0.0" +dependencies = [ + "difftest", +] + [[package]] name = "block" version = "0.1.6" @@ -161,6 +210,38 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "control_flow-rust" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", + "spirv-std", +] + +[[package]] +name = "control_flow-wgsl" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", +] + +[[package]] +name = "control_flow_complex-rust" +version = "0.0.0" +dependencies = [ + "difftest", + "spirv-std", +] + +[[package]] +name = "control_flow_complex-wgsl" +version = "0.0.0" +dependencies = [ + "difftest", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -613,12 +694,61 @@ dependencies = [ "libc", ] +[[package]] +name = "math_ops-rust" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", + "spirv-std", +] + +[[package]] +name = "math_ops-wgsl" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", +] + +[[package]] +name = "matrix_ops-rust" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", + "spirv-std", +] + +[[package]] +name = "matrix_ops-wgsl" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", +] + [[package]] name = "memchr" version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memory_barriers-rust" +version = "0.0.0" +dependencies = [ + "difftest", + "spirv-std", +] + +[[package]] +name = "memory_barriers-wgsl" +version = "0.0.0" +dependencies = [ + "difftest", +] + [[package]] name = "metal" version = "0.31.0" @@ -789,6 +919,23 @@ version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "afbdc74edc00b6f6a218ca6a5364d6226a259d4b8ea1af4a0ea063f27e179f4d" +[[package]] +name = "push_constants-rust" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", + "spirv-std", +] + +[[package]] +name = "push_constants-wgsl" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", +] + [[package]] name = "quote" version = "1.0.40" @@ -1119,6 +1266,21 @@ dependencies = [ "syn", ] +[[package]] +name = "trig_ops-rust" +version = "0.0.0" +dependencies = [ + "difftest", + "spirv-std", +] + +[[package]] +name = "trig_ops-wgsl" +version = "0.0.0" +dependencies = [ + "difftest", +] + [[package]] name = "unicode-ident" version = "1.0.18" @@ -1131,6 +1293,55 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "vector_extract_insert-rust" +version = "0.0.0" +dependencies = [ + "difftest", + "glam", + "spirv-std", +] + +[[package]] +name = "vector_extract_insert-wgsl" +version = "0.0.0" +dependencies = [ + "difftest", +] + +[[package]] +name = "vector_ops-rust" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", + "spirv-std", +] + +[[package]] +name = "vector_ops-wgsl" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", +] + +[[package]] +name = "vector_swizzle-rust" +version = "0.0.0" +dependencies = [ + "difftest", + "glam", + "spirv-std", +] + +[[package]] +name = "vector_swizzle-wgsl" +version = "0.0.0" +dependencies = [ + "difftest", +] + [[package]] name = "version_check" version = "0.9.5" @@ -1530,6 +1741,23 @@ dependencies = [ "bitflags 2.9.0", ] +[[package]] +name = "workgroup_memory-rust" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", + "spirv-std", +] + +[[package]] +name = "workgroup_memory-wgsl" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", +] + [[package]] name = "xml-rs" version = "0.8.25" diff --git a/tests/difftests/tests/Cargo.toml b/tests/difftests/tests/Cargo.toml index 34f08f9250..d5f4b492da 100644 --- a/tests/difftests/tests/Cargo.toml +++ b/tests/difftests/tests/Cargo.toml @@ -3,6 +3,34 @@ resolver = "2" members = [ "simple-compute/simple-compute-rust", "simple-compute/simple-compute-wgsl", + "arch/atomic_ops/atomic_ops-rust", + "arch/atomic_ops/atomic_ops-wgsl", + "arch/workgroup_memory/workgroup_memory-rust", + "arch/workgroup_memory/workgroup_memory-wgsl", + "arch/memory_barriers/memory_barriers-rust", + "arch/memory_barriers/memory_barriers-wgsl", + "arch/vector_extract_insert/vector_extract_insert-rust", + "arch/vector_extract_insert/vector_extract_insert-wgsl", + "arch/push_constants/push_constants-rust", + "arch/push_constants/push_constants-wgsl", + "storage_class/array_access/array_access-rust", + "storage_class/array_access/array_access-wgsl", + "lang/control_flow/control_flow-rust", + "lang/control_flow/control_flow-wgsl", + "lang/control_flow_complex/control_flow_complex-rust", + "lang/control_flow_complex/control_flow_complex-wgsl", + "lang/core/ops/math_ops/math_ops-rust", + "lang/core/ops/math_ops/math_ops-wgsl", + "lang/core/ops/vector_ops/vector_ops-rust", + "lang/core/ops/vector_ops/vector_ops-wgsl", + "lang/core/ops/matrix_ops/matrix_ops-rust", + "lang/core/ops/matrix_ops/matrix_ops-wgsl", + "lang/core/ops/bitwise_ops/bitwise_ops-rust", + "lang/core/ops/bitwise_ops/bitwise_ops-wgsl", + "lang/core/ops/trig_ops/trig_ops-rust", + "lang/core/ops/trig_ops/trig_ops-wgsl", + "lang/core/ops/vector_swizzle/vector_swizzle-rust", + "lang/core/ops/vector_swizzle/vector_swizzle-wgsl", ] [workspace.package] @@ -24,6 +52,7 @@ difftest = { path = "../../../tests/difftests/lib" } # External dependencies that need to be mentioned more than once. num-traits = { version = "0.2.15", default-features = false } glam = { version = ">=0.22, <=0.29", default-features = false } +bytemuck = { version = "1.14", features = ["derive"] } # Enable incremental by default in release mode. [profile.release] diff --git a/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/Cargo.toml b/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/Cargo.toml new file mode 100644 index 0000000000..431fdac31a --- /dev/null +++ b/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "atomic_ops-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# Common deps +[dependencies] + +# GPU deps +spirv-std.workspace = true + +# CPU deps +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/lib.rs b/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/lib.rs new file mode 100644 index 0000000000..f80b63d41f --- /dev/null +++ b/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/lib.rs @@ -0,0 +1,43 @@ +#![no_std] + +use spirv_std::arch::{atomic_i_add, atomic_i_sub, atomic_u_max, atomic_u_min}; +use spirv_std::memory::{Scope, Semantics}; +use spirv_std::spirv; + +#[spirv(compute(threads(32)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] counters: &mut [u32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [u32], + #[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3, +) { + const SCOPE: u32 = Scope::Workgroup as u32; + const SEMANTICS: u32 = Semantics::NONE.bits(); + + let tid = global_id.x; + + // All threads participate in atomic operations + // Each thread adds 1 to the first counter + unsafe { atomic_i_add::<_, SCOPE, SEMANTICS>(&mut counters[0], 1) }; + + // Each thread subtracts 1 from the second counter + unsafe { atomic_i_sub::<_, SCOPE, SEMANTICS>(&mut counters[1], 1) }; + + // Each thread tries to set minimum with their thread ID + unsafe { atomic_u_min::<_, SCOPE, SEMANTICS>(&mut counters[2], tid) }; + + // Each thread tries to set maximum with their thread ID + unsafe { atomic_u_max::<_, SCOPE, SEMANTICS>(&mut counters[3], tid) }; + + // Thread 0 stores the final values after all operations complete + if tid == 0 { + // Use atomic loads to ensure we read the final values + unsafe { + spirv_std::arch::workgroup_memory_barrier_with_group_sync(); + } + output[0] = counters[0]; // Should be initial + 32 + output[1] = counters[1]; // Should be initial - 32 + output[2] = counters[2]; // Should be min(initial, 0) + output[3] = counters[3]; // Should be max(initial, 31) + output[4] = counters[4]; // Unchanged + } +} diff --git a/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/main.rs b/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/main.rs new file mode 100644 index 0000000000..bcef4ca050 --- /dev/null +++ b/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/main.rs @@ -0,0 +1,37 @@ +#[cfg(not(target_arch = "spirv"))] +fn main() { + use difftest::config::Config; + use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTestMultiBuffer, + }; + + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Initialize counter buffer with test values + let counter_data = vec![100u32, 50, 20, 5, 0]; + let counter_bytes = bytemuck::cast_slice(&counter_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 20, // 5 u32 values + usage: BufferUsage::Storage, + initial_data: Some(counter_bytes), + }, + BufferConfig { + size: 20, // 5 u32 values for output + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + RustComputeShader::default(), + [1, 1, 1], // Single workgroup with 32 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} + +#[cfg(target_arch = "spirv")] +fn main() {} diff --git a/tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/Cargo.toml b/tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/Cargo.toml new file mode 100644 index 0000000000..a0ba03295a --- /dev/null +++ b/tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "atomic_ops-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/shader.wgsl b/tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/shader.wgsl new file mode 100644 index 0000000000..511c60fbb6 --- /dev/null +++ b/tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/shader.wgsl @@ -0,0 +1,34 @@ +@group(0) @binding(0) +var counters: array, 5>; + +@group(0) @binding(1) +var output: array; + +@compute @workgroup_size(32, 1, 1) +fn main_cs(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + + // All threads participate in atomic operations + // Each thread adds 1 to the first counter + atomicAdd(&counters[0], 1u); + + // Each thread subtracts 1 from the second counter + atomicSub(&counters[1], 1u); + + // Each thread tries to set minimum with their thread ID + atomicMin(&counters[2], tid); + + // Each thread tries to set maximum with their thread ID + atomicMax(&counters[3], tid); + + // Thread 0 stores the final values after all operations complete + if (tid == 0u) { + // Synchronize to ensure all atomic operations are complete + workgroupBarrier(); + output[0] = atomicLoad(&counters[0]); // Should be initial + 32 + output[1] = atomicLoad(&counters[1]); // Should be initial - 32 + output[2] = atomicLoad(&counters[2]); // Should be min(initial, 0) + output[3] = atomicLoad(&counters[3]); // Should be max(initial, 31) + output[4] = atomicLoad(&counters[4]); // Unchanged + } +} \ No newline at end of file diff --git a/tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/src/main.rs b/tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/src/main.rs new file mode 100644 index 0000000000..2221c6ed4e --- /dev/null +++ b/tests/difftests/tests/arch/atomic_ops/atomic_ops-wgsl/src/main.rs @@ -0,0 +1,33 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, WgpuComputeTestMultiBuffer, WgslComputeShader, +}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Initialize counter buffer with test values + let counter_data = vec![100u32, 50, 20, 5, 0]; + let counter_bytes = bytemuck::cast_slice(&counter_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 20, // 5 u32 values + usage: BufferUsage::Storage, + initial_data: Some(counter_bytes), + }, + BufferConfig { + size: 20, // 5 u32 values for output + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + WgslComputeShader::default(), + [1, 1, 1], // Single workgroup with 32 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/Cargo.toml b/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/Cargo.toml new file mode 100644 index 0000000000..6b7ea4ccd3 --- /dev/null +++ b/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "memory_barriers-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# GPU deps +[dependencies] +spirv-std.workspace = true + +# CPU deps (for the test harness) +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/lib.rs b/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/lib.rs new file mode 100644 index 0000000000..c1db9678d5 --- /dev/null +++ b/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/lib.rs @@ -0,0 +1,66 @@ +#![no_std] + +use spirv_std::arch::workgroup_memory_barrier_with_group_sync; +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [u32], + #[spirv(workgroup)] shared: &mut [u32; 64], + #[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3, + #[spirv(local_invocation_id)] local_id: spirv_std::glam::UVec3, +) { + let tid = global_id.x as usize; + let lid = local_id.x as usize; + + if tid < input.len() && tid < output.len() && lid < 64 { + // Load data into shared memory + shared[lid] = input[tid]; + + // Workgroup barrier to ensure all threads have loaded their data + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + // Perform operations on shared memory + let mut result = shared[lid]; + + // Different threads perform different operations + if lid % 4 == 0 { + // Read from neighboring thread's data + if lid + 1 < 64 { + result += shared[lid + 1]; + } + } else if lid % 4 == 1 { + // Read from previous thread's data + if lid > 0 { + result += shared[lid - 1]; + } + } else if lid % 4 == 2 { + // Sum reduction within groups of 4 + if lid + 2 < 64 { + result = shared[lid] + shared[lid + 1] + shared[lid + 2]; + } + } else { + // XOR with wrapped neighbor + result ^= shared[(lid + 32) % 64]; + } + + // Another barrier before writing back + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + // Write result back to shared memory + shared[lid] = result; + + // Memory barrier to ensure writes are visible + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + // Final read and output + output[tid] = shared[lid]; + } +} diff --git a/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/main.rs b/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/main.rs new file mode 100644 index 0000000000..3857ec8a89 --- /dev/null +++ b/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/main.rs @@ -0,0 +1,27 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTestMultiBuffer, +}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 256; + let initial_data: Vec = (0..64).collect(); + let initial_bytes: Vec = initial_data.iter().flat_map(|&x| x.to_ne_bytes()).collect(); + + let test = WgpuComputeTestMultiBuffer::new(RustComputeShader::default(), [1, 1, 1], vec![ + BufferConfig { + size: buffer_size, + usage: BufferUsage::StorageReadOnly, + initial_data: Some(initial_bytes), + }, + BufferConfig { + size: buffer_size, + usage: BufferUsage::Storage, + initial_data: None, + }, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/Cargo.toml b/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/Cargo.toml new file mode 100644 index 0000000000..6e245ef29a --- /dev/null +++ b/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "memory_barriers-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/shader.wgsl b/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/shader.wgsl new file mode 100644 index 0000000000..fbed8bf9a9 --- /dev/null +++ b/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/shader.wgsl @@ -0,0 +1,60 @@ +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +var shared_mem: array; + +@compute @workgroup_size(64) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let tid = global_id.x; + let lid = local_id.x; + + if (tid < arrayLength(&input) && tid < arrayLength(&output) && lid < 64u) { + // Load data into shared memory + shared_mem[lid] = input[tid]; + + // Workgroup barrier to ensure all threads have loaded their data + workgroupBarrier(); + + // Perform operations on shared memory + var result = shared_mem[lid]; + + // Different threads perform different operations + if (lid % 4u == 0u) { + // Read from neighboring thread's data + if (lid + 1u < 64u) { + result = result + shared_mem[lid + 1u]; + } + } else if (lid % 4u == 1u) { + // Read from previous thread's data + if (lid > 0u) { + result = result + shared_mem[lid - 1u]; + } + } else if (lid % 4u == 2u) { + // Sum reduction within groups of 4 + if (lid + 2u < 64u) { + result = shared_mem[lid] + shared_mem[lid + 1u] + shared_mem[lid + 2u]; + } + } else { + // XOR with wrapped neighbor + result = result ^ shared_mem[(lid + 32u) % 64u]; + } + + // Another barrier before writing back + workgroupBarrier(); + + // Write result back to shared memory + shared_mem[lid] = result; + + // Memory barrier to ensure writes are visible + workgroupBarrier(); + + // Final read and output + output[tid] = shared_mem[lid]; + } +} \ No newline at end of file diff --git a/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/src/main.rs b/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/src/main.rs new file mode 100644 index 0000000000..d0ba32b0c4 --- /dev/null +++ b/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/src/main.rs @@ -0,0 +1,27 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, WgpuComputeTestMultiBuffer, WgslComputeShader, +}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 256; + let initial_data: Vec = (0..64).collect(); + let initial_bytes: Vec = initial_data.iter().flat_map(|&x| x.to_ne_bytes()).collect(); + + let test = WgpuComputeTestMultiBuffer::new(WgslComputeShader::default(), [1, 1, 1], vec![ + BufferConfig { + size: buffer_size, + usage: BufferUsage::StorageReadOnly, + initial_data: Some(initial_bytes), + }, + BufferConfig { + size: buffer_size, + usage: BufferUsage::Storage, + initial_data: None, + }, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/arch/mod.rs b/tests/difftests/tests/arch/mod.rs new file mode 100644 index 0000000000..a49014f209 --- /dev/null +++ b/tests/difftests/tests/arch/mod.rs @@ -0,0 +1,3 @@ +pub mod atomic_ops; +pub mod workgroup_memory; +pub mod workgroup_sizes; \ No newline at end of file diff --git a/tests/difftests/tests/arch/push_constants/push_constants-rust/Cargo.toml b/tests/difftests/tests/arch/push_constants/push_constants-rust/Cargo.toml new file mode 100644 index 0000000000..2ba5468de4 --- /dev/null +++ b/tests/difftests/tests/arch/push_constants/push_constants-rust/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "push_constants-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# GPU deps +[dependencies] +spirv-std.workspace = true +bytemuck.workspace = true + +# CPU deps (for the test harness) +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/push_constants/push_constants-rust/src/lib.rs b/tests/difftests/tests/arch/push_constants/push_constants-rust/src/lib.rs new file mode 100644 index 0000000000..6f02be077c --- /dev/null +++ b/tests/difftests/tests/arch/push_constants/push_constants-rust/src/lib.rs @@ -0,0 +1,83 @@ +#![no_std] + +#[allow(unused_imports)] +use spirv_std::num_traits::Float; +use spirv_std::spirv; + +#[repr(C)] +#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] +pub struct PushConstants { + multiplier: f32, + offset: f32, + flags: u32, + count: u32, +} + +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(push_constant)] push_constants: &PushConstants, + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[f32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [f32], + #[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3, +) { + let tid = global_id.x as usize; + let count = push_constants.count as usize; + + if tid < input.len() && tid < output.len() && tid < count { + let value = input[tid]; + + // Apply different operations based on flags + let result = match push_constants.flags { + 0 => { + // Linear transformation + value * push_constants.multiplier + push_constants.offset + } + 1 => { + // Quadratic transformation + value * value * push_constants.multiplier + push_constants.offset + } + 2 => { + // Sine wave modulation + (value * push_constants.multiplier).sin() + push_constants.offset + } + 3 => { + // Exponential transformation + (value * push_constants.multiplier).exp() + push_constants.offset + } + 4 => { + // Logarithmic transformation (with protection against negative values) + if value > 0.0 { + (value * push_constants.multiplier).ln() + push_constants.offset + } else { + push_constants.offset + } + } + 5 => { + // Reciprocal transformation (with protection against division by zero) + if value.abs() > 0.001 { + push_constants.multiplier / value + push_constants.offset + } else { + push_constants.offset + } + } + 6 => { + // Power transformation + value.powf(push_constants.multiplier) + push_constants.offset + } + 7 => { + // Modulo operation (treating multiplier as divisor) + if push_constants.multiplier > 0.0 { + (value % push_constants.multiplier) + push_constants.offset + } else { + push_constants.offset + } + } + _ => { + // Default: just add offset + value + push_constants.offset + } + }; + + output[tid] = result; + } +} diff --git a/tests/difftests/tests/arch/push_constants/push_constants-rust/src/main.rs b/tests/difftests/tests/arch/push_constants/push_constants-rust/src/main.rs new file mode 100644 index 0000000000..671f0a4974 --- /dev/null +++ b/tests/difftests/tests/arch/push_constants/push_constants-rust/src/main.rs @@ -0,0 +1,53 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTestPushConstants, +}; + +#[repr(C)] +#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] +pub struct PushConstants { + multiplier: f32, + offset: f32, + flags: u32, + count: u32, +} + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let num_elements = 256; + let buffer_size = num_elements * 4; // 4 bytes per f32 + + // Create input data + let input_data: Vec = (0..num_elements).map(|i| i as f32 * 0.1).collect(); + let input_bytes: Vec = input_data.iter().flat_map(|&x| x.to_ne_bytes()).collect(); + + // Create push constants data + let push_constants = PushConstants { + multiplier: 2.5, + offset: 1.0, + flags: 0, // Linear transformation + count: num_elements as u32, + }; + + let test = WgpuComputeTestPushConstants::new( + RustComputeShader::default(), + [4, 1, 1], // 256 / 64 = 4 workgroups + vec![ + BufferConfig { + size: buffer_size as u64, + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: buffer_size as u64, + usage: BufferUsage::Storage, + initial_data: None, + }, + ], + std::mem::size_of::() as u32, + bytemuck::bytes_of(&push_constants).to_vec(), + ); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/arch/push_constants/push_constants-wgsl/Cargo.toml b/tests/difftests/tests/arch/push_constants/push_constants-wgsl/Cargo.toml new file mode 100644 index 0000000000..930605fdc0 --- /dev/null +++ b/tests/difftests/tests/arch/push_constants/push_constants-wgsl/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "push_constants-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/push_constants/push_constants-wgsl/shader.wgsl b/tests/difftests/tests/arch/push_constants/push_constants-wgsl/shader.wgsl new file mode 100644 index 0000000000..1977d38ff2 --- /dev/null +++ b/tests/difftests/tests/arch/push_constants/push_constants-wgsl/shader.wgsl @@ -0,0 +1,79 @@ +struct PushConstants { + multiplier: f32, + offset: f32, + flags: u32, + count: u32, +} + +var push_constants: PushConstants; + +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + let count = push_constants.count; + + if (tid < arrayLength(&input) && tid < arrayLength(&output) && tid < count) { + let value = input[tid]; + var result: f32; + + // Apply different operations based on flags + switch push_constants.flags { + case 0u: { + // Linear transformation + result = value * push_constants.multiplier + push_constants.offset; + } + case 1u: { + // Quadratic transformation + result = value * value * push_constants.multiplier + push_constants.offset; + } + case 2u: { + // Sine wave modulation + result = sin(value * push_constants.multiplier) + push_constants.offset; + } + case 3u: { + // Exponential transformation + result = exp(value * push_constants.multiplier) + push_constants.offset; + } + case 4u: { + // Logarithmic transformation (with protection against negative values) + if (value > 0.0) { + result = log(value * push_constants.multiplier) + push_constants.offset; + } else { + result = push_constants.offset; + } + } + case 5u: { + // Reciprocal transformation (with protection against division by zero) + if (abs(value) > 0.001) { + result = push_constants.multiplier / value + push_constants.offset; + } else { + result = push_constants.offset; + } + } + case 6u: { + // Power transformation + result = pow(value, push_constants.multiplier) + push_constants.offset; + } + case 7u: { + // Modulo operation (treating multiplier as divisor) + if (push_constants.multiplier > 0.0) { + result = (value % push_constants.multiplier) + push_constants.offset; + } else { + result = push_constants.offset; + } + } + default: { + // Default: just add offset + result = value + push_constants.offset; + } + } + + output[tid] = result; + } +} \ No newline at end of file diff --git a/tests/difftests/tests/arch/push_constants/push_constants-wgsl/src/main.rs b/tests/difftests/tests/arch/push_constants/push_constants-wgsl/src/main.rs new file mode 100644 index 0000000000..04f8b0ab79 --- /dev/null +++ b/tests/difftests/tests/arch/push_constants/push_constants-wgsl/src/main.rs @@ -0,0 +1,53 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, WgpuComputeTestPushConstants, WgslComputeShader, +}; + +#[repr(C)] +#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] +pub struct PushConstants { + multiplier: f32, + offset: f32, + flags: u32, + count: u32, +} + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let num_elements = 256; + let buffer_size = num_elements * 4; // 4 bytes per f32 + + // Create input data + let input_data: Vec = (0..num_elements).map(|i| i as f32 * 0.1).collect(); + let input_bytes: Vec = input_data.iter().flat_map(|&x| x.to_ne_bytes()).collect(); + + // Create push constants data + let push_constants = PushConstants { + multiplier: 2.5, + offset: 1.0, + flags: 0, // Linear transformation + count: num_elements as u32, + }; + + let test = WgpuComputeTestPushConstants::new( + WgslComputeShader::default(), + [4, 1, 1], // 256 / 64 = 4 workgroups + vec![ + BufferConfig { + size: buffer_size as u64, + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: buffer_size as u64, + usage: BufferUsage::Storage, + initial_data: None, + }, + ], + std::mem::size_of::() as u32, + bytemuck::bytes_of(&push_constants).to_vec(), + ); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/Cargo.toml b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/Cargo.toml new file mode 100644 index 0000000000..2481bd472d --- /dev/null +++ b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "vector_extract_insert-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# GPU deps +[dependencies] +spirv-std.workspace = true +glam.workspace = true + +# CPU deps (for the test harness) +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/lib.rs b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/lib.rs new file mode 100644 index 0000000000..078f868600 --- /dev/null +++ b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/lib.rs @@ -0,0 +1,78 @@ +#![no_std] +#![cfg_attr(target_arch = "spirv", feature(asm_experimental_arch))] + +use glam::Vec4; +use spirv_std::arch::{vector_extract_dynamic, vector_insert_dynamic}; +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[[f32; 4]], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] indices: &[u32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] output: &mut [[f32; 4]], + #[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3, +) { + let tid = global_id.x as usize; + + if tid < input.len() && tid < indices.len() && tid < output.len() { + let vec = Vec4::from_array(input[tid]); + let index = (indices[tid] % 4) as usize; // Ensure index is within bounds + + // Test various extract and insert operations + let result = match tid % 8 { + 0 => { + // Extract a component dynamically + let extracted = unsafe { vector_extract_dynamic(vec, index) }; + Vec4::new(extracted, extracted, extracted, extracted) + } + 1 => { + // Insert a new value at dynamic index + unsafe { vector_insert_dynamic(vec, index, 42.0) } + } + 2 => { + // Extract and double the value, then insert back + let extracted = unsafe { vector_extract_dynamic(vec, index) }; + unsafe { vector_insert_dynamic(vec, index, extracted * 2.0) } + } + 3 => { + // Swap two components using extract/insert + let idx1 = index; + let idx2 = (index + 1) % 4; + let val1 = unsafe { vector_extract_dynamic(vec, idx1) }; + let val2 = unsafe { vector_extract_dynamic(vec, idx2) }; + let temp = unsafe { vector_insert_dynamic(vec, idx1, val2) }; + unsafe { vector_insert_dynamic(temp, idx2, val1) } + } + 4 => { + // Set all components to the value at dynamic index + let val = unsafe { vector_extract_dynamic(vec, index) }; + Vec4::new(val, val, val, val) + } + 5 => { + // Rotate components based on index + let mut result = vec; + for i in 0..4 { + let src_idx = (i + index) % 4; + let val = unsafe { vector_extract_dynamic(vec, src_idx) }; + result = unsafe { vector_insert_dynamic(result, i, val) }; + } + result + } + 6 => { + // Insert sum of all components at dynamic index + let sum = vec.x + vec.y + vec.z + vec.w; + unsafe { vector_insert_dynamic(vec, index, sum) } + } + 7 => { + // Extract from one index, insert at another + let src_idx = index; + let dst_idx = (index + 2) % 4; + let val = unsafe { vector_extract_dynamic(vec, src_idx) }; + unsafe { vector_insert_dynamic(vec, dst_idx, val) } + } + _ => vec, + }; + + output[tid] = result.to_array(); + } +} diff --git a/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/main.rs b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/main.rs new file mode 100644 index 0000000000..89dc766576 --- /dev/null +++ b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/main.rs @@ -0,0 +1,16 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{RustComputeShader, WgpuComputeTestMultiBuffer}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 1024; + let test = + WgpuComputeTestMultiBuffer::new_with_sizes(RustComputeShader::default(), [64, 1, 1], &[ + buffer_size, + buffer_size, + buffer_size, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/Cargo.toml b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/Cargo.toml new file mode 100644 index 0000000000..b64dae659c --- /dev/null +++ b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "vector_extract_insert-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/shader.wgsl b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/shader.wgsl new file mode 100644 index 0000000000..43199e558f --- /dev/null +++ b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/shader.wgsl @@ -0,0 +1,72 @@ +@group(0) @binding(0) var input: array>; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array>; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + + if (tid < arrayLength(&input) && tid < arrayLength(&indices) && tid < arrayLength(&output)) { + let vec = input[tid]; + let index = indices[tid] % 4u; // Ensure index is within bounds + + // Test various extract and insert operations + var result: vec4; + switch (tid % 8u) { + case 0u: { + // Extract a component dynamically + let extracted = vec[index]; + result = vec4(extracted, extracted, extracted, extracted); + } + case 1u: { + // Insert a new value at dynamic index + result = vec; + result[index] = 42.0; + } + case 2u: { + // Extract and double the value, then insert back + result = vec; + result[index] = vec[index] * 2.0; + } + case 3u: { + // Swap two components using extract/insert + let idx1 = index; + let idx2 = (index + 1u) % 4u; + result = vec; + let temp = result[idx1]; + result[idx1] = result[idx2]; + result[idx2] = temp; + } + case 4u: { + // Set all components to the value at dynamic index + let val = vec[index]; + result = vec4(val, val, val, val); + } + case 5u: { + // Rotate components based on index + for (var i = 0u; i < 4u; i = i + 1u) { + let src_idx = (i + index) % 4u; + result[i] = vec[src_idx]; + } + } + case 6u: { + // Insert sum of all components at dynamic index + let sum = vec.x + vec.y + vec.z + vec.w; + result = vec; + result[index] = sum; + } + case 7u: { + // Extract from one index, insert at another + let src_idx = index; + let dst_idx = (index + 2u) % 4u; + result = vec; + result[dst_idx] = vec[src_idx]; + } + default: { + result = vec; + } + } + + output[tid] = result; + } +} \ No newline at end of file diff --git a/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/src/main.rs b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/src/main.rs new file mode 100644 index 0000000000..f98ea16c7e --- /dev/null +++ b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/src/main.rs @@ -0,0 +1,16 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{WgpuComputeTestMultiBuffer, WgslComputeShader}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 1024; + let test = + WgpuComputeTestMultiBuffer::new_with_sizes(WgslComputeShader::default(), [64, 1, 1], &[ + buffer_size, + buffer_size, + buffer_size, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml new file mode 100644 index 0000000000..0091d02acf --- /dev/null +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "workgroup_memory-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# Common deps +[dependencies] + +# GPU deps +spirv-std.workspace = true + +# CPU deps +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs new file mode 100644 index 0000000000..dc706b2477 --- /dev/null +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs @@ -0,0 +1,73 @@ +#![no_std] + +use spirv_std::arch::workgroup_memory_barrier_with_group_sync; +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [u32], + #[spirv(local_invocation_id)] local_id: spirv_std::glam::UVec3, + #[spirv(workgroup)] shared: &mut [u32; 64], +) { + let lid = local_id.x as usize; + + // Load data into shared memory + shared[lid] = input[lid]; + + // Synchronize to ensure all threads have loaded + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + // Each thread sums its value with its neighbor (reduction step) + if lid < 32 { + shared[lid] += shared[lid + 32]; + } + + // Synchronize again + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + if lid < 16 { + shared[lid] += shared[lid + 16]; + } + + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + if lid < 8 { + shared[lid] += shared[lid + 8]; + } + + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + if lid < 4 { + shared[lid] += shared[lid + 4]; + } + + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + if lid < 2 { + shared[lid] += shared[lid + 2]; + } + + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + if lid < 1 { + shared[lid] += shared[lid + 1]; + } + + // Write final result + if lid == 0 { + output[0] = shared[0]; + } +} diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs new file mode 100644 index 0000000000..cd20ab4765 --- /dev/null +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs @@ -0,0 +1,37 @@ +#[cfg(not(target_arch = "spirv"))] +fn main() { + use difftest::config::Config; + use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTestMultiBuffer, + }; + + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Initialize input buffer with values to sum + let input_data: Vec = (1..=64).collect(); + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 256, // 64 u32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 4, // 1 u32 value for output + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + RustComputeShader::default(), + [1, 1, 1], // Single workgroup with 64 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} + +#[cfg(target_arch = "spirv")] +fn main() {} diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/Cargo.toml b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/Cargo.toml new file mode 100644 index 0000000000..9491f5381b --- /dev/null +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "workgroup_memory-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/shader.wgsl b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/shader.wgsl new file mode 100644 index 0000000000..7f12d99439 --- /dev/null +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/shader.wgsl @@ -0,0 +1,59 @@ +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +var shared_data: array; + +@compute @workgroup_size(64, 1, 1) +fn main_cs(@builtin(local_invocation_id) local_id: vec3) { + let lid = local_id.x; + + // Load data into shared memory + shared_data[lid] = input[lid]; + + // Synchronize to ensure all threads have loaded + workgroupBarrier(); + + // Each thread sums its value with its neighbor (reduction step) + if (lid < 32u) { + shared_data[lid] += shared_data[lid + 32u]; + } + + // Synchronize again + workgroupBarrier(); + + if (lid < 16u) { + shared_data[lid] += shared_data[lid + 16u]; + } + + workgroupBarrier(); + + if (lid < 8u) { + shared_data[lid] += shared_data[lid + 8u]; + } + + workgroupBarrier(); + + if (lid < 4u) { + shared_data[lid] += shared_data[lid + 4u]; + } + + workgroupBarrier(); + + if (lid < 2u) { + shared_data[lid] += shared_data[lid + 2u]; + } + + workgroupBarrier(); + + if (lid < 1u) { + shared_data[lid] += shared_data[lid + 1u]; + } + + // Write final result + if (lid == 0u) { + output[0] = shared_data[0]; + } +} \ No newline at end of file diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs new file mode 100644 index 0000000000..1da92981a9 --- /dev/null +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs @@ -0,0 +1,33 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, WgpuComputeTestMultiBuffer, WgslComputeShader, +}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Initialize input buffer with values to sum + let input_data: Vec = (1..=64).collect(); + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 256, // 64 u32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 4, // 1 u32 value for output + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + WgslComputeShader::default(), + [1, 1, 1], // Single workgroup with 64 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/control_flow/control_flow-rust/Cargo.toml b/tests/difftests/tests/lang/control_flow/control_flow-rust/Cargo.toml new file mode 100644 index 0000000000..89b14e484b --- /dev/null +++ b/tests/difftests/tests/lang/control_flow/control_flow-rust/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "control_flow-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# Common deps +[dependencies] + +# GPU deps +spirv-std.workspace = true + +# CPU deps +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/control_flow/control_flow-rust/src/lib.rs b/tests/difftests/tests/lang/control_flow/control_flow-rust/src/lib.rs new file mode 100644 index 0000000000..ab6bd4ef33 --- /dev/null +++ b/tests/difftests/tests/lang/control_flow/control_flow-rust/src/lib.rs @@ -0,0 +1,103 @@ +#![no_std] + +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [u32], + #[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3, +) { + let tid = global_id.x as usize; + + if tid >= output.len() { + return; // Early return test + } + + let val = if tid < input.len() { input[tid] } else { 0 }; + + // Test 1: Simple if/else + let result1 = if val % 2 == 0 { val * 2 } else { val * 3 }; + + // Test 2: Nested if/else + let result2 = if val < 10 { + if val < 5 { val + 100 } else { val + 200 } + } else if val < 20 { + val + 300 + } else { + val + 400 + }; + + // Test 3: Loop with break + let mut sum = 0u32; + let mut i = 0; + loop { + if i >= val || i >= 10 { + break; + } + sum += i; + i += 1; + } + + // Test 4: While loop + let mut product = 1u32; + let mut j = 1; + while j <= 5 && j <= val { + product *= j; + j += 1; + } + + // Test 5: For loop with continue + let mut count = 0u32; + for k in 0..20 { + if k % 3 == 0 { + continue; + } + if k > val { + break; + } + count += 1; + } + + // Test 6: Match expression + let match_result = match val % 4 { + 0 => 1000, + 1 => 2000, + 2 => 3000, + _ => 4000, + }; + + // Test 7: Complex control flow with early returns + let complex_result = complex_function(val); + + // Combine all results + output[tid] = result1 + .wrapping_add(result2) + .wrapping_add(sum) + .wrapping_add(product) + .wrapping_add(count) + .wrapping_add(match_result) + .wrapping_add(complex_result); +} + +fn complex_function(x: u32) -> u32 { + if x == 0 { + return 999; + } + + for i in 0..x { + if i > 5 { + return i * 100; + } + } + + let mut result = x; + while result > 10 { + result /= 2; + if result == 15 { + return 777; + } + } + + result +} diff --git a/tests/difftests/tests/lang/control_flow/control_flow-rust/src/main.rs b/tests/difftests/tests/lang/control_flow/control_flow-rust/src/main.rs new file mode 100644 index 0000000000..3258728384 --- /dev/null +++ b/tests/difftests/tests/lang/control_flow/control_flow-rust/src/main.rs @@ -0,0 +1,37 @@ +#[cfg(not(target_arch = "spirv"))] +fn main() { + use difftest::config::Config; + use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTestMultiBuffer, + }; + + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Create input data with various values to test different control flow paths + let input_data: Vec = (0..64).map(|i| i as u32).collect(); + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 256, // 64 u32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 256, // 64 u32 values for output + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + RustComputeShader::default(), + [1, 1, 1], // Single workgroup with 64 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} + +#[cfg(target_arch = "spirv")] +fn main() {} diff --git a/tests/difftests/tests/lang/control_flow/control_flow-wgsl/Cargo.toml b/tests/difftests/tests/lang/control_flow/control_flow-wgsl/Cargo.toml new file mode 100644 index 0000000000..8224d583b9 --- /dev/null +++ b/tests/difftests/tests/lang/control_flow/control_flow-wgsl/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "control_flow-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/control_flow/control_flow-wgsl/shader.wgsl b/tests/difftests/tests/lang/control_flow/control_flow-wgsl/shader.wgsl new file mode 100644 index 0000000000..87305d0fc3 --- /dev/null +++ b/tests/difftests/tests/lang/control_flow/control_flow-wgsl/shader.wgsl @@ -0,0 +1,115 @@ +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +fn complex_function(x: u32) -> u32 { + if (x == 0u) { + return 999u; + } + + for (var i = 0u; i < x; i = i + 1u) { + if (i > 5u) { + return i * 100u; + } + } + + var result = x; + while (result > 10u) { + result = result / 2u; + if (result == 15u) { + return 777u; + } + } + + return result; +} + +@compute @workgroup_size(64, 1, 1) +fn main_cs(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + let output_len = arrayLength(&output); + + if (tid >= output_len) { + return; // Early return test + } + + let val = select(0u, input[tid], tid < arrayLength(&input)); + + // Test 1: Simple if/else + var result1: u32; + if (val % 2u == 0u) { + result1 = val * 2u; + } else { + result1 = val * 3u; + } + + // Test 2: Nested if/else + var result2: u32; + if (val < 10u) { + if (val < 5u) { + result2 = val + 100u; + } else { + result2 = val + 200u; + } + } else if (val < 20u) { + result2 = val + 300u; + } else { + result2 = val + 400u; + } + + // Test 3: Loop with break + var sum = 0u; + var i = 0u; + loop { + if (i >= val || i >= 10u) { + break; + } + sum = sum + i; + i = i + 1u; + } + + // Test 4: While loop + var product = 1u; + var j = 1u; + while (j <= 5u && j <= val) { + product = product * j; + j = j + 1u; + } + + // Test 5: For loop with continue + var count = 0u; + for (var k = 0u; k < 20u; k = k + 1u) { + if (k % 3u == 0u) { + continue; + } + if (k > val) { + break; + } + count = count + 1u; + } + + // Test 6: Switch expression (WGSL's equivalent of match) + var match_result: u32; + switch (val % 4u) { + case 0u: { + match_result = 1000u; + } + case 1u: { + match_result = 2000u; + } + case 2u: { + match_result = 3000u; + } + default: { + match_result = 4000u; + } + } + + // Test 7: Complex control flow with early returns + let complex_result = complex_function(val); + + // Combine all results (using addition with overflow wrapping) + output[tid] = (result1 + result2 + sum + product + count + match_result + complex_result) & 0xFFFFFFFFu; +} \ No newline at end of file diff --git a/tests/difftests/tests/lang/control_flow/control_flow-wgsl/src/main.rs b/tests/difftests/tests/lang/control_flow/control_flow-wgsl/src/main.rs new file mode 100644 index 0000000000..fc6e9feef5 --- /dev/null +++ b/tests/difftests/tests/lang/control_flow/control_flow-wgsl/src/main.rs @@ -0,0 +1,33 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, WgpuComputeTestMultiBuffer, WgslComputeShader, +}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Create input data with various values to test different control flow paths + let input_data: Vec = (0..64).map(|i| i as u32).collect(); + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 256, // 64 u32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 256, // 64 u32 values for output + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + WgslComputeShader::default(), + [1, 1, 1], // Single workgroup with 64 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/Cargo.toml b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/Cargo.toml new file mode 100644 index 0000000000..5cf19236ff --- /dev/null +++ b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "control_flow_complex-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# GPU deps +[dependencies] +spirv-std.workspace = true + +# CPU deps (for the test harness) +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/lib.rs b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/lib.rs new file mode 100644 index 0000000000..77dd696239 --- /dev/null +++ b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/lib.rs @@ -0,0 +1,111 @@ +#![no_std] +#![cfg_attr(target_arch = "spirv", feature(asm_experimental_arch))] + +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [u32], + #[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3, +) { + let tid = global_id.x as usize; + + if tid < input.len() && tid < output.len() { + let value = input[tid]; + + // Complex control flow with nested loops, early returns, and match expressions + let result = process_value(value, tid as u32); + + output[tid] = result; + } +} + +fn process_value(value: u32, tid: u32) -> u32 { + // Early return for special cases + if value == 0 { + return 0; + } + + if value > 1000 { + return value - 1000; + } + + // Complex match expression + let base = match value % 10 { + 0 => 10, + 1 | 2 => value * 2, + 3..=5 => { + // Nested computation + let mut sum = 0; + for i in 0..3 { + sum += value + i; + } + sum + } + 6 => { + // Early return from match arm + if tid % 2 == 0 { + return value * 3; + } + value * 4 + } + 7 | 8 => { + // Nested loop with break + let mut result = 0; + 'outer: for i in 0..5 { + for j in 0..3 { + result += i * j; + if result > value { + break 'outer; + } + } + } + result + } + 9 => { + // Complex nested condition + if tid < 10 { + if value < 50 { value + tid } else { value - tid } + } else { + // Loop with continue + let mut sum = 0; + for i in 0..value { + if i % 3 == 0 { + continue; + } + sum += i; + if sum > 100 { + break; + } + } + sum + } + } + _ => value, // This should never be reached due to modulo 10 + }; + + // Final transformation with nested conditions + if base > 50 { + if tid % 3 == 0 { + base / 2 + } else if tid % 3 == 1 { + base * 2 + } else { + base + } + } else { + // Another loop with multiple exit conditions + let mut result = base; + let mut counter = 0; + loop { + result = (result * 3 + 1) / 2; + counter += 1; + + if result == 1 || counter >= 10 || result > 1000 { + break; + } + } + result + counter + } +} diff --git a/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/main.rs b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/main.rs new file mode 100644 index 0000000000..eb96343bcf --- /dev/null +++ b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/main.rs @@ -0,0 +1,15 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{RustComputeShader, WgpuComputeTestMultiBuffer}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 1024; + let test = + WgpuComputeTestMultiBuffer::new_with_sizes(RustComputeShader::default(), [64, 1, 1], &[ + buffer_size, + buffer_size, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/Cargo.toml b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/Cargo.toml new file mode 100644 index 0000000000..07a260e26b --- /dev/null +++ b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "control_flow_complex-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/shader.wgsl b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/shader.wgsl new file mode 100644 index 0000000000..c04bcb3880 --- /dev/null +++ b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/shader.wgsl @@ -0,0 +1,120 @@ +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + + if (tid < arrayLength(&input) && tid < arrayLength(&output)) { + let value = input[tid]; + + // Complex control flow with nested loops, early returns, and switch expressions + let result = process_value(value, tid); + + output[tid] = result; + } +} + +fn process_value(value: u32, tid: u32) -> u32 { + // Early return for special cases + if (value == 0u) { + return 0u; + } + + if (value > 1000u) { + return select(0u, value - 1000u, value >= 1000u); // saturating_sub equivalent + } + + // Complex switch expression + var base: u32; + switch (value % 10u) { + case 0u: { + base = 10u; + } + case 1u, 2u: { + base = value * 2u; + } + case 3u, 4u, 5u: { + // Nested computation + var sum = 0u; + for (var i = 0u; i < 3u; i = i + 1u) { + sum = sum + value + i; + } + base = sum; + } + case 6u: { + // Early return from switch case + if (tid % 2u == 0u) { + return value * 3u; + } + base = value * 4u; + } + case 7u, 8u: { + // Nested loop with break + var result = 0u; + for (var i = 0u; i < 5u; i = i + 1u) { + for (var j = 0u; j < 3u; j = j + 1u) { + result = result + i * j; + if (result > value) { + i = 5u; // Force outer loop to exit + break; + } + } + if (i >= 5u) { + break; + } + } + base = result; + } + case 9u: { + // Complex nested condition + if (tid < 10u) { + if (value < 50u) { + base = value + tid; + } else { + base = value - tid; + } + } else { + // Loop with continue + var sum = 0u; + for (var i = 0u; i < value; i = i + 1u) { + if (i % 3u == 0u) { + continue; + } + sum = sum + i; + if (sum > 100u) { + break; + } + } + base = sum; + } + } + default: { + base = value; + } + } + + // Final transformation with nested conditions + if (base > 50u) { + if (tid % 3u == 0u) { + return base / 2u; + } else if (tid % 3u == 1u) { + return base * 2u; + } else { + return base; + } + } else { + // Another loop with multiple exit conditions + var result = base; + var counter = 0u; + loop { + result = (result * 3u + 1u) / 2u; + counter = counter + 1u; + + if (result == 1u || counter >= 10u || result > 1000u) { + break; + } + } + return result + counter; + } +} \ No newline at end of file diff --git a/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/src/main.rs b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/src/main.rs new file mode 100644 index 0000000000..f21cdecfbf --- /dev/null +++ b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/src/main.rs @@ -0,0 +1,15 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{WgpuComputeTestMultiBuffer, WgslComputeShader}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 1024; + let test = + WgpuComputeTestMultiBuffer::new_with_sizes(WgslComputeShader::default(), [64, 1, 1], &[ + buffer_size, + buffer_size, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/Cargo.toml b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/Cargo.toml new file mode 100644 index 0000000000..2392c4f90d --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "bitwise_ops-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# GPU deps +[dependencies] +spirv-std.workspace = true + +# CPU deps (for the test harness) +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/lib.rs new file mode 100644 index 0000000000..92ebc38790 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/lib.rs @@ -0,0 +1,38 @@ +#![no_std] +#![cfg_attr(target_arch = "spirv", feature(asm_experimental_arch))] + +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input_a: &[u32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] input_b: &[u32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] output: &mut [u32], + #[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3, +) { + let tid = global_id.x as usize; + + if tid < input_a.len() && tid < input_b.len() && tid < output.len() { + let a = input_a[tid]; + let b = input_b[tid]; + + // Test various bitwise operations + let result = match tid % 12 { + 0 => a & b, // AND + 1 => a | b, // OR + 2 => a ^ b, // XOR + 3 => !a, // NOT + 4 => a << (b % 32), // Left shift (avoid UB with modulo) + 5 => a >> (b % 32), // Right shift (avoid UB with modulo) + 6 => a.rotate_left(b % 32), // Rotate left + 7 => a.rotate_right(b % 32), // Rotate right + 8 => a.count_ones(), // Population count + 9 => a.leading_zeros(), // Leading zeros + 10 => a.trailing_zeros(), // Trailing zeros + 11 => a.reverse_bits(), // Bit reversal + _ => 0, + }; + + output[tid] = result; + } +} diff --git a/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/main.rs new file mode 100644 index 0000000000..89dc766576 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/main.rs @@ -0,0 +1,16 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{RustComputeShader, WgpuComputeTestMultiBuffer}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 1024; + let test = + WgpuComputeTestMultiBuffer::new_with_sizes(RustComputeShader::default(), [64, 1, 1], &[ + buffer_size, + buffer_size, + buffer_size, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/Cargo.toml b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/Cargo.toml new file mode 100644 index 0000000000..a5c6161d83 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "bitwise_ops-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/shader.wgsl new file mode 100644 index 0000000000..9950f1fa10 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/shader.wgsl @@ -0,0 +1,41 @@ +@group(0) @binding(0) var input_a: array; +@group(0) @binding(1) var input_b: array; +@group(0) @binding(2) var output: array; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + + if (tid < arrayLength(&input_a) && tid < arrayLength(&input_b) && tid < arrayLength(&output)) { + let a = input_a[tid]; + let b = input_b[tid]; + + // Test various bitwise operations + var result: u32; + switch (tid % 12u) { + case 0u: { result = a & b; } // AND + case 1u: { result = a | b; } // OR + case 2u: { result = a ^ b; } // XOR + case 3u: { result = ~a; } // NOT + case 4u: { result = a << (b % 32u); } // Left shift (avoid UB with modulo) + case 5u: { result = a >> (b % 32u); } // Right shift (avoid UB with modulo) + case 6u: { + // Rotate left + let shift = b % 32u; + result = (a << shift) | (a >> (32u - shift)); + } + case 7u: { + // Rotate right + let shift = b % 32u; + result = (a >> shift) | (a << (32u - shift)); + } + case 8u: { result = countOneBits(a); } // Population count + case 9u: { result = countLeadingZeros(a); } // Leading zeros + case 10u: { result = countTrailingZeros(a); } // Trailing zeros + case 11u: { result = reverseBits(a); } // Bit reversal + default: { result = 0u; } + } + + output[tid] = result; + } +} \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/src/main.rs new file mode 100644 index 0000000000..f98ea16c7e --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/src/main.rs @@ -0,0 +1,16 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{WgpuComputeTestMultiBuffer, WgslComputeShader}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 1024; + let test = + WgpuComputeTestMultiBuffer::new_with_sizes(WgslComputeShader::default(), [64, 1, 1], &[ + buffer_size, + buffer_size, + buffer_size, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/Cargo.toml b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/Cargo.toml new file mode 100644 index 0000000000..6b46627ae3 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "math_ops-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# Common deps +[dependencies] + +# GPU deps +spirv-std.workspace = true + +# CPU deps +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs new file mode 100644 index 0000000000..252f8b7b33 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs @@ -0,0 +1,55 @@ +#![no_std] + +#[allow(unused_imports)] +use spirv_std::num_traits::Float; +use spirv_std::spirv; + +#[spirv(compute(threads(32)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[f32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [f32], + #[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3, +) { + let tid = global_id.x as usize; + + if tid >= 32 || tid >= input.len() { + return; + } + + let x = input[tid]; + let base_offset = tid * 21; + + if base_offset + 20 >= output.len() { + return; + } + + // Basic arithmetic + output[base_offset + 0] = x + 1.5; + output[base_offset + 1] = x - 0.5; + output[base_offset + 2] = x * 2.0; + output[base_offset + 3] = x / 2.0; + output[base_offset + 4] = x % 3.0; + + // Trigonometric functions (simplified for consistent results) + output[base_offset + 5] = x.sin(); + output[base_offset + 6] = x.cos(); + output[base_offset + 7] = x.tan().clamp(-10.0, 10.0); + output[base_offset + 8] = 0.0; + output[base_offset + 9] = 0.0; + output[base_offset + 10] = x.atan(); + + // Exponential and logarithmic (simplified) + output[base_offset + 11] = x.exp().min(1e6); + output[base_offset + 12] = if x > 0.0 { x.ln() } else { -10.0 }; + output[base_offset + 13] = x.abs().sqrt(); + output[base_offset + 14] = x.abs().powf(2.0); + output[base_offset + 15] = if x > 0.0 { x.log2() } else { -10.0 }; + output[base_offset + 16] = x.exp2().min(1e6); + output[base_offset + 17] = x.floor(); + output[base_offset + 18] = x.ceil(); + output[base_offset + 19] = x.round(); + + // Special values and conversions + let int_val = x as i32; + output[base_offset + 20] = int_val as f32; +} diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs new file mode 100644 index 0000000000..2034b36d85 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs @@ -0,0 +1,58 @@ +#[cfg(not(target_arch = "spirv"))] +fn main() { + use difftest::config::Config; + use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTestMultiBuffer, + }; + + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Create input data with various float values + let input_data: Vec = (0..32) + .map(|i| match i { + 0 => 0.0, + 1 => 1.0, + 2 => -1.0, + 3 => 0.5, + 4 => -0.5, + 5 => 2.0, + 6 => -2.0, + 7 => std::f32::consts::PI, + 8 => std::f32::consts::E, + 9 => 10.0, + 10 => -10.0, + 11 => 0.1, + 12 => -0.1, + 13 => 100.0, + 14 => -100.0, + 15 => 3.14159, + _ => (i as f32) * 0.7 - 5.0, + }) + .collect(); + + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 128, // 32 f32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 2688, // 672 f32 values (32 threads * 21 outputs each) + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + RustComputeShader::default(), + [1, 1, 1], // Single workgroup with 32 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} + +#[cfg(target_arch = "spirv")] +fn main() {} diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/Cargo.toml b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/Cargo.toml new file mode 100644 index 0000000000..f9cb7c8629 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "math_ops-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl new file mode 100644 index 0000000000..c16fa21137 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl @@ -0,0 +1,51 @@ +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +@compute @workgroup_size(32, 1, 1) +fn main_cs(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + + if (tid >= 32u || tid >= arrayLength(&input)) { + return; + } + + let x = input[tid]; + let base_offset = tid * 21u; + + if (base_offset + 20u >= arrayLength(&output)) { + return; + } + + // Basic arithmetic + output[base_offset + 0u] = x + 1.5; + output[base_offset + 1u] = x - 0.5; + output[base_offset + 2u] = x * 2.0; + output[base_offset + 3u] = x / 2.0; + output[base_offset + 4u] = x % 3.0; + + // Trigonometric functions (simplified for consistent results) + output[base_offset + 5u] = sin(x); + output[base_offset + 6u] = cos(x); + output[base_offset + 7u] = clamp(tan(x), -10.0, 10.0); + output[base_offset + 8u] = 0.0; + output[base_offset + 9u] = 0.0; + output[base_offset + 10u] = atan(x); + + // Exponential and logarithmic (simplified) + output[base_offset + 11u] = min(exp(x), 1e6); + output[base_offset + 12u] = select(-10.0, log(x), x > 0.0); + output[base_offset + 13u] = sqrt(abs(x)); + output[base_offset + 14u] = pow(abs(x), 2.0); + output[base_offset + 15u] = select(-10.0, log2(x), x > 0.0); + output[base_offset + 16u] = min(exp2(x), 1e6); + output[base_offset + 17u] = floor(x); + output[base_offset + 18u] = ceil(x); + output[base_offset + 19u] = round(x); + + // Special values and conversions + let int_val = i32(x); + output[base_offset + 20u] = f32(int_val); +} \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs new file mode 100644 index 0000000000..6e8bd6273d --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs @@ -0,0 +1,54 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, WgpuComputeTestMultiBuffer, WgslComputeShader, +}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Create input data with various float values + let input_data: Vec = (0..32) + .map(|i| match i { + 0 => 0.0, + 1 => 1.0, + 2 => -1.0, + 3 => 0.5, + 4 => -0.5, + 5 => 2.0, + 6 => -2.0, + 7 => std::f32::consts::PI, + 8 => std::f32::consts::E, + 9 => 10.0, + 10 => -10.0, + 11 => 0.1, + 12 => -0.1, + 13 => 100.0, + 14 => -100.0, + 15 => 3.14159, + _ => (i as f32) * 0.7 - 5.0, + }) + .collect(); + + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 128, // 32 f32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 2688, // 672 f32 values (32 threads * 21 outputs each) + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + WgslComputeShader::default(), + [1, 1, 1], // Single workgroup with 32 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/Cargo.toml b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/Cargo.toml new file mode 100644 index 0000000000..64ab199838 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "matrix_ops-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# Common deps +[dependencies] + +# GPU deps +spirv-std.workspace = true + +# CPU deps +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs new file mode 100644 index 0000000000..76083c6de1 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs @@ -0,0 +1,156 @@ +#![no_std] + +use spirv_std::glam::{Mat2, Mat3, Mat4, UVec3, Vec2, Vec3, Vec4}; +#[allow(unused_imports)] +use spirv_std::num_traits::Float; +use spirv_std::spirv; + +#[spirv(compute(threads(32)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[f32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [f32], + #[spirv(global_invocation_id)] global_id: UVec3, +) { + let tid = global_id.x as usize; + + if tid >= 32 || tid * 4 + 3 >= input.len() { + return; + } + + // Read input values + let a = input[tid * 4]; + let b = input[tid * 4 + 1]; + let c = input[tid * 4 + 2]; + let d = input[tid * 4 + 3]; + + let base_offset = tid * 50; + + if base_offset + 49 >= output.len() { + return; + } + + // Mat2 operations + let m2a = Mat2::from_cols(Vec2::new(a, b), Vec2::new(c, d)); + let m2b = Mat2::from_cols(Vec2::new(d, c), Vec2::new(b, a)); + + // Mat2 multiplication + let m2_mul = m2a * m2b; + output[base_offset + 0] = m2_mul.col(0).x; + output[base_offset + 1] = m2_mul.col(0).y; + output[base_offset + 2] = m2_mul.col(1).x; + output[base_offset + 3] = m2_mul.col(1).y; + + // Mat2 transpose + let m2_transpose = m2a.transpose(); + output[base_offset + 4] = m2_transpose.col(0).x; + output[base_offset + 5] = m2_transpose.col(0).y; + output[base_offset + 6] = m2_transpose.col(1).x; + output[base_offset + 7] = m2_transpose.col(1).y; + + // Mat2 determinant (with rounding for consistency) + output[base_offset + 8] = (m2a.determinant() * 1000.0).round() / 1000.0; + + // Mat2 * Vec2 + let v2 = Vec2::new(1.0, 2.0); + let m2_v2 = m2a * v2; + output[base_offset + 9] = m2_v2.x; + output[base_offset + 10] = m2_v2.y; + + // Mat3 operations + let m3a = Mat3::from_cols(Vec3::new(a, b, c), Vec3::new(b, c, d), Vec3::new(c, d, a)); + let m3b = Mat3::from_cols(Vec3::new(d, c, b), Vec3::new(c, b, a), Vec3::new(b, a, d)); + + // Mat3 multiplication + let m3_mul = m3a * m3b; + output[base_offset + 11] = m3_mul.col(0).x; + output[base_offset + 12] = m3_mul.col(0).y; + output[base_offset + 13] = m3_mul.col(0).z; + output[base_offset + 14] = m3_mul.col(1).x; + output[base_offset + 15] = m3_mul.col(1).y; + output[base_offset + 16] = m3_mul.col(1).z; + output[base_offset + 17] = m3_mul.col(2).x; + output[base_offset + 18] = m3_mul.col(2).y; + output[base_offset + 19] = m3_mul.col(2).z; + + // Mat3 transpose - store just diagonal elements + let m3_transpose = m3a.transpose(); + output[base_offset + 20] = m3_transpose.col(0).x; + output[base_offset + 21] = m3_transpose.col(1).y; + output[base_offset + 22] = m3_transpose.col(2).z; + + // Mat3 determinant (with rounding for consistency) + output[base_offset + 23] = (m3a.determinant() * 1000.0).round() / 1000.0; + + // Mat3 * Vec3 (with rounding for consistency) + let v3 = Vec3::new(1.0, 2.0, 3.0); + let m3_v3 = m3a * v3; + output[base_offset + 24] = (m3_v3.x * 10000.0).round() / 10000.0; + output[base_offset + 25] = (m3_v3.y * 10000.0).round() / 10000.0; + output[base_offset + 26] = (m3_v3.z * 10000.0).round() / 10000.0; + + // Mat4 operations + let m4a = Mat4::from_cols( + Vec4::new(a, b, c, d), + Vec4::new(b, c, d, a), + Vec4::new(c, d, a, b), + Vec4::new(d, a, b, c), + ); + let m4b = Mat4::from_cols( + Vec4::new(d, c, b, a), + Vec4::new(c, b, a, d), + Vec4::new(b, a, d, c), + Vec4::new(a, d, c, b), + ); + + // Mat4 multiplication (just store diagonal for brevity) + let m4_mul = m4a * m4b; + output[base_offset + 27] = m4_mul.col(0).x; + output[base_offset + 28] = m4_mul.col(1).y; + output[base_offset + 29] = m4_mul.col(2).z; + output[base_offset + 30] = m4_mul.col(3).w; + + // Mat4 transpose (just store diagonal) + let m4_transpose = m4a.transpose(); + output[base_offset + 31] = m4_transpose.col(0).x; + output[base_offset + 32] = m4_transpose.col(1).y; + output[base_offset + 33] = m4_transpose.col(2).z; + output[base_offset + 34] = m4_transpose.col(3).w; + + // Mat4 determinant (with rounding for consistency) + output[base_offset + 35] = (m4a.determinant() * 1000.0).round() / 1000.0; + + // Mat4 * Vec4 (with rounding for consistency) + let v4 = Vec4::new(1.0, 2.0, 3.0, 4.0); + let m4_v4 = m4a * v4; + output[base_offset + 36] = (m4_v4.x * 10000.0).round() / 10000.0; + output[base_offset + 37] = (m4_v4.y * 10000.0).round() / 10000.0; + output[base_offset + 38] = (m4_v4.z * 10000.0).round() / 10000.0; + output[base_offset + 39] = (m4_v4.w * 10000.0).round() / 10000.0; + + // Identity matrices + output[base_offset + 40] = Mat2::IDENTITY.col(0).x; + output[base_offset + 41] = Mat3::IDENTITY.col(1).y; + output[base_offset + 42] = Mat4::IDENTITY.col(2).z; + + // Matrix inverse + if m2a.determinant().abs() > 0.0001 { + let m2_inv = m2a.inverse(); + output[base_offset + 43] = m2_inv.col(0).x; + output[base_offset + 44] = m2_inv.col(1).y; + } else { + output[base_offset + 43] = 0.0; + output[base_offset + 44] = 0.0; + } + + // Matrix addition + let m2_add = m2a + m2b; + output[base_offset + 45] = m2_add.col(0).x; + output[base_offset + 46] = m2_add.col(0).y; + + // Matrix scalar multiplication + let m2_scale = m2a * 2.0; + output[base_offset + 47] = m2_scale.col(0).x; + output[base_offset + 48] = m2_scale.col(0).y; + + output[base_offset + 49] = 1.0; // Padding +} diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs new file mode 100644 index 0000000000..c6fd3a2aa7 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs @@ -0,0 +1,58 @@ +#[cfg(not(target_arch = "spirv"))] +fn main() { + use difftest::config::Config; + use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTestMultiBuffer, + }; + + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Create input data with various float values + let input_data: Vec = (0..128) + .map(|i| match i % 16 { + 0 => 1.0, + 1 => 2.0, + 2 => 3.0, + 3 => 4.0, + 4 => 0.5, + 5 => -0.5, + 6 => 2.0, + 7 => -2.0, + 8 => 0.0, + 9 => 1.0, + 10 => -1.0, + 11 => 0.1, + 12 => 3.14, + 13 => 2.71, + 14 => 0.25, + 15 => -0.25, + _ => unreachable!(), + }) + .collect(); + + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 512, // 128 f32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 6400, // 1600 f32 values (32 threads * 50 outputs each) + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + RustComputeShader::default(), + [1, 1, 1], // Single workgroup with 32 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} + +#[cfg(target_arch = "spirv")] +fn main() {} diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/Cargo.toml b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/Cargo.toml new file mode 100644 index 0000000000..4fd25cf9ca --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "matrix_ops-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl new file mode 100644 index 0000000000..7dcdadaa45 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl @@ -0,0 +1,177 @@ +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +@compute @workgroup_size(32, 1, 1) +fn main_cs(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + + if (tid >= 32u || tid * 4u + 3u >= arrayLength(&input)) { + return; + } + + // Read input values + let a = input[tid * 4u]; + let b = input[tid * 4u + 1u]; + let c = input[tid * 4u + 2u]; + let d = input[tid * 4u + 3u]; + + let base_offset = tid * 50u; + + if (base_offset + 49u >= arrayLength(&output)) { + return; + } + + // Mat2 operations + let m2a = mat2x2( + vec2(a, b), + vec2(c, d) + ); + let m2b = mat2x2( + vec2(d, c), + vec2(b, a) + ); + + // Mat2 multiplication + let m2_mul = m2a * m2b; + output[base_offset + 0u] = m2_mul[0].x; + output[base_offset + 1u] = m2_mul[0].y; + output[base_offset + 2u] = m2_mul[1].x; + output[base_offset + 3u] = m2_mul[1].y; + + // Mat2 transpose + let m2_transpose = transpose(m2a); + output[base_offset + 4u] = m2_transpose[0].x; + output[base_offset + 5u] = m2_transpose[0].y; + output[base_offset + 6u] = m2_transpose[1].x; + output[base_offset + 7u] = m2_transpose[1].y; + + // Mat2 determinant (with rounding for consistency) + output[base_offset + 8u] = round(determinant(m2a) * 1000.0) / 1000.0; + + // Mat2 * Vec2 + let v2 = vec2(1.0, 2.0); + let m2_v2 = m2a * v2; + output[base_offset + 9u] = m2_v2.x; + output[base_offset + 10u] = m2_v2.y; + + // Mat3 operations + let m3a = mat3x3( + vec3(a, b, c), + vec3(b, c, d), + vec3(c, d, a) + ); + let m3b = mat3x3( + vec3(d, c, b), + vec3(c, b, a), + vec3(b, a, d) + ); + + // Mat3 multiplication + let m3_mul = m3a * m3b; + output[base_offset + 11u] = m3_mul[0].x; + output[base_offset + 12u] = m3_mul[0].y; + output[base_offset + 13u] = m3_mul[0].z; + output[base_offset + 14u] = m3_mul[1].x; + output[base_offset + 15u] = m3_mul[1].y; + output[base_offset + 16u] = m3_mul[1].z; + output[base_offset + 17u] = m3_mul[2].x; + output[base_offset + 18u] = m3_mul[2].y; + output[base_offset + 19u] = m3_mul[2].z; + + // Mat3 transpose + let m3_transpose = transpose(m3a); + output[base_offset + 20u] = m3_transpose[0].x; + output[base_offset + 21u] = m3_transpose[1].y; + output[base_offset + 22u] = m3_transpose[2].z; + + // Mat3 determinant (with rounding for consistency) + output[base_offset + 23u] = round(determinant(m3a) * 1000.0) / 1000.0; + + // Mat3 * Vec3 (with rounding for consistency) + let v3 = vec3(1.0, 2.0, 3.0); + let m3_v3 = m3a * v3; + output[base_offset + 24u] = round(m3_v3.x * 10000.0) / 10000.0; + output[base_offset + 25u] = round(m3_v3.y * 10000.0) / 10000.0; + output[base_offset + 26u] = round(m3_v3.z * 10000.0) / 10000.0; + + // Mat4 operations + let m4a = mat4x4( + vec4(a, b, c, d), + vec4(b, c, d, a), + vec4(c, d, a, b), + vec4(d, a, b, c) + ); + let m4b = mat4x4( + vec4(d, c, b, a), + vec4(c, b, a, d), + vec4(b, a, d, c), + vec4(a, d, c, b) + ); + + // Mat4 multiplication (just store diagonal for brevity) + let m4_mul = m4a * m4b; + output[base_offset + 27u] = m4_mul[0].x; + output[base_offset + 28u] = m4_mul[1].y; + output[base_offset + 29u] = m4_mul[2].z; + output[base_offset + 30u] = m4_mul[3].w; + + // Mat4 transpose (just store diagonal) + let m4_transpose = transpose(m4a); + output[base_offset + 31u] = m4_transpose[0].x; + output[base_offset + 32u] = m4_transpose[1].y; + output[base_offset + 33u] = m4_transpose[2].z; + output[base_offset + 34u] = m4_transpose[3].w; + + // Mat4 determinant (with rounding for consistency) + output[base_offset + 35u] = round(determinant(m4a) * 1000.0) / 1000.0; + + // Mat4 * Vec4 (with rounding for consistency) + let v4 = vec4(1.0, 2.0, 3.0, 4.0); + let m4_v4 = m4a * v4; + output[base_offset + 36u] = round(m4_v4.x * 10000.0) / 10000.0; + output[base_offset + 37u] = round(m4_v4.y * 10000.0) / 10000.0; + output[base_offset + 38u] = round(m4_v4.z * 10000.0) / 10000.0; + output[base_offset + 39u] = round(m4_v4.w * 10000.0) / 10000.0; + + // Identity matrices + output[base_offset + 40u] = mat2x2(vec2(1.0, 0.0), vec2(0.0, 1.0))[0].x; + output[base_offset + 41u] = mat3x3(vec3(1.0, 0.0, 0.0), vec3(0.0, 1.0, 0.0), vec3(0.0, 0.0, 1.0))[1].y; + output[base_offset + 42u] = mat4x4(vec4(1.0, 0.0, 0.0, 0.0), vec4(0.0, 1.0, 0.0, 0.0), vec4(0.0, 0.0, 1.0, 0.0), vec4(0.0, 0.0, 0.0, 1.0))[2].z; + + // Matrix inverse (WGSL doesn't have try_inverse, so we check determinant) + let det = determinant(m2a); + if (abs(det) > 0.0001) { + // Manual 2x2 inverse calculation + let inv_det = 1.0 / det; + let m2_inv = mat2x2( + vec2(m2a[1].y * inv_det, -m2a[0].y * inv_det), + vec2(-m2a[1].x * inv_det, m2a[0].x * inv_det) + ); + output[base_offset + 43u] = m2_inv[0].x; + output[base_offset + 44u] = m2_inv[1].y; + } else { + output[base_offset + 43u] = 0.0; + output[base_offset + 44u] = 0.0; + } + + // Matrix addition + let m2_add = mat2x2( + m2a[0] + m2b[0], + m2a[1] + m2b[1] + ); + output[base_offset + 45u] = m2_add[0].x; + output[base_offset + 46u] = m2_add[0].y; + + // Matrix scalar multiplication + let m2_scale = mat2x2( + m2a[0] * 2.0, + m2a[1] * 2.0 + ); + output[base_offset + 47u] = m2_scale[0].x; + output[base_offset + 48u] = m2_scale[0].y; + + output[base_offset + 49u] = 1.0; // Padding +} \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs new file mode 100644 index 0000000000..7488803fdb --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs @@ -0,0 +1,54 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, WgpuComputeTestMultiBuffer, WgslComputeShader, +}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Create input data with various float values + let input_data: Vec = (0..128) + .map(|i| match i % 16 { + 0 => 1.0, + 1 => 2.0, + 2 => 3.0, + 3 => 4.0, + 4 => 0.5, + 5 => -0.5, + 6 => 2.0, + 7 => -2.0, + 8 => 0.0, + 9 => 1.0, + 10 => -1.0, + 11 => 0.1, + 12 => 3.14, + 13 => 2.71, + 14 => 0.25, + 15 => -0.25, + _ => unreachable!(), + }) + .collect(); + + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 512, // 128 f32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 6400, // 1600 f32 values (32 threads * 50 outputs each) + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + WgslComputeShader::default(), + [1, 1, 1], // Single workgroup with 32 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/mod.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/mod.rs new file mode 100644 index 0000000000..c507227bc2 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/mod.rs @@ -0,0 +1,16 @@ +use crate::test_framework::*; + +pub const MODULE: &str = "lang/core/ops/matrix_ops"; + +pub fn matrix_ops() -> TestCase { + TestCase::new( + "matrix_ops", + MODULE, + WgpuComputeTest::new( + |value: u32| [value as f32, (value + 1) as f32, (value + 2) as f32, (value + 3) as f32], + |_input, _output| { + // Default comparison will check exact equality + }, + ), + ) +} \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/Cargo.toml b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/Cargo.toml new file mode 100644 index 0000000000..229e9f2459 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "trig_ops-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# GPU deps +[dependencies] +spirv-std.workspace = true + +# CPU deps (for the test harness) +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/lib.rs new file mode 100644 index 0000000000..b663576fdc --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/lib.rs @@ -0,0 +1,51 @@ +#![no_std] +#![cfg_attr(target_arch = "spirv", feature(asm_experimental_arch))] + +#[allow(unused_imports)] +use spirv_std::num_traits::Float; +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[f32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [f32], + #[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3, +) { + let tid = global_id.x as usize; + + if tid < input.len() && tid < output.len() { + let x = input[tid]; + + // Test various trigonometric functions + let result = match tid % 14 { + 0 => x.sin(), + 1 => x.cos(), + 2 => x.tan(), + 3 => x.asin().clamp(-1.0, 1.0), // Clamp to avoid NaN for values outside [-1, 1] + 4 => x.acos().clamp(-1.0, 1.0), // Clamp to avoid NaN for values outside [-1, 1] + 5 => x.atan(), + 6 => x.sinh(), + 7 => x.cosh(), + 8 => x.tanh(), + 9 => { + // atan2 - use two consecutive values + let y = if tid + 1 < input.len() { + input[tid + 1] + } else { + 1.0 + }; + x.atan2(y) + } + 10 => (x * x + 1.0).sqrt(), // hypot equivalent: sqrt(x^2 + 1^2) + 11 => x.to_radians(), + 12 => x.to_degrees(), + 13 => { + // sincos - return sin for even indices, cos for odd + if tid % 2 == 0 { x.sin() } else { x.cos() } + } + _ => 0.0, + }; + + output[tid] = result; + } +} diff --git a/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/main.rs new file mode 100644 index 0000000000..eb96343bcf --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/main.rs @@ -0,0 +1,15 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{RustComputeShader, WgpuComputeTestMultiBuffer}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 1024; + let test = + WgpuComputeTestMultiBuffer::new_with_sizes(RustComputeShader::default(), [64, 1, 1], &[ + buffer_size, + buffer_size, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/Cargo.toml b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/Cargo.toml new file mode 100644 index 0000000000..b053cfaae0 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "trig_ops-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/shader.wgsl new file mode 100644 index 0000000000..2a43f9a3b1 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/shader.wgsl @@ -0,0 +1,47 @@ +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + + if (tid < arrayLength(&input) && tid < arrayLength(&output)) { + let x = input[tid]; + + // Test various trigonometric functions + var result: f32; + switch (tid % 14u) { + case 0u: { result = sin(x); } + case 1u: { result = cos(x); } + case 2u: { result = tan(x); } + case 3u: { result = asin(clamp(x, -1.0, 1.0)); } // Clamp to avoid NaN + case 4u: { result = acos(clamp(x, -1.0, 1.0)); } // Clamp to avoid NaN + case 5u: { result = atan(x); } + case 6u: { result = sinh(x); } + case 7u: { result = cosh(x); } + case 8u: { result = tanh(x); } + case 9u: { + // atan2 - use two consecutive values + var y: f32 = 1.0; + if (tid + 1u < arrayLength(&input)) { + y = input[tid + 1u]; + } + result = atan2(x, y); + } + case 10u: { result = sqrt(x * x + 1.0); } // hypot equivalent + case 11u: { result = radians(x); } + case 12u: { result = degrees(x); } + case 13u: { + // sincos - return sin for even indices, cos for odd + if (tid % 2u == 0u) { + result = sin(x); + } else { + result = cos(x); + } + } + default: { result = 0.0; } + } + + output[tid] = result; + } +} \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/src/main.rs new file mode 100644 index 0000000000..f21cdecfbf --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/src/main.rs @@ -0,0 +1,15 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{WgpuComputeTestMultiBuffer, WgslComputeShader}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 1024; + let test = + WgpuComputeTestMultiBuffer::new_with_sizes(WgslComputeShader::default(), [64, 1, 1], &[ + buffer_size, + buffer_size, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/Cargo.toml b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/Cargo.toml new file mode 100644 index 0000000000..d0cb1246a2 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "vector_ops-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# Common deps +[dependencies] + +# GPU deps +spirv-std.workspace = true + +# CPU deps +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs new file mode 100644 index 0000000000..6a3c2a4e5d --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs @@ -0,0 +1,141 @@ +#![no_std] + +use spirv_std::glam::{UVec2, UVec3, UVec4, Vec2, Vec3, Vec4, Vec4Swizzles}; +#[allow(unused_imports)] +use spirv_std::num_traits::Float; +use spirv_std::spirv; + +#[spirv(compute(threads(32)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[f32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [f32], + #[spirv(global_invocation_id)] global_id: UVec3, +) { + let tid = global_id.x as usize; + + if tid >= 32 || tid * 4 + 3 >= input.len() { + return; + } + + // Read 4 floats from input + let a = input[tid * 4]; + let b = input[tid * 4 + 1]; + let c = input[tid * 4 + 2]; + let d = input[tid * 4 + 3]; + + let base_offset = tid * 50; + + if base_offset + 49 >= output.len() { + return; + } + + // Vec2 operations + let v2a = Vec2::new(a, b); + let v2b = Vec2::new(c, d); + + output[base_offset + 0] = (v2a.dot(v2b) * 1000.0).round() / 1000.0; + output[base_offset + 1] = (v2a.length() * 1000.0).round() / 1000.0; + output[base_offset + 2] = (v2a.distance(v2b) * 1000.0).round() / 1000.0; + + let v2_add = v2a + v2b; + output[base_offset + 3] = v2_add.x; + output[base_offset + 4] = v2_add.y; + + let v2_mul = v2a * 2.0; + output[base_offset + 5] = v2_mul.x; + output[base_offset + 6] = v2_mul.y; + + // Vec3 operations + let v3a = Vec3::new(a, b, c); + let v3b = Vec3::new(b, c, d); + + output[base_offset + 7] = (v3a.dot(v3b) * 1000.0).round() / 1000.0; + output[base_offset + 8] = (v3a.length() * 1000.0).round() / 1000.0; + + let v3_cross = v3a.cross(v3b); + output[base_offset + 9] = v3_cross.x; + output[base_offset + 10] = v3_cross.y; + output[base_offset + 11] = v3_cross.z; + + let v3_norm = v3a.normalize(); + output[base_offset + 12] = (v3_norm.x * 1000.0).round() / 1000.0; + output[base_offset + 13] = (v3_norm.y * 1000.0).round() / 1000.0; + output[base_offset + 14] = (v3_norm.z * 1000.0).round() / 1000.0; + + // Vec4 operations + let v4a = Vec4::new(a, b, c, d); + let v4b = Vec4::new(d, c, b, a); + + output[base_offset + 15] = (v4a.dot(v4b) * 1000.0).round() / 1000.0; + output[base_offset + 16] = (v4a.length() * 1000.0).round() / 1000.0; + + let v4_sub = v4a - v4b; + output[base_offset + 17] = v4_sub.x; + output[base_offset + 18] = v4_sub.y; + output[base_offset + 19] = v4_sub.z; + output[base_offset + 20] = v4_sub.w; + + // Swizzling + output[base_offset + 21] = v4a.x; + output[base_offset + 22] = v4a.y; + output[base_offset + 23] = v4a.z; + output[base_offset + 24] = v4a.w; + + let v4_swizzle = v4a.wzyx(); + output[base_offset + 25] = v4_swizzle.x; + output[base_offset + 26] = v4_swizzle.y; + output[base_offset + 27] = v4_swizzle.z; + output[base_offset + 28] = v4_swizzle.w; + + // Component-wise operations + let v4_min = v4a.min(v4b); + output[base_offset + 29] = v4_min.x; + output[base_offset + 30] = v4_min.y; + output[base_offset + 31] = v4_min.z; + + // UVec operations + let ua = a.abs() as u32; + let ub = b.abs() as u32; + let uc = c.abs() as u32; + let ud = d.abs() as u32; + + let uv2 = UVec2::new(ua, ub); + let uv3 = UVec3::new(ua, ub, uc); + let uv4 = UVec4::new(ua, ub, uc, ud); + + output[base_offset + 32] = (uv2.x + uv2.y) as f32; + output[base_offset + 33] = (uv3.x + uv3.y + uv3.z) as f32; + output[base_offset + 34] = (uv4.x + uv4.y + uv4.z + uv4.w) as f32; + + // UVec min/max operations commented out due to Int8 requirement + // See: https://github.com/Rust-GPU/rust-gpu/issues/314 + // let uv4_min = uv4.min(UVec4::new(ud, uc, ub, ua)); + // let uv4_max = uv4.max(UVec4::new(1, 2, 3, 4)); + output[base_offset + 35] = ua as f32; // Should be: uv4_min.x as f32 + output[base_offset + 36] = ub as f32; // Should be: uv4_min.y as f32 + output[base_offset + 37] = ua as f32; // Should be: uv4_max.x as f32 + output[base_offset + 38] = ub as f32; // Should be: uv4_max.y as f32 + output[base_offset + 39] = uc as f32; // Should be: uv4_max.z as f32 + + // Vector constants + output[base_offset + 40] = Vec2::ZERO.x; + output[base_offset + 41] = Vec3::ONE.x; + output[base_offset + 42] = Vec4::X.x; + output[base_offset + 43] = Vec4::Y.y; + + // Converting between sizes + let v2_from_v3 = v3a.truncate(); + output[base_offset + 44] = v2_from_v3.x; + output[base_offset + 45] = v2_from_v3.y; + + let v3_from_v4 = v4a.truncate(); + output[base_offset + 46] = v3_from_v4.x; + + let v3_extended = v2a.extend(1.0); + output[base_offset + 47] = v3_extended.z; + + let v4_extended = v3a.extend(2.0); + output[base_offset + 48] = v4_extended.w; + + output[base_offset + 49] = UVec4::ONE.x as f32; +} diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs new file mode 100644 index 0000000000..d78d11bf17 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs @@ -0,0 +1,58 @@ +#[cfg(not(target_arch = "spirv"))] +fn main() { + use difftest::config::Config; + use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTestMultiBuffer, + }; + + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Create input data with various float values + let input_data: Vec = (0..128) + .map(|i| match i % 16 { + 0 => 0.0, + 1 => 1.0, + 2 => -1.0, + 3 => 0.5, + 4 => -0.5, + 5 => 2.0, + 6 => -2.0, + 7 => 3.0, + 8 => std::f32::consts::PI, + 9 => std::f32::consts::E, + 10 => 0.1, + 11 => -0.1, + 12 => 4.0, + 13 => -4.0, + 14 => 0.25, + 15 => -0.25, + _ => unreachable!(), + }) + .collect(); + + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 512, // 128 f32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 6400, // 1600 f32 values (32 threads * 50 outputs each) + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + RustComputeShader::default(), + [1, 1, 1], // Single workgroup with 32 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} + +#[cfg(target_arch = "spirv")] +fn main() {} diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/Cargo.toml b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/Cargo.toml new file mode 100644 index 0000000000..727c6f9913 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "vector_ops-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl new file mode 100644 index 0000000000..9393759023 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl @@ -0,0 +1,136 @@ +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +@compute @workgroup_size(32, 1, 1) +fn main_cs(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + + if (tid >= 32u || tid * 4u + 3u >= arrayLength(&input)) { + return; + } + + // Read 4 floats from input + let a = input[tid * 4u]; + let b = input[tid * 4u + 1u]; + let c = input[tid * 4u + 2u]; + let d = input[tid * 4u + 3u]; + + let base_offset = tid * 50u; + + if (base_offset + 49u >= arrayLength(&output)) { + return; + } + + // Vec2 operations + let v2a = vec2(a, b); + let v2b = vec2(c, d); + + output[base_offset + 0u] = round(dot(v2a, v2b) * 1000.0) / 1000.0; + output[base_offset + 1u] = round(length(v2a) * 1000.0) / 1000.0; + output[base_offset + 2u] = round(distance(v2a, v2b) * 1000.0) / 1000.0; + + let v2_add = v2a + v2b; + output[base_offset + 3u] = v2_add.x; + output[base_offset + 4u] = v2_add.y; + + let v2_mul = v2a * 2.0; + output[base_offset + 5u] = v2_mul.x; + output[base_offset + 6u] = v2_mul.y; + + // Vec3 operations + let v3a = vec3(a, b, c); + let v3b = vec3(b, c, d); + + output[base_offset + 7u] = round(dot(v3a, v3b) * 1000.0) / 1000.0; + output[base_offset + 8u] = round(length(v3a) * 1000.0) / 1000.0; + + let v3_cross = cross(v3a, v3b); + output[base_offset + 9u] = v3_cross.x; + output[base_offset + 10u] = v3_cross.y; + output[base_offset + 11u] = v3_cross.z; + + let v3_norm = normalize(v3a); + output[base_offset + 12u] = round(v3_norm.x * 1000.0) / 1000.0; + output[base_offset + 13u] = round(v3_norm.y * 1000.0) / 1000.0; + output[base_offset + 14u] = round(v3_norm.z * 1000.0) / 1000.0; + + // Vec4 operations + let v4a = vec4(a, b, c, d); + let v4b = vec4(d, c, b, a); + + output[base_offset + 15u] = round(dot(v4a, v4b) * 1000.0) / 1000.0; + output[base_offset + 16u] = round(length(v4a) * 1000.0) / 1000.0; + + let v4_sub = v4a - v4b; + output[base_offset + 17u] = v4_sub.x; + output[base_offset + 18u] = v4_sub.y; + output[base_offset + 19u] = v4_sub.z; + output[base_offset + 20u] = v4_sub.w; + + // Swizzling + output[base_offset + 21u] = v4a.x; + output[base_offset + 22u] = v4a.y; + output[base_offset + 23u] = v4a.z; + output[base_offset + 24u] = v4a.w; + + let v4_swizzle = v4a.wzyx; + output[base_offset + 25u] = v4_swizzle.x; + output[base_offset + 26u] = v4_swizzle.y; + output[base_offset + 27u] = v4_swizzle.z; + output[base_offset + 28u] = v4_swizzle.w; + + // Component-wise operations + let v4_min = min(v4a, v4b); + output[base_offset + 29u] = v4_min.x; + output[base_offset + 30u] = v4_min.y; + output[base_offset + 31u] = v4_min.z; + + // UVec operations + let ua = u32(abs(a)); + let ub = u32(abs(b)); + let uc = u32(abs(c)); + let ud = u32(abs(d)); + + let uv2 = vec2(ua, ub); + let uv3 = vec3(ua, ub, uc); + let uv4 = vec4(ua, ub, uc, ud); + + output[base_offset + 32u] = f32(uv2.x + uv2.y); + output[base_offset + 33u] = f32(uv3.x + uv3.y + uv3.z); + output[base_offset + 34u] = f32(uv4.x + uv4.y + uv4.z + uv4.w); + + // UVec min/max operations commented out to match Rust side + // See: https://github.com/Rust-GPU/rust-gpu/issues/314 + // let uv4_min = min(uv4, vec4(ud, uc, ub, ua)); + // let uv4_max = max(uv4, vec4(1u, 2u, 3u, 4u)); + output[base_offset + 35u] = f32(ua); // Should be: f32(uv4_min.x) + output[base_offset + 36u] = f32(ub); // Should be: f32(uv4_min.y) + output[base_offset + 37u] = f32(ua); // Should be: f32(uv4_max.x) + output[base_offset + 38u] = f32(ub); // Should be: f32(uv4_max.y) + output[base_offset + 39u] = f32(uc); // Should be: f32(uv4_max.z) + + // Vector constants + output[base_offset + 40u] = vec2(0.0, 0.0).x; + output[base_offset + 41u] = vec3(1.0, 1.0, 1.0).x; + output[base_offset + 42u] = vec4(1.0, 0.0, 0.0, 0.0).x; + output[base_offset + 43u] = vec4(0.0, 1.0, 0.0, 0.0).y; + + // Converting between sizes + let v2_from_v3 = vec2(v3a.xy); + output[base_offset + 44u] = v2_from_v3.x; + output[base_offset + 45u] = v2_from_v3.y; + + let v3_from_v4 = vec3(v4a.xyz); + output[base_offset + 46u] = v3_from_v4.x; + + let v3_extended = vec3(v2a, 1.0); + output[base_offset + 47u] = v3_extended.z; + + let v4_extended = vec4(v3a, 2.0); + output[base_offset + 48u] = v4_extended.w; + + output[base_offset + 49u] = f32(vec4(1u, 1u, 1u, 1u).x); +} diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs new file mode 100644 index 0000000000..fa33c6170c --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs @@ -0,0 +1,54 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, WgpuComputeTestMultiBuffer, WgslComputeShader, +}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Create input data with various float values + let input_data: Vec = (0..128) + .map(|i| match i % 16 { + 0 => 0.0, + 1 => 1.0, + 2 => -1.0, + 3 => 0.5, + 4 => -0.5, + 5 => 2.0, + 6 => -2.0, + 7 => 3.0, + 8 => std::f32::consts::PI, + 9 => std::f32::consts::E, + 10 => 0.1, + 11 => -0.1, + 12 => 4.0, + 13 => -4.0, + 14 => 0.25, + 15 => -0.25, + _ => unreachable!(), + }) + .collect(); + + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 512, // 128 f32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 6400, // 1600 f32 values (32 threads * 50 outputs each) + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + WgslComputeShader::default(), + [1, 1, 1], // Single workgroup with 32 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/Cargo.toml b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/Cargo.toml new file mode 100644 index 0000000000..6cf5f5e2a3 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "vector_swizzle-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# GPU deps +[dependencies] +spirv-std.workspace = true +glam.workspace = true + +# CPU deps (for the test harness) +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/lib.rs new file mode 100644 index 0000000000..5412934294 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/lib.rs @@ -0,0 +1,71 @@ +#![no_std] +#![cfg_attr(target_arch = "spirv", feature(asm_experimental_arch))] + +use glam::{Vec3Swizzles, Vec4, Vec4Swizzles}; +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[[f32; 4]], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [[f32; 4]], + #[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3, +) { + let tid = global_id.x as usize; + + if tid < input.len() && tid < output.len() { + let data = input[tid]; + let vec4 = Vec4::from_array(data); + + // Test various swizzle operations + let result = match tid % 16 { + // Vec4 swizzles + 0 => vec4.xyzw(), // Identity + 1 => vec4.wzyx(), // Reverse + 2 => vec4.xxxx(), // Broadcast + 3 => vec4.xzyw(), // Swap middle + + // Vec3 from Vec4 + 4 => vec4.xyz().extend(0.0), + 5 => vec4.xyw().extend(0.0), + 6 => vec4.yzw().extend(0.0), + + // Vec2 from Vec4 + 7 => { + let v2 = vec4.xy(); + Vec4::new(v2.x, v2.y, 0.0, 0.0) + } + 8 => { + let v2 = vec4.zw(); + Vec4::new(v2.x, v2.y, 0.0, 0.0) + } + 9 => { + let v2 = vec4.xw(); + Vec4::new(v2.x, v2.y, 0.0, 0.0) + } + + // Complex swizzles + 10 => { + let v3 = vec4.zyx(); + v3.extend(vec4.w) + } + 11 => { + let v2 = vec4.yx(); + Vec4::new(v2.x, v2.y, vec4.z, vec4.w) + } + + // Component access + 12 => Vec4::new(vec4.x, vec4.x, vec4.y, vec4.y), + 13 => Vec4::new(vec4.z, vec4.z, vec4.w, vec4.w), + 14 => Vec4::new(vec4.w, vec4.z, vec4.y, vec4.x), + 15 => { + // Nested swizzles + let v3 = vec4.xyz(); + let v2 = v3.xy(); + Vec4::new(v2.y, v2.x, v3.z, vec4.w) + } + _ => Vec4::ZERO, + }; + + output[tid] = result.to_array(); + } +} diff --git a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/main.rs new file mode 100644 index 0000000000..eb96343bcf --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/main.rs @@ -0,0 +1,15 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{RustComputeShader, WgpuComputeTestMultiBuffer}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 1024; + let test = + WgpuComputeTestMultiBuffer::new_with_sizes(RustComputeShader::default(), [64, 1, 1], &[ + buffer_size, + buffer_size, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/Cargo.toml b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/Cargo.toml new file mode 100644 index 0000000000..9adbc06344 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "vector_swizzle-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/shader.wgsl new file mode 100644 index 0000000000..e811918b23 --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/shader.wgsl @@ -0,0 +1,64 @@ +@group(0) @binding(0) var input: array>; +@group(0) @binding(1) var output: array>; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + + if (tid < arrayLength(&input) && tid < arrayLength(&output)) { + let vec4_val = input[tid]; + + // Test various swizzle operations + var result: vec4; + switch (tid % 16u) { + // Vec4 swizzles + case 0u: { result = vec4_val.xyzw; } // Identity + case 1u: { result = vec4_val.wzyx; } // Reverse + case 2u: { result = vec4_val.xxxx; } // Broadcast + case 3u: { result = vec4_val.xzyw; } // Swap middle + + // Vec3 from Vec4 + case 4u: { result = vec4(vec4_val.xyz, 0.0); } + case 5u: { result = vec4(vec4_val.xyw, 0.0); } + case 6u: { result = vec4(vec4_val.yzw, 0.0); } + + // Vec2 from Vec4 + case 7u: { + let v2 = vec4_val.xy; + result = vec4(v2.x, v2.y, 0.0, 0.0); + } + case 8u: { + let v2 = vec4_val.zw; + result = vec4(v2.x, v2.y, 0.0, 0.0); + } + case 9u: { + let v2 = vec4_val.xw; + result = vec4(v2.x, v2.y, 0.0, 0.0); + } + + // Complex swizzles + case 10u: { + let v3 = vec4_val.zyx; + result = vec4(v3, vec4_val.w); + } + case 11u: { + let v2 = vec4_val.yx; + result = vec4(v2.x, v2.y, vec4_val.z, vec4_val.w); + } + + // Component access + case 12u: { result = vec4(vec4_val.x, vec4_val.x, vec4_val.y, vec4_val.y); } + case 13u: { result = vec4(vec4_val.z, vec4_val.z, vec4_val.w, vec4_val.w); } + case 14u: { result = vec4(vec4_val.w, vec4_val.z, vec4_val.y, vec4_val.x); } + case 15u: { + // Nested swizzles + let v3 = vec4_val.xyz; + let v2 = v3.xy; + result = vec4(v2.y, v2.x, v3.z, vec4_val.w); + } + default: { result = vec4(0.0); } + } + + output[tid] = result; + } +} \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/src/main.rs new file mode 100644 index 0000000000..f21cdecfbf --- /dev/null +++ b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/src/main.rs @@ -0,0 +1,15 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{WgpuComputeTestMultiBuffer, WgslComputeShader}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + let buffer_size = 1024; + let test = + WgpuComputeTestMultiBuffer::new_with_sizes(WgslComputeShader::default(), [64, 1, 1], &[ + buffer_size, + buffer_size, + ]); + + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/storage_class/array_access/array_access-rust/Cargo.toml b/tests/difftests/tests/storage_class/array_access/array_access-rust/Cargo.toml new file mode 100644 index 0000000000..fb811e545d --- /dev/null +++ b/tests/difftests/tests/storage_class/array_access/array_access-rust/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "array_access-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# Common deps +[dependencies] + +# GPU deps +spirv-std.workspace = true + +# CPU deps +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/storage_class/array_access/array_access-rust/src/lib.rs b/tests/difftests/tests/storage_class/array_access/array_access-rust/src/lib.rs new file mode 100644 index 0000000000..aba8cd9294 --- /dev/null +++ b/tests/difftests/tests/storage_class/array_access/array_access-rust/src/lib.rs @@ -0,0 +1,48 @@ +#![no_std] + +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [u32], + #[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3, +) { + let tid = global_id.x as usize; + + // Test various array access patterns + + // 1. Direct indexing + if tid < input.len() { + output[tid] = input[tid] * 2; + } + + // 2. Strided access (every 4th element) + let stride_idx = tid * 4; + if stride_idx < input.len() && tid + 64 < output.len() { + output[tid + 64] = input[stride_idx]; + } + + // 3. Reverse indexing + if tid < input.len() && tid + 128 < output.len() { + let reverse_idx = input.len() - 1 - tid; + output[tid + 128] = input[reverse_idx]; + } + + // 4. Gather operation (indirect indexing) + if tid < 16 && tid + 192 < output.len() { + // Use first 16 values as indices + let index = input[tid] as usize; + if index < input.len() { + output[tid + 192] = input[index]; + } + } + + // 5. Write pattern to test bounds + if tid == 0 { + // Write sentinel values at specific positions + output[208] = 0xDEADBEEF; + output[209] = 0xCAFEBABE; + output[210] = input.len() as u32; + } +} diff --git a/tests/difftests/tests/storage_class/array_access/array_access-rust/src/main.rs b/tests/difftests/tests/storage_class/array_access/array_access-rust/src/main.rs new file mode 100644 index 0000000000..9248d1802e --- /dev/null +++ b/tests/difftests/tests/storage_class/array_access/array_access-rust/src/main.rs @@ -0,0 +1,46 @@ +#[cfg(not(target_arch = "spirv"))] +fn main() { + use difftest::config::Config; + use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTestMultiBuffer, + }; + + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Create input data with specific patterns + let mut input_data = vec![0u32; 256]; + for i in 0..256 { + input_data[i] = i as u32; + } + // Set some specific values for indirect indexing test + input_data[0] = 5; + input_data[1] = 10; + input_data[2] = 15; + input_data[3] = 20; + + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 1024, // 256 u32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 1024, // 256 u32 values for output + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + RustComputeShader::default(), + [1, 1, 1], // Single workgroup with 64 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} + +#[cfg(target_arch = "spirv")] +fn main() {} diff --git a/tests/difftests/tests/storage_class/array_access/array_access-wgsl/Cargo.toml b/tests/difftests/tests/storage_class/array_access/array_access-wgsl/Cargo.toml new file mode 100644 index 0000000000..abba1fcaad --- /dev/null +++ b/tests/difftests/tests/storage_class/array_access/array_access-wgsl/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "array_access-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/storage_class/array_access/array_access-wgsl/shader.wgsl b/tests/difftests/tests/storage_class/array_access/array_access-wgsl/shader.wgsl new file mode 100644 index 0000000000..dafec91157 --- /dev/null +++ b/tests/difftests/tests/storage_class/array_access/array_access-wgsl/shader.wgsl @@ -0,0 +1,47 @@ +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +@compute @workgroup_size(64, 1, 1) +fn main_cs(@builtin(global_invocation_id) global_id: vec3) { + let tid = global_id.x; + let input_len = arrayLength(&input); + + // Test various array access patterns + + // 1. Direct indexing + if (tid < input_len) { + output[tid] = input[tid] * 2u; + } + + // 2. Strided access (every 4th element) + let stride_idx = tid * 4u; + if (stride_idx < input_len && tid + 64u < arrayLength(&output)) { + output[tid + 64u] = input[stride_idx]; + } + + // 3. Reverse indexing + if (tid < input_len && tid + 128u < arrayLength(&output)) { + let reverse_idx = input_len - 1u - tid; + output[tid + 128u] = input[reverse_idx]; + } + + // 4. Gather operation (indirect indexing) + if (tid < 16u && tid + 192u < arrayLength(&output)) { + // Use first 16 values as indices + let index = input[tid]; + if (index < input_len) { + output[tid + 192u] = input[index]; + } + } + + // 5. Write pattern to test bounds + if (tid == 0u) { + // Write sentinel values at specific positions + output[208] = 0xDEADBEEFu; + output[209] = 0xCAFEBABEu; + output[210] = input_len; + } +} \ No newline at end of file diff --git a/tests/difftests/tests/storage_class/array_access/array_access-wgsl/src/main.rs b/tests/difftests/tests/storage_class/array_access/array_access-wgsl/src/main.rs new file mode 100644 index 0000000000..3abc311375 --- /dev/null +++ b/tests/difftests/tests/storage_class/array_access/array_access-wgsl/src/main.rs @@ -0,0 +1,42 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, WgpuComputeTestMultiBuffer, WgslComputeShader, +}; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Create input data with specific patterns + let mut input_data = vec![0u32; 256]; + for i in 0..256 { + input_data[i] = i as u32; + } + // Set some specific values for indirect indexing test + input_data[0] = 5; + input_data[1] = 10; + input_data[2] = 15; + input_data[3] = 20; + + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 1024, // 256 u32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 1024, // 256 u32 values for output + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = WgpuComputeTestMultiBuffer::new( + WgslComputeShader::default(), + [1, 1, 1], // Single workgroup with 64 threads + buffers, + ); + + test.run_test(&config).unwrap(); +} From 879d93f80a6d2ef265640ac8add17f8e2df5acb0 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Fri, 4 Jul 2025 20:40:08 +0200 Subject: [PATCH 02/24] Round floats to prevent deltas across platforms --- tests/difftests/lib/Cargo.toml | 2 +- tests/difftests/lib/src/lib.rs | 15 +++++ .../ops/math_ops/math_ops-rust/Cargo.toml | 6 +- .../ops/math_ops/math_ops-rust/src/lib.rs | 33 ++++----- .../ops/math_ops/math_ops-wgsl/shader.wgsl | 37 +++++----- .../ops/matrix_ops/matrix_ops-rust/Cargo.toml | 6 +- .../ops/matrix_ops/matrix_ops-rust/src/lib.rs | 63 ++++++++--------- .../matrix_ops/matrix_ops-wgsl/shader.wgsl | 67 ++++++++++--------- .../ops/vector_ops/vector_ops-rust/Cargo.toml | 7 +- .../ops/vector_ops/vector_ops-rust/src/lib.rs | 27 ++++---- .../vector_ops/vector_ops-wgsl/shader.wgsl | 31 +++++---- 11 files changed, 160 insertions(+), 134 deletions(-) diff --git a/tests/difftests/lib/Cargo.toml b/tests/difftests/lib/Cargo.toml index afd3e21a49..fb6cd75f14 100644 --- a/tests/difftests/lib/Cargo.toml +++ b/tests/difftests/lib/Cargo.toml @@ -15,7 +15,7 @@ use-compiled-tools = [ "spirv-builder/use-compiled-tools" ] -[dependencies] +[target.'cfg(not(target_arch = "spirv"))'.dependencies] spirv-builder.workspace = true serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/tests/difftests/lib/src/lib.rs b/tests/difftests/lib/src/lib.rs index bb840e7954..1cab2cc3b4 100644 --- a/tests/difftests/lib/src/lib.rs +++ b/tests/difftests/lib/src/lib.rs @@ -1,6 +1,21 @@ +#![cfg_attr(target_arch = "spirv", no_std)] + +#[cfg(not(target_arch = "spirv"))] pub mod config; +#[cfg(not(target_arch = "spirv"))] pub mod scaffold; +/// Macro to round a f32 value to 6 decimal places for cross-platform consistency +/// in floating-point operations. This helps ensure difftest results are consistent +/// across different platforms (Linux, Mac, Windows) which may have slight differences +/// in floating-point implementations. +#[macro_export] +macro_rules! round6 { + ($v:expr) => { + (($v) * 1_000_000.0).round() / 1_000_000.0 + }; +} + #[cfg(test)] mod tests { use super::config::Config; diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/Cargo.toml b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/Cargo.toml index 6b46627ae3..86d05c7de4 100644 --- a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/Cargo.toml +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/Cargo.toml @@ -10,11 +10,9 @@ crate-type = ["dylib"] # Common deps [dependencies] - -# GPU deps spirv-std.workspace = true +difftest.workspace = true # CPU deps [target.'cfg(not(target_arch = "spirv"))'.dependencies] -difftest.workspace = true -bytemuck.workspace = true \ No newline at end of file +bytemuck.workspace = true diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs index 252f8b7b33..0af50c2eec 100644 --- a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs @@ -1,5 +1,6 @@ #![no_std] +use difftest::round6; #[allow(unused_imports)] use spirv_std::num_traits::Float; use spirv_std::spirv; @@ -24,28 +25,28 @@ pub fn main_cs( } // Basic arithmetic - output[base_offset + 0] = x + 1.5; - output[base_offset + 1] = x - 0.5; - output[base_offset + 2] = x * 2.0; - output[base_offset + 3] = x / 2.0; - output[base_offset + 4] = x % 3.0; + output[base_offset + 0] = round6!(x + 1.5); + output[base_offset + 1] = round6!(x - 0.5); + output[base_offset + 2] = round6!(x * 2.0); + output[base_offset + 3] = round6!(x / 2.0); + output[base_offset + 4] = round6!(x % 3.0); // Trigonometric functions (simplified for consistent results) - output[base_offset + 5] = x.sin(); - output[base_offset + 6] = x.cos(); - output[base_offset + 7] = x.tan().clamp(-10.0, 10.0); + output[base_offset + 5] = round6!(x.sin()); + output[base_offset + 6] = round6!(x.cos()); + output[base_offset + 7] = round6!(x.tan().clamp(-10.0, 10.0)); output[base_offset + 8] = 0.0; output[base_offset + 9] = 0.0; - output[base_offset + 10] = x.atan(); + output[base_offset + 10] = round6!(x.atan()); // Exponential and logarithmic (simplified) - output[base_offset + 11] = x.exp().min(1e6); - output[base_offset + 12] = if x > 0.0 { x.ln() } else { -10.0 }; - output[base_offset + 13] = x.abs().sqrt(); - output[base_offset + 14] = x.abs().powf(2.0); - output[base_offset + 15] = if x > 0.0 { x.log2() } else { -10.0 }; - output[base_offset + 16] = x.exp2().min(1e6); - output[base_offset + 17] = x.floor(); + output[base_offset + 11] = round6!(x.exp().min(1e6)); + output[base_offset + 12] = round6!(if x > 0.0 { x.ln() } else { -10.0 }); + output[base_offset + 13] = round6!(x.abs().sqrt()); + output[base_offset + 14] = round6!(x.abs() * x.abs()); // Use multiplication instead of powf + output[base_offset + 15] = round6!(if x > 0.0 { x.log2() } else { -10.0 }); + output[base_offset + 16] = round6!(x.exp2().min(1e6)); + output[base_offset + 17] = x.floor(); // floor/ceil/round are exact output[base_offset + 18] = x.ceil(); output[base_offset + 19] = x.round(); diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl index c16fa21137..2d7c2a4b98 100644 --- a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl @@ -4,6 +4,11 @@ var input: array; @group(0) @binding(1) var output: array; +// Helper function to round to 6 decimal places for cross-platform consistency +fn round6(v: f32) -> f32 { + return round(v * 1000000.0) / 1000000.0; +} + @compute @workgroup_size(32, 1, 1) fn main_cs(@builtin(global_invocation_id) global_id: vec3) { let tid = global_id.x; @@ -20,28 +25,28 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { } // Basic arithmetic - output[base_offset + 0u] = x + 1.5; - output[base_offset + 1u] = x - 0.5; - output[base_offset + 2u] = x * 2.0; - output[base_offset + 3u] = x / 2.0; - output[base_offset + 4u] = x % 3.0; + output[base_offset + 0u] = round6(x + 1.5); + output[base_offset + 1u] = round6(x - 0.5); + output[base_offset + 2u] = round6(x * 2.0); + output[base_offset + 3u] = round6(x / 2.0); + output[base_offset + 4u] = round6(x % 3.0); // Trigonometric functions (simplified for consistent results) - output[base_offset + 5u] = sin(x); - output[base_offset + 6u] = cos(x); - output[base_offset + 7u] = clamp(tan(x), -10.0, 10.0); + output[base_offset + 5u] = round6(sin(x)); + output[base_offset + 6u] = round6(cos(x)); + output[base_offset + 7u] = round6(clamp(tan(x), -10.0, 10.0)); output[base_offset + 8u] = 0.0; output[base_offset + 9u] = 0.0; - output[base_offset + 10u] = atan(x); + output[base_offset + 10u] = round6(atan(x)); // Exponential and logarithmic (simplified) - output[base_offset + 11u] = min(exp(x), 1e6); - output[base_offset + 12u] = select(-10.0, log(x), x > 0.0); - output[base_offset + 13u] = sqrt(abs(x)); - output[base_offset + 14u] = pow(abs(x), 2.0); - output[base_offset + 15u] = select(-10.0, log2(x), x > 0.0); - output[base_offset + 16u] = min(exp2(x), 1e6); - output[base_offset + 17u] = floor(x); + output[base_offset + 11u] = round6(min(exp(x), 1e6)); + output[base_offset + 12u] = round6(select(-10.0, log(x), x > 0.0)); + output[base_offset + 13u] = round6(sqrt(abs(x))); + output[base_offset + 14u] = round6(abs(x) * abs(x)); // Use multiplication instead of pow + output[base_offset + 15u] = round6(select(-10.0, log2(x), x > 0.0)); + output[base_offset + 16u] = round6(min(exp2(x), 1e6)); + output[base_offset + 17u] = floor(x); // floor/ceil/round are exact output[base_offset + 18u] = ceil(x); output[base_offset + 19u] = round(x); diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/Cargo.toml b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/Cargo.toml index 64ab199838..b4d32b8fcb 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/Cargo.toml +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/Cargo.toml @@ -10,11 +10,9 @@ crate-type = ["dylib"] # Common deps [dependencies] - -# GPU deps spirv-std.workspace = true +difftest.workspace = true # CPU deps [target.'cfg(not(target_arch = "spirv"))'.dependencies] -difftest.workspace = true -bytemuck.workspace = true \ No newline at end of file +bytemuck.workspace = true diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs index 76083c6de1..3450c7b4d7 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs @@ -1,5 +1,6 @@ #![no_std] +use difftest::round6; use spirv_std::glam::{Mat2, Mat3, Mat4, UVec3, Vec2, Vec3, Vec4}; #[allow(unused_imports)] use spirv_std::num_traits::Float; @@ -35,10 +36,10 @@ pub fn main_cs( // Mat2 multiplication let m2_mul = m2a * m2b; - output[base_offset + 0] = m2_mul.col(0).x; - output[base_offset + 1] = m2_mul.col(0).y; - output[base_offset + 2] = m2_mul.col(1).x; - output[base_offset + 3] = m2_mul.col(1).y; + output[base_offset + 0] = round6!(m2_mul.col(0).x); + output[base_offset + 1] = round6!(m2_mul.col(0).y); + output[base_offset + 2] = round6!(m2_mul.col(1).x); + output[base_offset + 3] = round6!(m2_mul.col(1).y); // Mat2 transpose let m2_transpose = m2a.transpose(); @@ -48,13 +49,13 @@ pub fn main_cs( output[base_offset + 7] = m2_transpose.col(1).y; // Mat2 determinant (with rounding for consistency) - output[base_offset + 8] = (m2a.determinant() * 1000.0).round() / 1000.0; + output[base_offset + 8] = round6!(m2a.determinant()); // Mat2 * Vec2 let v2 = Vec2::new(1.0, 2.0); let m2_v2 = m2a * v2; - output[base_offset + 9] = m2_v2.x; - output[base_offset + 10] = m2_v2.y; + output[base_offset + 9] = round6!(m2_v2.x); + output[base_offset + 10] = round6!(m2_v2.y); // Mat3 operations let m3a = Mat3::from_cols(Vec3::new(a, b, c), Vec3::new(b, c, d), Vec3::new(c, d, a)); @@ -62,15 +63,15 @@ pub fn main_cs( // Mat3 multiplication let m3_mul = m3a * m3b; - output[base_offset + 11] = m3_mul.col(0).x; - output[base_offset + 12] = m3_mul.col(0).y; - output[base_offset + 13] = m3_mul.col(0).z; - output[base_offset + 14] = m3_mul.col(1).x; - output[base_offset + 15] = m3_mul.col(1).y; - output[base_offset + 16] = m3_mul.col(1).z; - output[base_offset + 17] = m3_mul.col(2).x; - output[base_offset + 18] = m3_mul.col(2).y; - output[base_offset + 19] = m3_mul.col(2).z; + output[base_offset + 11] = round6!(m3_mul.col(0).x); + output[base_offset + 12] = round6!(m3_mul.col(0).y); + output[base_offset + 13] = round6!(m3_mul.col(0).z); + output[base_offset + 14] = round6!(m3_mul.col(1).x); + output[base_offset + 15] = round6!(m3_mul.col(1).y); + output[base_offset + 16] = round6!(m3_mul.col(1).z); + output[base_offset + 17] = round6!(m3_mul.col(2).x); + output[base_offset + 18] = round6!(m3_mul.col(2).y); + output[base_offset + 19] = round6!(m3_mul.col(2).z); // Mat3 transpose - store just diagonal elements let m3_transpose = m3a.transpose(); @@ -79,14 +80,14 @@ pub fn main_cs( output[base_offset + 22] = m3_transpose.col(2).z; // Mat3 determinant (with rounding for consistency) - output[base_offset + 23] = (m3a.determinant() * 1000.0).round() / 1000.0; + output[base_offset + 23] = round6!(m3a.determinant()); // Mat3 * Vec3 (with rounding for consistency) let v3 = Vec3::new(1.0, 2.0, 3.0); let m3_v3 = m3a * v3; - output[base_offset + 24] = (m3_v3.x * 10000.0).round() / 10000.0; - output[base_offset + 25] = (m3_v3.y * 10000.0).round() / 10000.0; - output[base_offset + 26] = (m3_v3.z * 10000.0).round() / 10000.0; + output[base_offset + 24] = round6!(m3_v3.x); + output[base_offset + 25] = round6!(m3_v3.y); + output[base_offset + 26] = round6!(m3_v3.z); // Mat4 operations let m4a = Mat4::from_cols( @@ -104,10 +105,10 @@ pub fn main_cs( // Mat4 multiplication (just store diagonal for brevity) let m4_mul = m4a * m4b; - output[base_offset + 27] = m4_mul.col(0).x; - output[base_offset + 28] = m4_mul.col(1).y; - output[base_offset + 29] = m4_mul.col(2).z; - output[base_offset + 30] = m4_mul.col(3).w; + output[base_offset + 27] = round6!(m4_mul.col(0).x); + output[base_offset + 28] = round6!(m4_mul.col(1).y); + output[base_offset + 29] = round6!(m4_mul.col(2).z); + output[base_offset + 30] = round6!(m4_mul.col(3).w); // Mat4 transpose (just store diagonal) let m4_transpose = m4a.transpose(); @@ -117,15 +118,15 @@ pub fn main_cs( output[base_offset + 34] = m4_transpose.col(3).w; // Mat4 determinant (with rounding for consistency) - output[base_offset + 35] = (m4a.determinant() * 1000.0).round() / 1000.0; + output[base_offset + 35] = round6!(m4a.determinant()); // Mat4 * Vec4 (with rounding for consistency) let v4 = Vec4::new(1.0, 2.0, 3.0, 4.0); let m4_v4 = m4a * v4; - output[base_offset + 36] = (m4_v4.x * 10000.0).round() / 10000.0; - output[base_offset + 37] = (m4_v4.y * 10000.0).round() / 10000.0; - output[base_offset + 38] = (m4_v4.z * 10000.0).round() / 10000.0; - output[base_offset + 39] = (m4_v4.w * 10000.0).round() / 10000.0; + output[base_offset + 36] = round6!(m4_v4.x); + output[base_offset + 37] = round6!(m4_v4.y); + output[base_offset + 38] = round6!(m4_v4.z); + output[base_offset + 39] = round6!(m4_v4.w); // Identity matrices output[base_offset + 40] = Mat2::IDENTITY.col(0).x; @@ -135,8 +136,8 @@ pub fn main_cs( // Matrix inverse if m2a.determinant().abs() > 0.0001 { let m2_inv = m2a.inverse(); - output[base_offset + 43] = m2_inv.col(0).x; - output[base_offset + 44] = m2_inv.col(1).y; + output[base_offset + 43] = round6!(m2_inv.col(0).x); + output[base_offset + 44] = round6!(m2_inv.col(1).y); } else { output[base_offset + 43] = 0.0; output[base_offset + 44] = 0.0; diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl index 7dcdadaa45..dc7bf997ee 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl @@ -4,6 +4,11 @@ var input: array; @group(0) @binding(1) var output: array; +// Helper function to round to 6 decimal places for cross-platform consistency +fn round6(v: f32) -> f32 { + return round(v * 1000000.0) / 1000000.0; +} + @compute @workgroup_size(32, 1, 1) fn main_cs(@builtin(global_invocation_id) global_id: vec3) { let tid = global_id.x; @@ -36,10 +41,10 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { // Mat2 multiplication let m2_mul = m2a * m2b; - output[base_offset + 0u] = m2_mul[0].x; - output[base_offset + 1u] = m2_mul[0].y; - output[base_offset + 2u] = m2_mul[1].x; - output[base_offset + 3u] = m2_mul[1].y; + output[base_offset + 0u] = round6(m2_mul[0].x); + output[base_offset + 1u] = round6(m2_mul[0].y); + output[base_offset + 2u] = round6(m2_mul[1].x); + output[base_offset + 3u] = round6(m2_mul[1].y); // Mat2 transpose let m2_transpose = transpose(m2a); @@ -49,13 +54,13 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { output[base_offset + 7u] = m2_transpose[1].y; // Mat2 determinant (with rounding for consistency) - output[base_offset + 8u] = round(determinant(m2a) * 1000.0) / 1000.0; + output[base_offset + 8u] = round6(determinant(m2a)); // Mat2 * Vec2 let v2 = vec2(1.0, 2.0); let m2_v2 = m2a * v2; - output[base_offset + 9u] = m2_v2.x; - output[base_offset + 10u] = m2_v2.y; + output[base_offset + 9u] = round6(m2_v2.x); + output[base_offset + 10u] = round6(m2_v2.y); // Mat3 operations let m3a = mat3x3( @@ -71,15 +76,15 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { // Mat3 multiplication let m3_mul = m3a * m3b; - output[base_offset + 11u] = m3_mul[0].x; - output[base_offset + 12u] = m3_mul[0].y; - output[base_offset + 13u] = m3_mul[0].z; - output[base_offset + 14u] = m3_mul[1].x; - output[base_offset + 15u] = m3_mul[1].y; - output[base_offset + 16u] = m3_mul[1].z; - output[base_offset + 17u] = m3_mul[2].x; - output[base_offset + 18u] = m3_mul[2].y; - output[base_offset + 19u] = m3_mul[2].z; + output[base_offset + 11u] = round6(m3_mul[0].x); + output[base_offset + 12u] = round6(m3_mul[0].y); + output[base_offset + 13u] = round6(m3_mul[0].z); + output[base_offset + 14u] = round6(m3_mul[1].x); + output[base_offset + 15u] = round6(m3_mul[1].y); + output[base_offset + 16u] = round6(m3_mul[1].z); + output[base_offset + 17u] = round6(m3_mul[2].x); + output[base_offset + 18u] = round6(m3_mul[2].y); + output[base_offset + 19u] = round6(m3_mul[2].z); // Mat3 transpose let m3_transpose = transpose(m3a); @@ -88,14 +93,14 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { output[base_offset + 22u] = m3_transpose[2].z; // Mat3 determinant (with rounding for consistency) - output[base_offset + 23u] = round(determinant(m3a) * 1000.0) / 1000.0; + output[base_offset + 23u] = round6(determinant(m3a)); // Mat3 * Vec3 (with rounding for consistency) let v3 = vec3(1.0, 2.0, 3.0); let m3_v3 = m3a * v3; - output[base_offset + 24u] = round(m3_v3.x * 10000.0) / 10000.0; - output[base_offset + 25u] = round(m3_v3.y * 10000.0) / 10000.0; - output[base_offset + 26u] = round(m3_v3.z * 10000.0) / 10000.0; + output[base_offset + 24u] = round6(m3_v3.x); + output[base_offset + 25u] = round6(m3_v3.y); + output[base_offset + 26u] = round6(m3_v3.z); // Mat4 operations let m4a = mat4x4( @@ -113,10 +118,10 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { // Mat4 multiplication (just store diagonal for brevity) let m4_mul = m4a * m4b; - output[base_offset + 27u] = m4_mul[0].x; - output[base_offset + 28u] = m4_mul[1].y; - output[base_offset + 29u] = m4_mul[2].z; - output[base_offset + 30u] = m4_mul[3].w; + output[base_offset + 27u] = round6(m4_mul[0].x); + output[base_offset + 28u] = round6(m4_mul[1].y); + output[base_offset + 29u] = round6(m4_mul[2].z); + output[base_offset + 30u] = round6(m4_mul[3].w); // Mat4 transpose (just store diagonal) let m4_transpose = transpose(m4a); @@ -126,15 +131,15 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { output[base_offset + 34u] = m4_transpose[3].w; // Mat4 determinant (with rounding for consistency) - output[base_offset + 35u] = round(determinant(m4a) * 1000.0) / 1000.0; + output[base_offset + 35u] = round6(determinant(m4a)); // Mat4 * Vec4 (with rounding for consistency) let v4 = vec4(1.0, 2.0, 3.0, 4.0); let m4_v4 = m4a * v4; - output[base_offset + 36u] = round(m4_v4.x * 10000.0) / 10000.0; - output[base_offset + 37u] = round(m4_v4.y * 10000.0) / 10000.0; - output[base_offset + 38u] = round(m4_v4.z * 10000.0) / 10000.0; - output[base_offset + 39u] = round(m4_v4.w * 10000.0) / 10000.0; + output[base_offset + 36u] = round6(m4_v4.x); + output[base_offset + 37u] = round6(m4_v4.y); + output[base_offset + 38u] = round6(m4_v4.z); + output[base_offset + 39u] = round6(m4_v4.w); // Identity matrices output[base_offset + 40u] = mat2x2(vec2(1.0, 0.0), vec2(0.0, 1.0))[0].x; @@ -150,8 +155,8 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { vec2(m2a[1].y * inv_det, -m2a[0].y * inv_det), vec2(-m2a[1].x * inv_det, m2a[0].x * inv_det) ); - output[base_offset + 43u] = m2_inv[0].x; - output[base_offset + 44u] = m2_inv[1].y; + output[base_offset + 43u] = round6(m2_inv[0].x); + output[base_offset + 44u] = round6(m2_inv[1].y); } else { output[base_offset + 43u] = 0.0; output[base_offset + 44u] = 0.0; diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/Cargo.toml b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/Cargo.toml index d0cb1246a2..19e11d7a03 100644 --- a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/Cargo.toml +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/Cargo.toml @@ -8,13 +8,10 @@ workspace = true [lib] crate-type = ["dylib"] -# Common deps [dependencies] - -# GPU deps spirv-std.workspace = true +difftest.workspace = true # CPU deps [target.'cfg(not(target_arch = "spirv"))'.dependencies] -difftest.workspace = true -bytemuck.workspace = true \ No newline at end of file +bytemuck.workspace = true diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs index 6a3c2a4e5d..2a53a8f760 100644 --- a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs @@ -1,5 +1,6 @@ #![no_std] +use difftest::round6; use spirv_std::glam::{UVec2, UVec3, UVec4, Vec2, Vec3, Vec4, Vec4Swizzles}; #[allow(unused_imports)] use spirv_std::num_traits::Float; @@ -33,9 +34,9 @@ pub fn main_cs( let v2a = Vec2::new(a, b); let v2b = Vec2::new(c, d); - output[base_offset + 0] = (v2a.dot(v2b) * 1000.0).round() / 1000.0; - output[base_offset + 1] = (v2a.length() * 1000.0).round() / 1000.0; - output[base_offset + 2] = (v2a.distance(v2b) * 1000.0).round() / 1000.0; + output[base_offset + 0] = round6!(v2a.dot(v2b)); + output[base_offset + 1] = round6!(v2a.length()); + output[base_offset + 2] = round6!(v2a.distance(v2b)); let v2_add = v2a + v2b; output[base_offset + 3] = v2_add.x; @@ -49,25 +50,25 @@ pub fn main_cs( let v3a = Vec3::new(a, b, c); let v3b = Vec3::new(b, c, d); - output[base_offset + 7] = (v3a.dot(v3b) * 1000.0).round() / 1000.0; - output[base_offset + 8] = (v3a.length() * 1000.0).round() / 1000.0; + output[base_offset + 7] = round6!(v3a.dot(v3b)); + output[base_offset + 8] = round6!(v3a.length()); let v3_cross = v3a.cross(v3b); - output[base_offset + 9] = v3_cross.x; - output[base_offset + 10] = v3_cross.y; - output[base_offset + 11] = v3_cross.z; + output[base_offset + 9] = round6!(v3_cross.x); + output[base_offset + 10] = round6!(v3_cross.y); + output[base_offset + 11] = round6!(v3_cross.z); let v3_norm = v3a.normalize(); - output[base_offset + 12] = (v3_norm.x * 1000.0).round() / 1000.0; - output[base_offset + 13] = (v3_norm.y * 1000.0).round() / 1000.0; - output[base_offset + 14] = (v3_norm.z * 1000.0).round() / 1000.0; + output[base_offset + 12] = round6!(v3_norm.x); + output[base_offset + 13] = round6!(v3_norm.y); + output[base_offset + 14] = round6!(v3_norm.z); // Vec4 operations let v4a = Vec4::new(a, b, c, d); let v4b = Vec4::new(d, c, b, a); - output[base_offset + 15] = (v4a.dot(v4b) * 1000.0).round() / 1000.0; - output[base_offset + 16] = (v4a.length() * 1000.0).round() / 1000.0; + output[base_offset + 15] = round6!(v4a.dot(v4b)); + output[base_offset + 16] = round6!(v4a.length()); let v4_sub = v4a - v4b; output[base_offset + 17] = v4_sub.x; diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl index 9393759023..809ea276d1 100644 --- a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl @@ -4,6 +4,11 @@ var input: array; @group(0) @binding(1) var output: array; +// Helper function to round to 6 decimal places for cross-platform consistency +fn round6(v: f32) -> f32 { + return round(v * 1000000.0) / 1000000.0; +} + @compute @workgroup_size(32, 1, 1) fn main_cs(@builtin(global_invocation_id) global_id: vec3) { let tid = global_id.x; @@ -28,9 +33,9 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { let v2a = vec2(a, b); let v2b = vec2(c, d); - output[base_offset + 0u] = round(dot(v2a, v2b) * 1000.0) / 1000.0; - output[base_offset + 1u] = round(length(v2a) * 1000.0) / 1000.0; - output[base_offset + 2u] = round(distance(v2a, v2b) * 1000.0) / 1000.0; + output[base_offset + 0u] = round6(dot(v2a, v2b)); + output[base_offset + 1u] = round6(length(v2a)); + output[base_offset + 2u] = round6(distance(v2a, v2b)); let v2_add = v2a + v2b; output[base_offset + 3u] = v2_add.x; @@ -44,25 +49,25 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { let v3a = vec3(a, b, c); let v3b = vec3(b, c, d); - output[base_offset + 7u] = round(dot(v3a, v3b) * 1000.0) / 1000.0; - output[base_offset + 8u] = round(length(v3a) * 1000.0) / 1000.0; + output[base_offset + 7u] = round6(dot(v3a, v3b)); + output[base_offset + 8u] = round6(length(v3a)); let v3_cross = cross(v3a, v3b); - output[base_offset + 9u] = v3_cross.x; - output[base_offset + 10u] = v3_cross.y; - output[base_offset + 11u] = v3_cross.z; + output[base_offset + 9u] = round6(v3_cross.x); + output[base_offset + 10u] = round6(v3_cross.y); + output[base_offset + 11u] = round6(v3_cross.z); let v3_norm = normalize(v3a); - output[base_offset + 12u] = round(v3_norm.x * 1000.0) / 1000.0; - output[base_offset + 13u] = round(v3_norm.y * 1000.0) / 1000.0; - output[base_offset + 14u] = round(v3_norm.z * 1000.0) / 1000.0; + output[base_offset + 12u] = round6(v3_norm.x); + output[base_offset + 13u] = round6(v3_norm.y); + output[base_offset + 14u] = round6(v3_norm.z); // Vec4 operations let v4a = vec4(a, b, c, d); let v4b = vec4(d, c, b, a); - output[base_offset + 15u] = round(dot(v4a, v4b) * 1000.0) / 1000.0; - output[base_offset + 16u] = round(length(v4a) * 1000.0) / 1000.0; + output[base_offset + 15u] = round6(dot(v4a, v4b)); + output[base_offset + 16u] = round6(length(v4a)); let v4_sub = v4a - v4b; output[base_offset + 17u] = v4_sub.x; From 79d310af287108fffed4a23fd89f8486fcdc96ae Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Sun, 6 Jul 2025 17:11:00 +0200 Subject: [PATCH 03/24] Add epsilon support and improve difftest framework - Add a round macro for cross-plat compat - Add epsilon-based floating point comparison - Add human-readable output for float/int data - Add test coverage for new stuff - Update documentation --- Cargo.lock | 131 +++- tests/difftests/README.md | 124 +++- tests/difftests/bin/Cargo.toml | 3 + tests/difftests/bin/src/differ.rs | 271 ++++++++ tests/difftests/bin/src/main.rs | 1 + tests/difftests/bin/src/runner.rs | 577 +++++++++++++++++- tests/difftests/lib/src/config.rs | 61 +- tests/difftests/lib/src/lib.rs | 19 +- .../ops/math_ops/math_ops-rust/src/lib.rs | 32 +- .../ops/math_ops/math_ops-rust/src/main.rs | 17 +- .../ops/math_ops/math_ops-wgsl/shader.wgsl | 36 +- .../ops/math_ops/math_ops-wgsl/src/main.rs | 17 +- .../ops/matrix_ops/matrix_ops-rust/src/lib.rs | 72 +-- .../matrix_ops/matrix_ops-rust/src/main.rs | 7 + .../matrix_ops/matrix_ops-wgsl/shader.wgsl | 76 +-- .../matrix_ops/matrix_ops-wgsl/src/main.rs | 7 + .../ops/vector_ops/vector_ops-rust/src/lib.rs | 44 +- .../vector_ops/vector_ops-rust/src/main.rs | 7 + .../vector_ops/vector_ops-wgsl/shader.wgsl | 48 +- .../vector_ops/vector_ops-wgsl/src/main.rs | 7 + 20 files changed, 1302 insertions(+), 255 deletions(-) create mode 100644 tests/difftests/bin/src/differ.rs diff --git a/Cargo.lock b/Cargo.lock index e28edb2bc5..0002a7c444 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -285,6 +285,12 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "bytecount" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e" + [[package]] name = "bytemuck" version = "1.23.1" @@ -302,7 +308,7 @@ checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -432,10 +438,10 @@ version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -695,7 +701,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn", + "syn 2.0.94", ] [[package]] @@ -723,9 +729,12 @@ name = "difftests" version = "0.9.0" dependencies = [ "anyhow", + "bytemuck", "bytesize", + "difftest", "serde", "serde_json", + "tabled", "tempfile", "tester", "thiserror 1.0.69", @@ -939,6 +948,12 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "foldhash" version = "0.1.4" @@ -963,7 +978,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -1037,7 +1052,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -1236,6 +1251,12 @@ dependencies = [ "foldhash", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "heck" version = "0.5.0" @@ -1783,7 +1804,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -2055,6 +2076,17 @@ dependencies = [ "ttf-parser", ] +[[package]] +name = "papergrid" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ad43c07024ef767f9160710b3a6773976194758c7919b17e63b863db0bdf7fb" +dependencies = [ + "bytecount", + "fnv", + "unicode-width", +] + [[package]] name = "parking_lot" version = "0.12.3" @@ -2119,7 +2151,7 @@ checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -2205,6 +2237,30 @@ dependencies = [ "toml_edit", ] +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.92" @@ -2615,7 +2671,7 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -2808,7 +2864,7 @@ dependencies = [ "proc-macro2", "quote", "spirv-std-types", - "syn", + "syn 2.0.94", ] [[package]] @@ -2874,11 +2930,22 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "rustversion", - "syn", + "syn 2.0.94", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", ] [[package]] @@ -2892,6 +2959,30 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tabled" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c998b0c8b921495196a48aabaf1901ff28be0760136e31604f7967b0792050e" +dependencies = [ + "papergrid", + "tabled_derive", + "unicode-width", +] + +[[package]] +name = "tabled_derive" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c138f99377e5d653a371cdad263615634cfc8467685dfe8e73e2b8e98f44b17" +dependencies = [ + "heck 0.4.1", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "tempfile" version = "3.14.0" @@ -2964,7 +3055,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -2975,7 +3066,7 @@ checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -3116,7 +3207,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -3278,7 +3369,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn", + "syn 2.0.94", "wasm-bindgen-shared", ] @@ -3313,7 +3404,7 @@ checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3745,7 +3836,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -3756,7 +3847,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] [[package]] @@ -4140,5 +4231,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.94", ] diff --git a/tests/difftests/README.md b/tests/difftests/README.md index d9d5a436dc..9e282bbb89 100644 --- a/tests/difftests/README.md +++ b/tests/difftests/README.md @@ -23,7 +23,9 @@ discrepancies across implementations. 3. **Output Comparison** - The harness reads outputs as opaque bytes. - - If outputs differ, the test fails. + - If outputs differ, the test fails with detailed error reporting. + - Tests can specify metadata to enable smarter epsilon-based comparisons and human + display of data. Because the difftest harness merely runs Rust binaries in a directory, it supports testing various setups. For example, you can: @@ -54,10 +56,13 @@ Each test binary must: 3. Load the config using `difftest::Config::from_path`. 4. Write its computed output to `output_path`. +The test binary can _optionally_ write test metadata to `metadata_path` for custom +comparison behavior. + For example: ```rust -use difftest::config::Config; +use difftest::config::{Config, TestMetadata, OutputType}; use std::{env, fs, io::Write}; fn main() { @@ -70,6 +75,13 @@ fn main() { let mut file = fs::File::create(&config.output_path) .expect("Failed to create output file"); file.write_all(&output).expect("Failed to write output"); + + // Optional: Write metadata for floating-point comparison + let metadata = TestMetadata { + epsilon: Some(0.00001), // Allow differences up to 1e-5 + output_type: OutputType::F32, // Interpret output as f32 array + }; + config.write_metadata(&metadata).expect("Failed to write metadata"); } ``` @@ -79,39 +91,70 @@ Of course, many test will have common host and GPU needs. Rather than require ev binary to reimplement functionality, we have created some common tests with reasonable defaults in the `difftest` library. -For example, this will handle compiling the current crate as a Rust compute shader, -running it via `wgpu`, and writing the output to the appropriate place: +The library provides helper types for common test patterns: -```rust -fn main() { - // Load the config from the harness. - let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); +**Test types:** - // Define test parameters, loading the rust shader from the current crate. - let test = WgpuComputeTest::new(RustComputeShader::default(), [1, 1, 1], 1024); +- `WgpuComputeTest` - Single buffer compute shader test +- `WgpuComputeTestMultiBuffer` - Multi-buffer compute shader test with input/output + separation +- `WgpuComputeTestPushConstant` - Compute shader test with push constants support - // Run the test and write the output to a file. - test.run_test(&config).unwrap(); -} -``` +**Shader source types:** -and this will handle loading a shader named `shader.wgsl` or `compute.wgsl` in the root -of the current crate, running it via `wgpu`, and writing the output to the appropriate -place: +- `RustComputeShader` - Compiles the current crate as a Rust GPU shader +- `WgslComputeShader` - Loads WGSL shader from file (shader.wgsl or compute.wgsl) -```rust -fn main() { - // Load the config from the harness. - let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); +For examples, see: - // Define test parameters, loading the wgsl shader from the crate directory. - let test = WgpuComputeTest::new(WgslComputeShader::default(), [1, 1, 1], 1024); +- [`tests/lang/core/ops/math_ops/`](tests/lang/core/ops/math_ops/) - Multi-buffer test + with floating-point metadata +- [`tests/storage_class/push_constant/`](tests/storage_class/push_constant/) - Push + constants usage +- [`tests/arch/workgroup_memory/`](tests/arch/workgroup_memory/) - Workgroup memory + usage - // Run the test and write the output to a file. - test.run_test(&config).unwrap(); -} +### Test Metadata + +Tests producing floating-point outputs can specify comparison metadata to handle +platform-specific precision differences. The metadata controls how the harness compares +outputs: + +```rust +use difftest::config::{TestMetadata, OutputType}; + +// Write metadata before or after writing output +let metadata = TestMetadata { + epsilon: Some(0.00001), // Maximum allowed difference (default: None) + output_type: OutputType::F32, // How to interpret output data (default: Raw) +}; +config.write_metadata(&metadata)?; + +// Alternative: Use the helper method for common cases +let metadata = TestMetadata::with_epsilon(0.00001); // Sets epsilon, keeps default output_type +config.write_metadata(&metadata)?; ``` +**Metadata fields:** + +- `epsilon`: Optional maximum allowed absolute difference between values. When `None` + (default), exact byte-for-byte comparison is used. When `Some(value)`, floating-point + values are compared with the specified tolerance. +- `output_type`: Specifies how to interpret output data: + - `Raw`: Exact byte comparison (default) + - `F32`: Interpret as array of 32-bit floats, enables epsilon comparison + - `F64`: Interpret as array of 64-bit floats, enables epsilon comparison + - `U32`/`I32`: Interpret as 32-bit integers (epsilon ignored) + +**Important notes:** + +- If no metadata file is written or the file is empty, the harness uses exact byte + comparison. +- All test packages must have consistent metadata. If packages specify different + `output_type` values, the test will fail with an error. +- Invalid JSON in metadata files will cause the test to fail immediately. +- The `epsilon` field is only used when `output_type` is `F32` or `F64`. + ## Running Tests ### Run all difftests: @@ -137,13 +180,32 @@ cargo difftest --nocapture ## Debugging Failing Tests -If outputs differ, the error message lists: +When outputs differ, the harness provides detailed error reporting: + +### For raw byte differences + +- Shows which packages produced different outputs +- Lists output file paths for manual inspection +- Groups packages by their output values + +Inspect the output files with your preferred tools to determine the root cause. + +### For floating-point differences (with `output_type: F32/F64`) + +Reports all of the above, plus: + +- Actual floating-point values in a comparison table +- Shows the maximum difference found +- Indicates the epsilon threshold (if specified) +- Highlights specific values that exceed the tolerance + +### Additional output files -- Binary package names -- Their directories -- Output file paths +The harness automatically writes human-readable `.txt` files alongside binary outputs. +For floating-point data (F32/F64), these show the array values in decimal format. For +raw/integer data, these show the values as hex bytes or integers -Inspect the output files with your preferred tools to determine the differences. +## Harness logs If you suspect a bug in the test harness, you can view detailed test harness logs: diff --git a/tests/difftests/bin/Cargo.toml b/tests/difftests/bin/Cargo.toml index 82b1fc2d93..e62cc51049 100644 --- a/tests/difftests/bin/Cargo.toml +++ b/tests/difftests/bin/Cargo.toml @@ -23,6 +23,9 @@ serde_json = "1.0" thiserror = "1.0" toml = { version = "0.8.20", default-features = false, features = ["parse"] } bytesize = "2.0.1" +bytemuck = "1.21.0" +difftest = { path = "../lib" } +tabled = "0.15" [lints] workspace = true diff --git a/tests/difftests/bin/src/differ.rs b/tests/difftests/bin/src/differ.rs new file mode 100644 index 0000000000..5e74ffe642 --- /dev/null +++ b/tests/difftests/bin/src/differ.rs @@ -0,0 +1,271 @@ +use difftest::config::OutputType; + +/// Trait for comparing two outputs and producing differences +pub trait OutputDiffer { + /// Compare two outputs and return a list of differences + fn compare(&self, output1: &[u8], output2: &[u8], epsilon: Option) -> Vec; + + /// Get a human-readable name for this differ + fn name(&self) -> &'static str; +} + +/// Trait for displaying differences in various formats +pub trait DifferenceDisplay { + /// Format differences as a table + fn format_table(&self, diffs: &[Difference], pkg1: &str, pkg2: &str) -> String; + + /// Format differences as a detailed report + fn format_report( + &self, + diffs: &[Difference], + pkg1: &str, + pkg2: &str, + epsilon: Option, + ) -> String; + + /// Write human-readable output to a file + fn write_human_readable(&self, output: &[u8], path: &std::path::Path) -> std::io::Result<()>; +} + +/// A single difference between two values +#[derive(Debug, Clone)] +pub struct Difference { + pub index: usize, + pub value1: String, + pub value2: String, + pub absolute_diff: f64, + pub relative_diff: Option, +} + +/// Differ for raw byte comparison +pub struct RawDiffer; + +impl OutputDiffer for RawDiffer { + fn compare(&self, output1: &[u8], output2: &[u8], _epsilon: Option) -> Vec { + if output1 == output2 { + vec![] + } else { + // For raw comparison, we just note that they differ + vec![Difference { + index: 0, + value1: format!("{} bytes", output1.len()), + value2: format!("{} bytes", output2.len()), + absolute_diff: 0.0, + relative_diff: None, + }] + } + } + + fn name(&self) -> &'static str { + "Raw Binary" + } +} + +impl DifferenceDisplay for RawDiffer { + fn format_table(&self, _diffs: &[Difference], _pkg1: &str, _pkg2: &str) -> String { + "Binary files differ".to_string() + } + + fn format_report( + &self, + _diffs: &[Difference], + pkg1: &str, + pkg2: &str, + _epsilon: Option, + ) -> String { + format!("Binary outputs from {} and {} differ", pkg1, pkg2) + } + + fn write_human_readable(&self, output: &[u8], path: &std::path::Path) -> std::io::Result<()> { + // For raw binary, write hex dump + use std::io::Write; + let mut file = std::fs::File::create(path)?; + for (i, chunk) in output.chunks(16).enumerate() { + write!(file, "{:08x}: ", i * 16)?; + for byte in chunk { + write!(file, "{:02x} ", byte)?; + } + writeln!(file)?; + } + Ok(()) + } +} + +/// Differ for f32 arrays +pub struct F32Differ; + +impl OutputDiffer for F32Differ { + fn compare(&self, output1: &[u8], output2: &[u8], epsilon: Option) -> Vec { + if output1.len() != output2.len() { + return vec![Difference { + index: 0, + value1: format!("{} bytes", output1.len()), + value2: format!("{} bytes", output2.len()), + absolute_diff: 0.0, + relative_diff: None, + }]; + } + + let floats1: Vec = output1 + .chunks_exact(4) + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(); + let floats2: Vec = output2 + .chunks_exact(4) + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(); + + let mut differences = Vec::new(); + + for (i, (&v1, &v2)) in floats1.iter().zip(floats2.iter()).enumerate() { + let diff = (v1 - v2).abs(); + // If no epsilon specified, report all differences + let threshold = epsilon.unwrap_or(0.0); + if diff > threshold { + let rel_diff = if v1.abs() > 1e-10 || v2.abs() > 1e-10 { + Some((diff as f64) / f64::max(v1.abs() as f64, v2.abs() as f64)) + } else { + None + }; + + differences.push(Difference { + index: i, + value1: format!("{:.9}", v1), + value2: format!("{:.9}", v2), + absolute_diff: diff as f64, + relative_diff: rel_diff, + }); + } + } + + differences + } + + fn name(&self) -> &'static str { + "32-bit Float" + } +} + +impl DifferenceDisplay for F32Differ { + fn format_table(&self, diffs: &[Difference], pkg1: &str, pkg2: &str) -> String { + use tabled::settings::{Alignment, Modify, Style, object::Rows}; + + if diffs.is_empty() { + return String::new(); + } + + // Extract suffix from package names (e.g., "math_ops-wgsl" -> "WGSL") + let extract_suffix = |pkg: &str| -> String { + if let Some(pos) = pkg.rfind('-') { + let suffix = &pkg[pos + 1..]; + suffix.to_uppercase() + } else { + pkg.to_string() + } + }; + + let col1_name = extract_suffix(pkg1); + let col2_name = extract_suffix(pkg2); + + // Build rows for the table + let rows: Vec> = diffs + .iter() + .take(10) + .map(|d| { + vec![ + d.index.to_string(), + d.value1.clone(), + d.value2.clone(), + format!("{:.3e}", d.absolute_diff), + d.relative_diff + .map(|r| format!("{:.2}%", r * 100.0)) + .unwrap_or_else(|| "N/A".to_string()), + ] + }) + .collect(); + + // Create a table with custom headers + let mut builder = tabled::builder::Builder::default(); + + // Add header row + builder.push_record(vec!["#", &col1_name, &col2_name, "Δ abs", "Δ %"]); + + // Add data rows + for row in &rows { + builder.push_record(row); + } + + let mut table = builder.build(); + table + .with(Style::modern()) + .with(Modify::new(Rows::first()).with(Alignment::center())); + + let mut result = table.to_string(); + + if diffs.len() > 10 { + // Get the width of the last line to properly align the text + let last_line_width = result + .lines() + .last() + .map(|l| l.chars().count()) + .unwrap_or(0); + result.push_str(&format!( + "\n{:>width$}", + format!("... {} more differences", diffs.len() - 10), + width = last_line_width + )); + } + + result + } + + fn format_report( + &self, + diffs: &[Difference], + pkg1: &str, + pkg2: &str, + _epsilon: Option, + ) -> String { + self.format_table(diffs, pkg1, pkg2) + } + + fn write_human_readable(&self, output: &[u8], path: &std::path::Path) -> std::io::Result<()> { + use std::io::Write; + let mut file = std::fs::File::create(path)?; + + let floats: Vec = output + .chunks_exact(4) + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(); + + for (i, &value) in floats.iter().enumerate() { + writeln!(file, "{}: {:.9}", i, value)?; + } + + Ok(()) + } +} + +impl From for Box { + fn from(output_type: OutputType) -> Self { + match output_type { + OutputType::Raw => Box::new(RawDiffer), + OutputType::F32 => Box::new(F32Differ), + OutputType::F64 => todo!("F64Differ not implemented yet"), + OutputType::U32 => todo!("U32Differ not implemented yet"), + OutputType::I32 => todo!("I32Differ not implemented yet"), + } + } +} + +impl From for Box { + fn from(output_type: OutputType) -> Self { + match output_type { + OutputType::Raw => Box::new(RawDiffer), + OutputType::F32 => Box::new(F32Differ), + OutputType::F64 => todo!("F64Display not implemented yet"), + OutputType::U32 => todo!("U32Display not implemented yet"), + OutputType::I32 => todo!("I32Display not implemented yet"), + } + } +} diff --git a/tests/difftests/bin/src/main.rs b/tests/difftests/bin/src/main.rs index 1ede9444b8..58614d5a8b 100644 --- a/tests/difftests/bin/src/main.rs +++ b/tests/difftests/bin/src/main.rs @@ -12,6 +12,7 @@ use tester::{ }; use tracing_subscriber::FmtSubscriber; +mod differ; mod runner; use runner::Runner; diff --git a/tests/difftests/bin/src/runner.rs b/tests/difftests/bin/src/runner.rs index fa56325172..e3990ff304 100644 --- a/tests/difftests/bin/src/runner.rs +++ b/tests/difftests/bin/src/runner.rs @@ -1,4 +1,5 @@ use bytesize::ByteSize; +use difftest::config::{OutputType, TestMetadata}; use serde::{Deserialize, Serialize}; use std::{ collections::{HashMap, HashSet}, @@ -11,6 +12,8 @@ use tempfile::NamedTempFile; use thiserror::Error; use tracing::{debug, error, info, trace}; +use crate::differ::{Difference, DifferenceDisplay, OutputDiffer}; + #[derive(Debug, Error)] pub enum RunnerError { #[error("I/O error: {source}")] @@ -54,6 +57,7 @@ pub type RunnerResult = std::result::Result; #[derive(Debug, Serialize)] pub struct HarnessConfig { pub output_path: PathBuf, + pub metadata_path: PathBuf, } #[derive(Deserialize)] @@ -73,6 +77,129 @@ struct PackageOutput { temp_path: PathBuf, } +struct TestInfo { + name: String, + path: String, +} + +struct ErrorReport { + lines: Vec, + test_info: TestInfo, + summary_parts: Vec, +} + +impl ErrorReport { + fn new(test_info: TestInfo, differ_name: &'static str, epsilon: Option) -> Self { + let epsilon_str = match epsilon { + Some(e) => format!(", ε={}", e), + None => String::new(), + }; + Self { + lines: Vec::new(), + test_info, + summary_parts: vec![format!("{}{}", differ_name, epsilon_str)], + } + } + + fn set_summary_from_differences(&mut self, differences: &[Difference]) { + if !differences.is_empty() { + let (max_diff, max_rel) = self.calculate_max_differences(differences); + self.summary_parts + .push(format!("{} differences", differences.len())); + self.summary_parts + .push(format!("max: {:.3e} ({:.2}%)", max_diff, max_rel * 100.0)); + } + } + + fn set_distinct_outputs(&mut self, count: usize) { + self.summary_parts + .push(format!("{} distinct outputs", count)); + } + + fn add_output_files( + &mut self, + groups: &HashMap, Vec<&PackageOutput>>, + pkg_outputs: &[PackageOutput], + ) { + if groups.len() <= 5 { + for (output_bytes, group) in groups { + let names: Vec<&str> = group.iter().map(|po| po.pkg_name.as_str()).collect(); + self.lines.push(format!("▪ {}", names.join(", "))); + self.lines.push(format!( + " → Raw: {} ({:.1})", + group[0].temp_path.display(), + ByteSize::b(output_bytes.len() as u64) + )); + let text_path = group[0].temp_path.with_extension("txt"); + self.lines + .push(format!(" → Text: {}", text_path.display())); + self.lines.push("".to_string()); + } + } else { + for po in pkg_outputs { + self.lines.push(format!("▪ {}", po.pkg_name)); + self.lines.push(format!( + " → Raw: {} ({:.1})", + po.temp_path.display(), + ByteSize::b(po.output.len() as u64) + )); + let text_path = po.temp_path.with_extension("txt"); + self.lines + .push(format!(" → Text: {}", text_path.display())); + self.lines.push("".to_string()); + } + } + } + + fn add_comparison_table(&mut self, table: String) { + self.lines.push(table); + } + + fn add_summary_line(&mut self, differences: &[Difference]) { + if !differences.is_empty() { + let (max_diff, max_rel) = self.calculate_max_differences(differences); + self.lines.push(format!( + "• {} differences, max: {:.3e} ({:.2}%)", + differences.len(), + max_diff, + max_rel * 100.0 + )); + } + } + + fn calculate_max_differences(&self, differences: &[Difference]) -> (f64, f64) { + let max_diff = differences + .iter() + .map(|d| d.absolute_diff) + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap_or(0.0); + let max_rel = differences + .iter() + .filter_map(|d| d.relative_diff) + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap_or(0.0); + (max_diff, max_rel) + } + + fn build(self) -> String { + let mut result = Vec::new(); + + // Header + result.push(format!( + "\x1b[1m{} ({})\x1b[0m", + self.test_info.name, self.test_info.path + )); + result.push(self.summary_parts.join(" • ")); + result.push("─".repeat(65)); + result.push("".to_string()); + + // Body + result.extend(self.lines); + + result.join("\n") + } +} + #[derive(Clone)] pub struct Runner { pub base_dir: PathBuf, @@ -110,6 +237,8 @@ impl Runner { let mut temp_files: Vec = Vec::with_capacity(packages.len()); let mut pkg_outputs: Vec = Vec::with_capacity(packages.len()); + let mut epsilon: Option = None; + let mut output_type = None; for package in packages { trace!("Processing package at {}", package.display()); @@ -120,13 +249,23 @@ impl Runner { let output_file = NamedTempFile::new()?; let temp_output_path = output_file.path().to_path_buf(); temp_files.push(output_file); + + let metadata_file = NamedTempFile::new()?; + let temp_metadata_path = metadata_file.path().to_path_buf(); + temp_files.push(metadata_file); + trace!( "Temporary output file created at {}", temp_output_path.display() ); + trace!( + "Temporary metadata file created at {}", + temp_metadata_path.display() + ); let config = HarnessConfig { output_path: temp_output_path.clone(), + metadata_path: temp_metadata_path.clone(), }; let config_json = serde_json::to_string(&config) .map_err(|e| RunnerError::Config { msg: e.to_string() })?; @@ -173,12 +312,64 @@ impl Runner { }); } - let output_bytes = fs::read(temp_files.last().unwrap().path())?; + let output_bytes = fs::read(&temp_output_path)?; debug!( "Read {} bytes of output for package '{}'", output_bytes.len(), pkg_name ); + + // Try to read metadata file + if let Ok(metadata_content) = fs::read_to_string(&temp_metadata_path) { + if !metadata_content.trim().is_empty() { + match serde_json::from_str::(&metadata_content) { + Ok(metadata) => { + if let Some(meta_epsilon) = metadata.epsilon { + epsilon = match epsilon { + Some(e) => Some(e.max(meta_epsilon)), + None => Some(meta_epsilon), + }; + debug!( + "Found epsilon {} in metadata for package '{}'", + meta_epsilon, pkg_name + ); + } + + if output_type.is_none() { + output_type = Some(metadata.output_type); + } else if output_type != Some(metadata.output_type) { + error!("Inconsistent output types across packages"); + return Err(RunnerError::Config { + msg: format!( + "Package '{}' has output type {:?}, but previous packages have {:?}", + pkg_name, metadata.output_type, output_type + ), + }); + } + } + Err(e) => { + error!("Failed to parse metadata for package '{}'", pkg_name); + return Err(RunnerError::Config { + msg: format!( + "Failed to parse metadata for package '{}': {}", + pkg_name, e + ), + }); + } + } + } else { + debug!( + "Empty metadata file for package '{}', using defaults", + pkg_name + ); + } + } else { + debug!( + "No metadata file for package '{}', using defaults", + pkg_name + ); + } + pkg_outputs.push(PackageOutput { pkg_name, package_path: package, @@ -192,9 +383,25 @@ impl Runner { return Err(RunnerError::EmptyOutput); } - let groups = self.group_outputs(&pkg_outputs); + let output_type = output_type.unwrap_or(OutputType::Raw); + let groups = self.group_outputs(&pkg_outputs, epsilon, output_type); if groups.len() > 1 { - let details = self.format_group_details(&pkg_outputs); + let differ: Box = output_type.into(); + let display: Box = output_type.into(); + + // Write human-readable outputs + for po in &pkg_outputs { + let text_path = po.temp_path.with_extension("txt"); + if let Err(e) = display.write_human_readable(&po.output, &text_path) { + debug!("Failed to write human-readable output: {}", e); + } else { + info!("Wrote human-readable output to {}", text_path.display()); + } + } + + // Generate detailed error report + let details = + self.format_error(&pkg_outputs, epsilon, output_type, &*differ, &*display); self.keep_temp_files(&mut temp_files); return Err(RunnerError::DifferingOutput(details)); } @@ -223,42 +430,207 @@ impl Runner { fn group_outputs<'a>( &self, pkg_outputs: &'a [PackageOutput], + epsilon: Option, + output_type: OutputType, ) -> HashMap, Vec<&'a PackageOutput>> { let mut groups: HashMap, Vec<&'a PackageOutput>> = HashMap::new(); + + // If no epsilon specified or type is Raw with epsilon 0, use exact byte comparison + if epsilon.is_none() || (epsilon == Some(0.0) && output_type == OutputType::Raw) { + for po in pkg_outputs { + groups.entry(po.output.clone()).or_default().push(po); + } + return groups; + } + + // Otherwise, group outputs that are within epsilon of each other for po in pkg_outputs { - groups.entry(po.output.clone()).or_default().push(po); + let mut found_group = false; + + for (group_output, group) in groups.iter_mut() { + if self.outputs_match(&po.output, group_output, epsilon, output_type) { + group.push(po); + found_group = true; + break; + } + } + + if !found_group { + groups.insert(po.output.clone(), vec![po]); + } } + groups } - fn format_group_details(&self, pkg_outputs: &[PackageOutput]) -> String { - let groups = self.group_outputs(pkg_outputs); - const TOTAL_WIDTH: usize = 50; - let mut details = Vec::with_capacity(groups.len() * 4); - for (i, (output, group)) in groups.iter().enumerate() { - let group_index = i + 1; - let header = format!( - "╭─ Output {} ({}) ", - group_index, - ByteSize::b(output.len() as u64) - ); - let header = if header.len() < TOTAL_WIDTH { - format!("{}{}", header, "─".repeat(TOTAL_WIDTH - header.len())) - } else { - header - }; - details.push(header); - for po in group { - let p = po - .package_path - .strip_prefix(self.base_dir.parent().expect("base_dir is not root")) - .expect("base_path is not a prefix of package_path"); - details.push(format!("│ {} ({})", po.pkg_name, p.display())); + fn outputs_match( + &self, + output1: &[u8], + output2: &[u8], + epsilon: Option, + output_type: OutputType, + ) -> bool { + if output1.len() != output2.len() { + return false; + } + + match output_type { + OutputType::Raw => output1 == output2, + OutputType::F32 => { + if output1.len() % 4 != 0 { + return false; + } + + match epsilon { + None => output1 == output2, // Exact comparison if no epsilon + Some(eps) => { + let floats1: Vec = output1 + .chunks_exact(4) + .map(|chunk| { + f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) + }) + .collect(); + + let floats2: Vec = output2 + .chunks_exact(4) + .map(|chunk| { + f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) + }) + .collect(); + + floats1 + .iter() + .zip(floats2.iter()) + .all(|(a, b)| (a - b).abs() <= eps) + } + } + } + OutputType::F64 => { + if output1.len() % 8 != 0 { + return false; + } + + match epsilon { + None => output1 == output2, // Exact comparison if no epsilon + Some(eps) => { + let floats1: Vec = output1 + .chunks_exact(8) + .map(|chunk| { + f64::from_le_bytes([ + chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], + chunk[6], chunk[7], + ]) + }) + .collect(); + + let floats2: Vec = output2 + .chunks_exact(8) + .map(|chunk| { + f64::from_le_bytes([ + chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], + chunk[6], chunk[7], + ]) + }) + .collect(); + + floats1 + .iter() + .zip(floats2.iter()) + .all(|(a, b)| (a - b).abs() <= eps as f64) + } + } + } + OutputType::U32 | OutputType::I32 => { + // For integer types, epsilon doesn't make sense, so exact match + output1 == output2 } - let footer = format!("╰──▶ {} \n", group[0].temp_path.display()); - details.push(footer); } - details.join("\n") + } + + fn format_error( + &self, + pkg_outputs: &[PackageOutput], + epsilon: Option, + output_type: OutputType, + differ: &dyn OutputDiffer, + display: &dyn DifferenceDisplay, + ) -> String { + let test_info = self.extract_test_info(pkg_outputs); + let groups = self.group_outputs(pkg_outputs, epsilon, output_type); + + let mut report = ErrorReport::new(test_info, differ.name(), epsilon); + + // Analyze the differences + match groups.len() { + 0 => unreachable!("No output groups"), + 1 => unreachable!("All outputs match - shouldn't be an error"), + 2 if pkg_outputs.len() == 2 => { + // Exactly 2 outputs that differ + let differences = + differ.compare(&pkg_outputs[0].output, &pkg_outputs[1].output, epsilon); + report.set_summary_from_differences(&differences); + } + 2 => { + // Multiple outputs, 2 distinct groups + let group_vec: Vec<_> = groups.values().collect(); + let differences = + differ.compare(&group_vec[0][0].output, &group_vec[1][0].output, epsilon); + report.set_summary_from_differences(&differences); + } + n => { + // Many distinct outputs + report.set_distinct_outputs(n); + } + } + + // Format output files + report.add_output_files(&groups, pkg_outputs); + + // Add detailed comparison if applicable + if groups.len() == 2 && pkg_outputs.len() == 2 { + let differences = + differ.compare(&pkg_outputs[0].output, &pkg_outputs[1].output, epsilon); + if !differences.is_empty() { + let table = display.format_report( + &differences, + &pkg_outputs[0].pkg_name, + &pkg_outputs[1].pkg_name, + epsilon, + ); + report.add_comparison_table(table); + } + } else if groups.len() == 2 && pkg_outputs.len() > 2 { + let group_vec: Vec<_> = groups.values().collect(); + let differences = + differ.compare(&group_vec[0][0].output, &group_vec[1][0].output, epsilon); + report.add_summary_line(&differences); + } + + report.build() + } + + fn extract_test_info(&self, pkg_outputs: &[PackageOutput]) -> TestInfo { + if pkg_outputs.is_empty() { + return TestInfo { + name: "unknown".to_string(), + path: "unknown".to_string(), + }; + } + + let test_path = pkg_outputs[0] + .package_path + .parent() + .unwrap_or(&pkg_outputs[0].package_path); + let relative_path = test_path.strip_prefix(&self.base_dir).unwrap_or(test_path); + + TestInfo { + name: test_path + .file_name() + .unwrap_or_else(|| std::ffi::OsStr::new("unknown")) + .to_string_lossy() + .to_string(), + path: relative_path.display().to_string(), + } } #[allow(clippy::unused_self)] @@ -343,6 +715,7 @@ pub fn forward_features(cmd: &mut Command) { #[cfg(test)] mod tests { use super::*; + use difftest::config::OutputType; use std::{fs, io::Write, path::Path, path::PathBuf}; use tempfile::{NamedTempFile, tempdir}; @@ -362,7 +735,7 @@ mod tests { let pkg3 = dummy_package_output("baz", "/path/to/baz", b"hello", "tmp3"); let outputs = vec![pkg1, pkg2, pkg3]; let runner = Runner::new(PathBuf::from("dummy_base")); - let groups = runner.group_outputs(&outputs); + let groups = runner.group_outputs(&outputs, None, OutputType::Raw); assert_eq!(groups.len(), 2); } @@ -462,4 +835,146 @@ mod tests { _ => panic!("Expected DuplicatePackageName error"), } } + + #[test] + fn test_outputs_match_no_epsilon() { + let runner = Runner::new(PathBuf::from("dummy_base")); + + // Exact match should work + assert!(runner.outputs_match(b"hello", b"hello", None, OutputType::Raw)); + + // Different content should not match + assert!(!runner.outputs_match(b"hello", b"world", None, OutputType::Raw)); + } + + #[test] + fn test_outputs_match_with_epsilon_f32() { + let runner = Runner::new(PathBuf::from("dummy_base")); + + // Prepare test data - two floats with small difference + let val1: f32 = 1.0; + let val2: f32 = 1.00001; + let arr1 = [val1]; + let arr2 = [val2]; + let bytes1 = bytemuck::cast_slice(&arr1); + let bytes2 = bytemuck::cast_slice(&arr2); + + // Should not match without epsilon + assert!(!runner.outputs_match(bytes1, bytes2, None, OutputType::F32)); + + // Should match with sufficient epsilon + assert!(runner.outputs_match(bytes1, bytes2, Some(0.0001), OutputType::F32)); + + // Should not match with too small epsilon + assert!(!runner.outputs_match(bytes1, bytes2, Some(0.000001), OutputType::F32)); + } + + #[test] + fn test_outputs_match_with_epsilon_f64() { + let runner = Runner::new(PathBuf::from("dummy_base")); + + // Prepare test data - two doubles with small difference + let val1: f64 = 1.0; + let val2: f64 = 1.00001; + let arr1 = [val1]; + let arr2 = [val2]; + let bytes1 = bytemuck::cast_slice(&arr1); + let bytes2 = bytemuck::cast_slice(&arr2); + + // Should not match without epsilon + assert!(!runner.outputs_match(bytes1, bytes2, None, OutputType::F64)); + + // Should match with sufficient epsilon + assert!(runner.outputs_match(bytes1, bytes2, Some(0.0001), OutputType::F64)); + + // Should not match with too small epsilon + assert!(!runner.outputs_match(bytes1, bytes2, Some(0.000001), OutputType::F64)); + } + + #[test] + fn test_group_outputs_with_epsilon() { + let runner = Runner::new(PathBuf::from("dummy_base")); + + // Create float outputs with small differences + let val1: f32 = 1.0; + let val2: f32 = 1.00001; + let val3: f32 = 2.0; + + let pkg1 = + dummy_package_output("foo", "/path/to/foo", bytemuck::cast_slice(&[val1]), "tmp1"); + let pkg2 = + dummy_package_output("bar", "/path/to/bar", bytemuck::cast_slice(&[val2]), "tmp2"); + let pkg3 = + dummy_package_output("baz", "/path/to/baz", bytemuck::cast_slice(&[val3]), "tmp3"); + + let outputs = vec![pkg1, pkg2, pkg3]; + + // Without epsilon, val1 and val2 should be in different groups + let groups = runner.group_outputs(&outputs, None, OutputType::F32); + assert_eq!(groups.len(), 3); + + // With epsilon, val1 and val2 should be in the same group + let groups_with_epsilon = runner.group_outputs(&outputs, Some(0.0001), OutputType::F32); + assert_eq!(groups_with_epsilon.len(), 2); + } + + #[test] + fn test_invalid_metadata_json() { + // Test that invalid JSON in metadata file causes proper error + let metadata_content = "{ invalid json }"; + let result: Result = + serde_json::from_str(metadata_content); + assert!(result.is_err()); + // Just check that it's an error, don't check the specific message + } + + #[test] + fn test_inconsistent_output_types() { + // This test verifies that when packages have different output types, + // the runner returns an error. This tests the code at line 340-347 + // where we check: if output_type != Some(metadata.output_type) + + // We can't easily test the full run_test_case flow without real binaries, + // but we can at least verify the error is constructed properly + + // Test that the error message is properly formatted + let error = RunnerError::Config { + msg: format!( + "Package '{}' has output type {:?}, but previous packages have {:?}", + "test_pkg", + OutputType::F32, + Some(OutputType::F64) + ), + }; + + match error { + RunnerError::Config { msg } => { + assert!(msg.contains("test_pkg")); + assert!(msg.contains("F32")); + assert!(msg.contains("F64")); + } + _ => panic!("Wrong error type"), + } + } + + #[test] + fn test_metadata_parsing_error_message() { + // Test that metadata parsing errors are formatted correctly + + let error = RunnerError::Config { + msg: format!( + "Failed to parse metadata for package '{}': {}", + "test_pkg", "invalid JSON" + ), + }; + + match error { + RunnerError::Config { msg } => { + assert!(msg.contains("Failed to parse metadata")); + assert!(msg.contains("test_pkg")); + assert!(msg.contains("invalid JSON")); + } + _ => panic!("Wrong error type"), + } + } } diff --git a/tests/difftests/lib/src/config.rs b/tests/difftests/lib/src/config.rs index ed717502f8..51fd61d9f5 100644 --- a/tests/difftests/lib/src/config.rs +++ b/tests/difftests/lib/src/config.rs @@ -1,9 +1,59 @@ -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::{fs, path::Path}; -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct Config { pub output_path: std::path::PathBuf, + pub metadata_path: std::path::PathBuf, +} + +/// Test metadata that controls output comparison behavior +/// +/// This metadata is written alongside test output to specify how the test harness +/// should compare outputs between different implementations (e.g., Rust vs WGSL). +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct TestMetadata { + /// Maximum allowed difference for floating-point comparisons. + /// + /// When None (default), exact byte-for-byte comparison is used. + /// For floating-point tests, a small epsilon (e.g., 1e-6) accounts for + /// platform-specific precision differences. + #[serde(default)] + pub epsilon: Option, + + /// Specifies how to interpret and compare output data. + /// + /// Defaults to `Raw` for exact byte comparison. Use typed variants + /// (F32, F64, etc.) to enable epsilon-based comparison for numeric types. + #[serde(default)] + pub output_type: OutputType, +} + +/// Specifies how test output data should be interpreted for comparison +#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum OutputType { + /// Exact byte-for-byte comparison (default) + #[default] + Raw, + /// Interpret as array of 32-bit floats, enables epsilon comparison + F32, + /// Interpret as array of 64-bit floats, enables epsilon comparison + F64, + /// Interpret as array of 32-bit unsigned integers + U32, + /// Interpret as array of 32-bit signed integers + I32, +} + +impl TestMetadata { + /// Create metadata with a specific epsilon value, keeping default output type + pub fn with_epsilon(epsilon: f32) -> Self { + Self { + epsilon: Some(epsilon), + ..Default::default() + } + } } impl Config { @@ -12,4 +62,11 @@ impl Config { let config = serde_json::from_str(&content)?; Ok(config) } + + /// Write test metadata to the configured metadata path + pub fn write_metadata(&self, metadata: &TestMetadata) -> anyhow::Result<()> { + let metadata_json = serde_json::to_string(metadata)?; + fs::write(&self.metadata_path, metadata_json)?; + Ok(()) + } } diff --git a/tests/difftests/lib/src/lib.rs b/tests/difftests/lib/src/lib.rs index 1cab2cc3b4..2ef8552858 100644 --- a/tests/difftests/lib/src/lib.rs +++ b/tests/difftests/lib/src/lib.rs @@ -5,14 +5,17 @@ pub mod config; #[cfg(not(target_arch = "spirv"))] pub mod scaffold; -/// Macro to round a f32 value to 6 decimal places for cross-platform consistency -/// in floating-point operations. This helps ensure difftest results are consistent -/// across different platforms (Linux, Mac, Windows) which may have slight differences -/// in floating-point implementations. +/// Macro to round a f32 value for cross-platform compatibility in floating-point +/// operations. This helps ensure difftest results are consistent across different +/// platforms (Linux, Mac, Windows) which may have slight differences in floating-point +/// implementations due to different FMA usage, operation ordering, etc. +/// +/// We round to 5 decimal places to handle differences that can appear in the 6th-7th +/// decimal places due to platform variations. #[macro_export] -macro_rules! round6 { +macro_rules! compat_round { ($v:expr) => { - (($v) * 1_000_000.0).round() / 1_000_000.0 + (($v) * 100_000.0).round() / 100_000.0 }; } @@ -25,9 +28,11 @@ mod tests { #[test] fn test_config_from_path() { let mut tmp = NamedTempFile::new().unwrap(); - let config_json = r#"{ "output_path": "/tmp/output.txt" }"#; + let config_json = + r#"{ "output_path": "/tmp/output.txt", "metadata_path": "/tmp/metadata.json" }"#; write!(tmp, "{config_json}").unwrap(); let config = Config::from_path(tmp.path()).unwrap(); assert_eq!(config.output_path.to_str().unwrap(), "/tmp/output.txt"); + assert_eq!(config.metadata_path.to_str().unwrap(), "/tmp/metadata.json"); } } diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs index 0af50c2eec..1703e2a036 100644 --- a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs @@ -1,6 +1,6 @@ #![no_std] -use difftest::round6; +use difftest::compat_round; #[allow(unused_imports)] use spirv_std::num_traits::Float; use spirv_std::spirv; @@ -25,27 +25,27 @@ pub fn main_cs( } // Basic arithmetic - output[base_offset + 0] = round6!(x + 1.5); - output[base_offset + 1] = round6!(x - 0.5); - output[base_offset + 2] = round6!(x * 2.0); - output[base_offset + 3] = round6!(x / 2.0); - output[base_offset + 4] = round6!(x % 3.0); + output[base_offset + 0] = compat_round!(x + 1.5); + output[base_offset + 1] = compat_round!(x - 0.5); + output[base_offset + 2] = compat_round!(x * 2.0); + output[base_offset + 3] = compat_round!(x / 2.0); + output[base_offset + 4] = compat_round!(x % 3.0); // Trigonometric functions (simplified for consistent results) - output[base_offset + 5] = round6!(x.sin()); - output[base_offset + 6] = round6!(x.cos()); - output[base_offset + 7] = round6!(x.tan().clamp(-10.0, 10.0)); + output[base_offset + 5] = compat_round!(x.sin()); + output[base_offset + 6] = compat_round!(x.cos()); + output[base_offset + 7] = compat_round!(x.tan().clamp(-10.0, 10.0)); output[base_offset + 8] = 0.0; output[base_offset + 9] = 0.0; - output[base_offset + 10] = round6!(x.atan()); + output[base_offset + 10] = compat_round!(x.atan()); // Exponential and logarithmic (simplified) - output[base_offset + 11] = round6!(x.exp().min(1e6)); - output[base_offset + 12] = round6!(if x > 0.0 { x.ln() } else { -10.0 }); - output[base_offset + 13] = round6!(x.abs().sqrt()); - output[base_offset + 14] = round6!(x.abs() * x.abs()); // Use multiplication instead of powf - output[base_offset + 15] = round6!(if x > 0.0 { x.log2() } else { -10.0 }); - output[base_offset + 16] = round6!(x.exp2().min(1e6)); + output[base_offset + 11] = compat_round!(x.exp().min(1048576.0)); // 2^20 + output[base_offset + 12] = compat_round!(if x > 0.0 { x.ln() } else { -10.0 }); + output[base_offset + 13] = compat_round!(x.abs().sqrt()); + output[base_offset + 14] = compat_round!(x.abs() * x.abs()); // Use multiplication instead of powf + output[base_offset + 15] = compat_round!(if x > 0.0 { x.log2() } else { -10.0 }); + output[base_offset + 16] = compat_round!(x.exp2().min(1048576.0)); // 2^20 output[base_offset + 17] = x.floor(); // floor/ceil/round are exact output[base_offset + 18] = x.ceil(); output[base_offset + 19] = x.round(); diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs index 2034b36d85..f7a3cb3958 100644 --- a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs @@ -19,14 +19,14 @@ fn main() { 6 => -2.0, 7 => std::f32::consts::PI, 8 => std::f32::consts::E, - 9 => 10.0, - 10 => -10.0, + 9 => 3.0, + 10 => -3.0, 11 => 0.1, 12 => -0.1, - 13 => 100.0, - 14 => -100.0, + 13 => 4.0, + 14 => -4.0, 15 => 3.14159, - _ => (i as f32) * 0.7 - 5.0, + _ => (i as f32) * 0.1 - 1.5, }) .collect(); @@ -51,6 +51,13 @@ fn main() { buffers, ); + // Write metadata file + let metadata = difftest::config::TestMetadata { + epsilon: Some(2e-6), // Small epsilon for last-bit differences + output_type: difftest::config::OutputType::F32, + }; + config.write_metadata(&metadata).unwrap(); + test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl index 2d7c2a4b98..c143e8a9df 100644 --- a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl @@ -4,9 +4,9 @@ var input: array; @group(0) @binding(1) var output: array; -// Helper function to round to 6 decimal places for cross-platform consistency -fn round6(v: f32) -> f32 { - return round(v * 1000000.0) / 1000000.0; +// Helper function to round to 5 decimal places for cross-platform compatibility +fn compat_round(v: f32) -> f32 { + return round(v * 100000.0) / 100000.0; } @compute @workgroup_size(32, 1, 1) @@ -25,27 +25,27 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { } // Basic arithmetic - output[base_offset + 0u] = round6(x + 1.5); - output[base_offset + 1u] = round6(x - 0.5); - output[base_offset + 2u] = round6(x * 2.0); - output[base_offset + 3u] = round6(x / 2.0); - output[base_offset + 4u] = round6(x % 3.0); + output[base_offset + 0u] = compat_round(x + 1.5); + output[base_offset + 1u] = compat_round(x - 0.5); + output[base_offset + 2u] = compat_round(x * 2.0); + output[base_offset + 3u] = compat_round(x / 2.0); + output[base_offset + 4u] = compat_round(x % 3.0); // Trigonometric functions (simplified for consistent results) - output[base_offset + 5u] = round6(sin(x)); - output[base_offset + 6u] = round6(cos(x)); - output[base_offset + 7u] = round6(clamp(tan(x), -10.0, 10.0)); + output[base_offset + 5u] = compat_round(sin(x)); + output[base_offset + 6u] = compat_round(cos(x)); + output[base_offset + 7u] = compat_round(clamp(tan(x), -10.0, 10.0)); output[base_offset + 8u] = 0.0; output[base_offset + 9u] = 0.0; - output[base_offset + 10u] = round6(atan(x)); + output[base_offset + 10u] = compat_round(atan(x)); // Exponential and logarithmic (simplified) - output[base_offset + 11u] = round6(min(exp(x), 1e6)); - output[base_offset + 12u] = round6(select(-10.0, log(x), x > 0.0)); - output[base_offset + 13u] = round6(sqrt(abs(x))); - output[base_offset + 14u] = round6(abs(x) * abs(x)); // Use multiplication instead of pow - output[base_offset + 15u] = round6(select(-10.0, log2(x), x > 0.0)); - output[base_offset + 16u] = round6(min(exp2(x), 1e6)); + output[base_offset + 11u] = compat_round(min(exp(x), 1048576.0)); // 2^20 + output[base_offset + 12u] = compat_round(select(-10.0, log(x), x > 0.0)); + output[base_offset + 13u] = compat_round(sqrt(abs(x))); + output[base_offset + 14u] = compat_round(abs(x) * abs(x)); // Use multiplication instead of pow + output[base_offset + 15u] = compat_round(select(-10.0, log2(x), x > 0.0)); + output[base_offset + 16u] = compat_round(min(exp2(x), 1048576.0)); // 2^20 output[base_offset + 17u] = floor(x); // floor/ceil/round are exact output[base_offset + 18u] = ceil(x); output[base_offset + 19u] = round(x); diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs index 6e8bd6273d..09ecdd545c 100644 --- a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs @@ -18,14 +18,14 @@ fn main() { 6 => -2.0, 7 => std::f32::consts::PI, 8 => std::f32::consts::E, - 9 => 10.0, - 10 => -10.0, + 9 => 3.0, + 10 => -3.0, 11 => 0.1, 12 => -0.1, - 13 => 100.0, - 14 => -100.0, + 13 => 4.0, + 14 => -4.0, 15 => 3.14159, - _ => (i as f32) * 0.7 - 5.0, + _ => (i as f32) * 0.1 - 1.5, }) .collect(); @@ -50,5 +50,12 @@ fn main() { buffers, ); + // Write metadata file + let metadata = difftest::config::TestMetadata { + epsilon: Some(2e-6), // Small epsilon for last-bit differences + output_type: difftest::config::OutputType::F32, + }; + config.write_metadata(&metadata).unwrap(); + test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs index 3450c7b4d7..2f527c134c 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs @@ -1,6 +1,6 @@ #![no_std] -use difftest::round6; +use difftest::compat_round; use spirv_std::glam::{Mat2, Mat3, Mat4, UVec3, Vec2, Vec3, Vec4}; #[allow(unused_imports)] use spirv_std::num_traits::Float; @@ -36,10 +36,10 @@ pub fn main_cs( // Mat2 multiplication let m2_mul = m2a * m2b; - output[base_offset + 0] = round6!(m2_mul.col(0).x); - output[base_offset + 1] = round6!(m2_mul.col(0).y); - output[base_offset + 2] = round6!(m2_mul.col(1).x); - output[base_offset + 3] = round6!(m2_mul.col(1).y); + output[base_offset + 0] = compat_round!(m2_mul.col(0).x); + output[base_offset + 1] = compat_round!(m2_mul.col(0).y); + output[base_offset + 2] = compat_round!(m2_mul.col(1).x); + output[base_offset + 3] = compat_round!(m2_mul.col(1).y); // Mat2 transpose let m2_transpose = m2a.transpose(); @@ -49,13 +49,13 @@ pub fn main_cs( output[base_offset + 7] = m2_transpose.col(1).y; // Mat2 determinant (with rounding for consistency) - output[base_offset + 8] = round6!(m2a.determinant()); + output[base_offset + 8] = compat_round!(m2a.determinant()); // Mat2 * Vec2 let v2 = Vec2::new(1.0, 2.0); let m2_v2 = m2a * v2; - output[base_offset + 9] = round6!(m2_v2.x); - output[base_offset + 10] = round6!(m2_v2.y); + output[base_offset + 9] = compat_round!(m2_v2.x); + output[base_offset + 10] = compat_round!(m2_v2.y); // Mat3 operations let m3a = Mat3::from_cols(Vec3::new(a, b, c), Vec3::new(b, c, d), Vec3::new(c, d, a)); @@ -63,15 +63,15 @@ pub fn main_cs( // Mat3 multiplication let m3_mul = m3a * m3b; - output[base_offset + 11] = round6!(m3_mul.col(0).x); - output[base_offset + 12] = round6!(m3_mul.col(0).y); - output[base_offset + 13] = round6!(m3_mul.col(0).z); - output[base_offset + 14] = round6!(m3_mul.col(1).x); - output[base_offset + 15] = round6!(m3_mul.col(1).y); - output[base_offset + 16] = round6!(m3_mul.col(1).z); - output[base_offset + 17] = round6!(m3_mul.col(2).x); - output[base_offset + 18] = round6!(m3_mul.col(2).y); - output[base_offset + 19] = round6!(m3_mul.col(2).z); + output[base_offset + 11] = compat_round!(m3_mul.col(0).x); + output[base_offset + 12] = compat_round!(m3_mul.col(0).y); + output[base_offset + 13] = compat_round!(m3_mul.col(0).z); + output[base_offset + 14] = compat_round!(m3_mul.col(1).x); + output[base_offset + 15] = compat_round!(m3_mul.col(1).y); + output[base_offset + 16] = compat_round!(m3_mul.col(1).z); + output[base_offset + 17] = compat_round!(m3_mul.col(2).x); + output[base_offset + 18] = compat_round!(m3_mul.col(2).y); + output[base_offset + 19] = compat_round!(m3_mul.col(2).z); // Mat3 transpose - store just diagonal elements let m3_transpose = m3a.transpose(); @@ -80,14 +80,14 @@ pub fn main_cs( output[base_offset + 22] = m3_transpose.col(2).z; // Mat3 determinant (with rounding for consistency) - output[base_offset + 23] = round6!(m3a.determinant()); + output[base_offset + 23] = compat_round!(m3a.determinant()); // Mat3 * Vec3 (with rounding for consistency) let v3 = Vec3::new(1.0, 2.0, 3.0); let m3_v3 = m3a * v3; - output[base_offset + 24] = round6!(m3_v3.x); - output[base_offset + 25] = round6!(m3_v3.y); - output[base_offset + 26] = round6!(m3_v3.z); + output[base_offset + 24] = compat_round!(m3_v3.x); + output[base_offset + 25] = compat_round!(m3_v3.y); + output[base_offset + 26] = compat_round!(m3_v3.z); // Mat4 operations let m4a = Mat4::from_cols( @@ -105,10 +105,10 @@ pub fn main_cs( // Mat4 multiplication (just store diagonal for brevity) let m4_mul = m4a * m4b; - output[base_offset + 27] = round6!(m4_mul.col(0).x); - output[base_offset + 28] = round6!(m4_mul.col(1).y); - output[base_offset + 29] = round6!(m4_mul.col(2).z); - output[base_offset + 30] = round6!(m4_mul.col(3).w); + output[base_offset + 27] = compat_round!(m4_mul.col(0).x); + output[base_offset + 28] = compat_round!(m4_mul.col(1).y); + output[base_offset + 29] = compat_round!(m4_mul.col(2).z); + output[base_offset + 30] = compat_round!(m4_mul.col(3).w); // Mat4 transpose (just store diagonal) let m4_transpose = m4a.transpose(); @@ -118,15 +118,15 @@ pub fn main_cs( output[base_offset + 34] = m4_transpose.col(3).w; // Mat4 determinant (with rounding for consistency) - output[base_offset + 35] = round6!(m4a.determinant()); + output[base_offset + 35] = compat_round!(m4a.determinant()); // Mat4 * Vec4 (with rounding for consistency) let v4 = Vec4::new(1.0, 2.0, 3.0, 4.0); let m4_v4 = m4a * v4; - output[base_offset + 36] = round6!(m4_v4.x); - output[base_offset + 37] = round6!(m4_v4.y); - output[base_offset + 38] = round6!(m4_v4.z); - output[base_offset + 39] = round6!(m4_v4.w); + output[base_offset + 36] = compat_round!(m4_v4.x); + output[base_offset + 37] = compat_round!(m4_v4.y); + output[base_offset + 38] = compat_round!(m4_v4.z); + output[base_offset + 39] = compat_round!(m4_v4.w); // Identity matrices output[base_offset + 40] = Mat2::IDENTITY.col(0).x; @@ -136,8 +136,8 @@ pub fn main_cs( // Matrix inverse if m2a.determinant().abs() > 0.0001 { let m2_inv = m2a.inverse(); - output[base_offset + 43] = round6!(m2_inv.col(0).x); - output[base_offset + 44] = round6!(m2_inv.col(1).y); + output[base_offset + 43] = compat_round!(m2_inv.col(0).x); + output[base_offset + 44] = compat_round!(m2_inv.col(1).y); } else { output[base_offset + 43] = 0.0; output[base_offset + 44] = 0.0; @@ -145,13 +145,13 @@ pub fn main_cs( // Matrix addition let m2_add = m2a + m2b; - output[base_offset + 45] = m2_add.col(0).x; - output[base_offset + 46] = m2_add.col(0).y; + output[base_offset + 45] = compat_round!(m2_add.col(0).x); + output[base_offset + 46] = compat_round!(m2_add.col(0).y); // Matrix scalar multiplication let m2_scale = m2a * 2.0; - output[base_offset + 47] = m2_scale.col(0).x; - output[base_offset + 48] = m2_scale.col(0).y; + output[base_offset + 47] = compat_round!(m2_scale.col(0).x); + output[base_offset + 48] = compat_round!(m2_scale.col(0).y); output[base_offset + 49] = 1.0; // Padding } diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs index c6fd3a2aa7..da6752004f 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs @@ -51,6 +51,13 @@ fn main() { buffers, ); + // Write metadata file + let metadata = difftest::config::TestMetadata { + epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding + output_type: difftest::config::OutputType::F32, + }; + config.write_metadata(&metadata).unwrap(); + test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl index dc7bf997ee..0dc01f70e3 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/shader.wgsl @@ -4,9 +4,9 @@ var input: array; @group(0) @binding(1) var output: array; -// Helper function to round to 6 decimal places for cross-platform consistency -fn round6(v: f32) -> f32 { - return round(v * 1000000.0) / 1000000.0; +// Helper function to round to 5 decimal places for cross-platform compatibility +fn compat_round(v: f32) -> f32 { + return round(v * 100000.0) / 100000.0; } @compute @workgroup_size(32, 1, 1) @@ -41,10 +41,10 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { // Mat2 multiplication let m2_mul = m2a * m2b; - output[base_offset + 0u] = round6(m2_mul[0].x); - output[base_offset + 1u] = round6(m2_mul[0].y); - output[base_offset + 2u] = round6(m2_mul[1].x); - output[base_offset + 3u] = round6(m2_mul[1].y); + output[base_offset + 0u] = compat_round(m2_mul[0].x); + output[base_offset + 1u] = compat_round(m2_mul[0].y); + output[base_offset + 2u] = compat_round(m2_mul[1].x); + output[base_offset + 3u] = compat_round(m2_mul[1].y); // Mat2 transpose let m2_transpose = transpose(m2a); @@ -54,13 +54,13 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { output[base_offset + 7u] = m2_transpose[1].y; // Mat2 determinant (with rounding for consistency) - output[base_offset + 8u] = round6(determinant(m2a)); + output[base_offset + 8u] = compat_round(determinant(m2a)); // Mat2 * Vec2 let v2 = vec2(1.0, 2.0); let m2_v2 = m2a * v2; - output[base_offset + 9u] = round6(m2_v2.x); - output[base_offset + 10u] = round6(m2_v2.y); + output[base_offset + 9u] = compat_round(m2_v2.x); + output[base_offset + 10u] = compat_round(m2_v2.y); // Mat3 operations let m3a = mat3x3( @@ -76,15 +76,15 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { // Mat3 multiplication let m3_mul = m3a * m3b; - output[base_offset + 11u] = round6(m3_mul[0].x); - output[base_offset + 12u] = round6(m3_mul[0].y); - output[base_offset + 13u] = round6(m3_mul[0].z); - output[base_offset + 14u] = round6(m3_mul[1].x); - output[base_offset + 15u] = round6(m3_mul[1].y); - output[base_offset + 16u] = round6(m3_mul[1].z); - output[base_offset + 17u] = round6(m3_mul[2].x); - output[base_offset + 18u] = round6(m3_mul[2].y); - output[base_offset + 19u] = round6(m3_mul[2].z); + output[base_offset + 11u] = compat_round(m3_mul[0].x); + output[base_offset + 12u] = compat_round(m3_mul[0].y); + output[base_offset + 13u] = compat_round(m3_mul[0].z); + output[base_offset + 14u] = compat_round(m3_mul[1].x); + output[base_offset + 15u] = compat_round(m3_mul[1].y); + output[base_offset + 16u] = compat_round(m3_mul[1].z); + output[base_offset + 17u] = compat_round(m3_mul[2].x); + output[base_offset + 18u] = compat_round(m3_mul[2].y); + output[base_offset + 19u] = compat_round(m3_mul[2].z); // Mat3 transpose let m3_transpose = transpose(m3a); @@ -93,14 +93,14 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { output[base_offset + 22u] = m3_transpose[2].z; // Mat3 determinant (with rounding for consistency) - output[base_offset + 23u] = round6(determinant(m3a)); + output[base_offset + 23u] = compat_round(determinant(m3a)); // Mat3 * Vec3 (with rounding for consistency) let v3 = vec3(1.0, 2.0, 3.0); let m3_v3 = m3a * v3; - output[base_offset + 24u] = round6(m3_v3.x); - output[base_offset + 25u] = round6(m3_v3.y); - output[base_offset + 26u] = round6(m3_v3.z); + output[base_offset + 24u] = compat_round(m3_v3.x); + output[base_offset + 25u] = compat_round(m3_v3.y); + output[base_offset + 26u] = compat_round(m3_v3.z); // Mat4 operations let m4a = mat4x4( @@ -118,10 +118,10 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { // Mat4 multiplication (just store diagonal for brevity) let m4_mul = m4a * m4b; - output[base_offset + 27u] = round6(m4_mul[0].x); - output[base_offset + 28u] = round6(m4_mul[1].y); - output[base_offset + 29u] = round6(m4_mul[2].z); - output[base_offset + 30u] = round6(m4_mul[3].w); + output[base_offset + 27u] = compat_round(m4_mul[0].x); + output[base_offset + 28u] = compat_round(m4_mul[1].y); + output[base_offset + 29u] = compat_round(m4_mul[2].z); + output[base_offset + 30u] = compat_round(m4_mul[3].w); // Mat4 transpose (just store diagonal) let m4_transpose = transpose(m4a); @@ -131,15 +131,15 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { output[base_offset + 34u] = m4_transpose[3].w; // Mat4 determinant (with rounding for consistency) - output[base_offset + 35u] = round6(determinant(m4a)); + output[base_offset + 35u] = compat_round(determinant(m4a)); // Mat4 * Vec4 (with rounding for consistency) let v4 = vec4(1.0, 2.0, 3.0, 4.0); let m4_v4 = m4a * v4; - output[base_offset + 36u] = round6(m4_v4.x); - output[base_offset + 37u] = round6(m4_v4.y); - output[base_offset + 38u] = round6(m4_v4.z); - output[base_offset + 39u] = round6(m4_v4.w); + output[base_offset + 36u] = compat_round(m4_v4.x); + output[base_offset + 37u] = compat_round(m4_v4.y); + output[base_offset + 38u] = compat_round(m4_v4.z); + output[base_offset + 39u] = compat_round(m4_v4.w); // Identity matrices output[base_offset + 40u] = mat2x2(vec2(1.0, 0.0), vec2(0.0, 1.0))[0].x; @@ -155,8 +155,8 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { vec2(m2a[1].y * inv_det, -m2a[0].y * inv_det), vec2(-m2a[1].x * inv_det, m2a[0].x * inv_det) ); - output[base_offset + 43u] = round6(m2_inv[0].x); - output[base_offset + 44u] = round6(m2_inv[1].y); + output[base_offset + 43u] = compat_round(m2_inv[0].x); + output[base_offset + 44u] = compat_round(m2_inv[1].y); } else { output[base_offset + 43u] = 0.0; output[base_offset + 44u] = 0.0; @@ -167,16 +167,16 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { m2a[0] + m2b[0], m2a[1] + m2b[1] ); - output[base_offset + 45u] = m2_add[0].x; - output[base_offset + 46u] = m2_add[0].y; + output[base_offset + 45u] = compat_round(m2_add[0].x); + output[base_offset + 46u] = compat_round(m2_add[0].y); // Matrix scalar multiplication let m2_scale = mat2x2( m2a[0] * 2.0, m2a[1] * 2.0 ); - output[base_offset + 47u] = m2_scale[0].x; - output[base_offset + 48u] = m2_scale[0].y; + output[base_offset + 47u] = compat_round(m2_scale[0].x); + output[base_offset + 48u] = compat_round(m2_scale[0].y); output[base_offset + 49u] = 1.0; // Padding } \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs index 7488803fdb..244d119609 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs @@ -50,5 +50,12 @@ fn main() { buffers, ); + // Write metadata file + let metadata = difftest::config::TestMetadata { + epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding + output_type: difftest::config::OutputType::F32, + }; + config.write_metadata(&metadata).unwrap(); + test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs index 2a53a8f760..f2bcf336f9 100644 --- a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/lib.rs @@ -1,6 +1,6 @@ #![no_std] -use difftest::round6; +use difftest::compat_round; use spirv_std::glam::{UVec2, UVec3, UVec4, Vec2, Vec3, Vec4, Vec4Swizzles}; #[allow(unused_imports)] use spirv_std::num_traits::Float; @@ -34,47 +34,47 @@ pub fn main_cs( let v2a = Vec2::new(a, b); let v2b = Vec2::new(c, d); - output[base_offset + 0] = round6!(v2a.dot(v2b)); - output[base_offset + 1] = round6!(v2a.length()); - output[base_offset + 2] = round6!(v2a.distance(v2b)); + output[base_offset + 0] = compat_round!(v2a.dot(v2b)); + output[base_offset + 1] = compat_round!(v2a.length()); + output[base_offset + 2] = compat_round!(v2a.distance(v2b)); let v2_add = v2a + v2b; - output[base_offset + 3] = v2_add.x; - output[base_offset + 4] = v2_add.y; + output[base_offset + 3] = compat_round!(v2_add.x); + output[base_offset + 4] = compat_round!(v2_add.y); let v2_mul = v2a * 2.0; - output[base_offset + 5] = v2_mul.x; - output[base_offset + 6] = v2_mul.y; + output[base_offset + 5] = compat_round!(v2_mul.x); + output[base_offset + 6] = compat_round!(v2_mul.y); // Vec3 operations let v3a = Vec3::new(a, b, c); let v3b = Vec3::new(b, c, d); - output[base_offset + 7] = round6!(v3a.dot(v3b)); - output[base_offset + 8] = round6!(v3a.length()); + output[base_offset + 7] = compat_round!(v3a.dot(v3b)); + output[base_offset + 8] = compat_round!(v3a.length()); let v3_cross = v3a.cross(v3b); - output[base_offset + 9] = round6!(v3_cross.x); - output[base_offset + 10] = round6!(v3_cross.y); - output[base_offset + 11] = round6!(v3_cross.z); + output[base_offset + 9] = compat_round!(v3_cross.x); + output[base_offset + 10] = compat_round!(v3_cross.y); + output[base_offset + 11] = compat_round!(v3_cross.z); let v3_norm = v3a.normalize(); - output[base_offset + 12] = round6!(v3_norm.x); - output[base_offset + 13] = round6!(v3_norm.y); - output[base_offset + 14] = round6!(v3_norm.z); + output[base_offset + 12] = compat_round!(v3_norm.x); + output[base_offset + 13] = compat_round!(v3_norm.y); + output[base_offset + 14] = compat_round!(v3_norm.z); // Vec4 operations let v4a = Vec4::new(a, b, c, d); let v4b = Vec4::new(d, c, b, a); - output[base_offset + 15] = round6!(v4a.dot(v4b)); - output[base_offset + 16] = round6!(v4a.length()); + output[base_offset + 15] = compat_round!(v4a.dot(v4b)); + output[base_offset + 16] = compat_round!(v4a.length()); let v4_sub = v4a - v4b; - output[base_offset + 17] = v4_sub.x; - output[base_offset + 18] = v4_sub.y; - output[base_offset + 19] = v4_sub.z; - output[base_offset + 20] = v4_sub.w; + output[base_offset + 17] = compat_round!(v4_sub.x); + output[base_offset + 18] = compat_round!(v4_sub.y); + output[base_offset + 19] = compat_round!(v4_sub.z); + output[base_offset + 20] = compat_round!(v4_sub.w); // Swizzling output[base_offset + 21] = v4a.x; diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs index d78d11bf17..7a8d4128f2 100644 --- a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs @@ -51,6 +51,13 @@ fn main() { buffers, ); + // Write metadata file + let metadata = difftest::config::TestMetadata { + epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding + output_type: difftest::config::OutputType::F32, + }; + config.write_metadata(&metadata).unwrap(); + test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl index 809ea276d1..6a58e409c7 100644 --- a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/shader.wgsl @@ -4,9 +4,9 @@ var input: array; @group(0) @binding(1) var output: array; -// Helper function to round to 6 decimal places for cross-platform consistency -fn round6(v: f32) -> f32 { - return round(v * 1000000.0) / 1000000.0; +// Helper function to round to 5 decimal places for cross-platform compatibility +fn compat_round(v: f32) -> f32 { + return round(v * 100000.0) / 100000.0; } @compute @workgroup_size(32, 1, 1) @@ -33,47 +33,47 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3) { let v2a = vec2(a, b); let v2b = vec2(c, d); - output[base_offset + 0u] = round6(dot(v2a, v2b)); - output[base_offset + 1u] = round6(length(v2a)); - output[base_offset + 2u] = round6(distance(v2a, v2b)); + output[base_offset + 0u] = compat_round(dot(v2a, v2b)); + output[base_offset + 1u] = compat_round(length(v2a)); + output[base_offset + 2u] = compat_round(distance(v2a, v2b)); let v2_add = v2a + v2b; - output[base_offset + 3u] = v2_add.x; - output[base_offset + 4u] = v2_add.y; + output[base_offset + 3u] = compat_round(v2_add.x); + output[base_offset + 4u] = compat_round(v2_add.y); let v2_mul = v2a * 2.0; - output[base_offset + 5u] = v2_mul.x; - output[base_offset + 6u] = v2_mul.y; + output[base_offset + 5u] = compat_round(v2_mul.x); + output[base_offset + 6u] = compat_round(v2_mul.y); // Vec3 operations let v3a = vec3(a, b, c); let v3b = vec3(b, c, d); - output[base_offset + 7u] = round6(dot(v3a, v3b)); - output[base_offset + 8u] = round6(length(v3a)); + output[base_offset + 7u] = compat_round(dot(v3a, v3b)); + output[base_offset + 8u] = compat_round(length(v3a)); let v3_cross = cross(v3a, v3b); - output[base_offset + 9u] = round6(v3_cross.x); - output[base_offset + 10u] = round6(v3_cross.y); - output[base_offset + 11u] = round6(v3_cross.z); + output[base_offset + 9u] = compat_round(v3_cross.x); + output[base_offset + 10u] = compat_round(v3_cross.y); + output[base_offset + 11u] = compat_round(v3_cross.z); let v3_norm = normalize(v3a); - output[base_offset + 12u] = round6(v3_norm.x); - output[base_offset + 13u] = round6(v3_norm.y); - output[base_offset + 14u] = round6(v3_norm.z); + output[base_offset + 12u] = compat_round(v3_norm.x); + output[base_offset + 13u] = compat_round(v3_norm.y); + output[base_offset + 14u] = compat_round(v3_norm.z); // Vec4 operations let v4a = vec4(a, b, c, d); let v4b = vec4(d, c, b, a); - output[base_offset + 15u] = round6(dot(v4a, v4b)); - output[base_offset + 16u] = round6(length(v4a)); + output[base_offset + 15u] = compat_round(dot(v4a, v4b)); + output[base_offset + 16u] = compat_round(length(v4a)); let v4_sub = v4a - v4b; - output[base_offset + 17u] = v4_sub.x; - output[base_offset + 18u] = v4_sub.y; - output[base_offset + 19u] = v4_sub.z; - output[base_offset + 20u] = v4_sub.w; + output[base_offset + 17u] = compat_round(v4_sub.x); + output[base_offset + 18u] = compat_round(v4_sub.y); + output[base_offset + 19u] = compat_round(v4_sub.z); + output[base_offset + 20u] = compat_round(v4_sub.w); // Swizzling output[base_offset + 21u] = v4a.x; diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs index fa33c6170c..b9a98fc170 100644 --- a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs @@ -50,5 +50,12 @@ fn main() { buffers, ); + // Write metadata file + let metadata = difftest::config::TestMetadata { + epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding + output_type: difftest::config::OutputType::F32, + }; + config.write_metadata(&metadata).unwrap(); + test.run_test(&config).unwrap(); } From 91b402cce71a400a02afaeb4f29ea00d6059bc72 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Sun, 6 Jul 2025 17:26:40 +0200 Subject: [PATCH 04/24] Fix cargo deny by turning off unused tabled features --- Cargo.lock | 95 +++++++--------------------------- tests/difftests/bin/Cargo.toml | 2 +- 2 files changed, 21 insertions(+), 76 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0002a7c444..090b7d431b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -308,7 +308,7 @@ checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -438,10 +438,10 @@ version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -701,7 +701,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn 2.0.94", + "syn", ] [[package]] @@ -978,7 +978,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -1052,7 +1052,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -1251,12 +1251,6 @@ dependencies = [ "foldhash", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -1804,7 +1798,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -2151,7 +2145,7 @@ checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -2237,30 +2231,6 @@ dependencies = [ "toml_edit", ] -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro2" version = "1.0.92" @@ -2671,7 +2641,7 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -2864,7 +2834,7 @@ dependencies = [ "proc-macro2", "quote", "spirv-std-types", - "syn 2.0.94", + "syn", ] [[package]] @@ -2930,22 +2900,11 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", - "syn 2.0.94", -] - -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", + "syn", ] [[package]] @@ -2966,23 +2925,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c998b0c8b921495196a48aabaf1901ff28be0760136e31604f7967b0792050e" dependencies = [ "papergrid", - "tabled_derive", "unicode-width", ] -[[package]] -name = "tabled_derive" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c138f99377e5d653a371cdad263615634cfc8467685dfe8e73e2b8e98f44b17" -dependencies = [ - "heck 0.4.1", - "proc-macro-error", - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "tempfile" version = "3.14.0" @@ -3055,7 +3000,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -3066,7 +3011,7 @@ checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -3207,7 +3152,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -3369,7 +3314,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.94", + "syn", "wasm-bindgen-shared", ] @@ -3404,7 +3349,7 @@ checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3836,7 +3781,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -3847,7 +3792,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -4231,5 +4176,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] diff --git a/tests/difftests/bin/Cargo.toml b/tests/difftests/bin/Cargo.toml index e62cc51049..a46a09e561 100644 --- a/tests/difftests/bin/Cargo.toml +++ b/tests/difftests/bin/Cargo.toml @@ -25,7 +25,7 @@ toml = { version = "0.8.20", default-features = false, features = ["parse"] } bytesize = "2.0.1" bytemuck = "1.21.0" difftest = { path = "../lib" } -tabled = "0.15" +tabled = { version = "0.15", default-features = false, features = ["std"] } [lints] workspace = true From 277d764ebc84cbb15d958c918ba4ea6934aa5a3f Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Sun, 6 Jul 2025 17:31:33 +0200 Subject: [PATCH 05/24] Use Default impl --- tests/difftests/bin/src/runner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/difftests/bin/src/runner.rs b/tests/difftests/bin/src/runner.rs index e3990ff304..9e74e6b46c 100644 --- a/tests/difftests/bin/src/runner.rs +++ b/tests/difftests/bin/src/runner.rs @@ -383,7 +383,7 @@ impl Runner { return Err(RunnerError::EmptyOutput); } - let output_type = output_type.unwrap_or(OutputType::Raw); + let output_type = output_type.unwrap_or_default(); let groups = self.group_outputs(&pkg_outputs, epsilon, output_type); if groups.len() > 1 { let differ: Box = output_type.into(); From ded0c0f8bc9d75e1cc2409179843eaede16c763a Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Sun, 6 Jul 2025 18:33:49 +0200 Subject: [PATCH 06/24] Switch to static DXC --- Cargo.lock | 7 +++++++ tests/difftests/lib/Cargo.toml | 2 +- tests/difftests/lib/src/scaffold/compute/wgpu.rs | 9 ++++++++- tests/difftests/tests/Cargo.lock | 7 +++++++ 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 090b7d431b..12c6f88263 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1501,6 +1501,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3bd0dd2cd90571056fdb71f6275fada10131182f84899f4b2a916e565d81d86" +[[package]] +name = "mach-dxcompiler-rs" +version = "0.1.4+2024.11.22-df583a3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e3cd67e8ea2ba061339150970542cf1c60ba44c6d17e31279cbc133a4b018f8" + [[package]] name = "malloc_buf" version = "0.0.6" @@ -3685,6 +3691,7 @@ dependencies = [ "libc", "libloading", "log", + "mach-dxcompiler-rs", "metal", "naga", "ndk-sys 0.5.0+25.2.9519653", diff --git a/tests/difftests/lib/Cargo.toml b/tests/difftests/lib/Cargo.toml index fb6cd75f14..31d12d00c7 100644 --- a/tests/difftests/lib/Cargo.toml +++ b/tests/difftests/lib/Cargo.toml @@ -19,7 +19,7 @@ use-compiled-tools = [ spirv-builder.workspace = true serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -wgpu = { version = "25.0.2", features = ["spirv", "vulkan-portability"] } +wgpu = { version = "25.0.2", features = ["spirv", "vulkan-portability", "static-dxc"] } tempfile = "3.5" futures = "0.3.31" bytemuck = "1.21.0" diff --git a/tests/difftests/lib/src/scaffold/compute/wgpu.rs b/tests/difftests/lib/src/scaffold/compute/wgpu.rs index 20b3491a44..7d0046066b 100644 --- a/tests/difftests/lib/src/scaffold/compute/wgpu.rs +++ b/tests/difftests/lib/src/scaffold/compute/wgpu.rs @@ -163,7 +163,14 @@ where #[cfg(not(target_os = "linux"))] backends: wgpu::Backends::PRIMARY, flags: Default::default(), - backend_options: Default::default(), + backend_options: wgpu::BackendOptions { + #[cfg(target_os = "windows")] + dx12: wgpu::Dx12BackendOptions { + shader_compiler: wgpu::Dx12Compiler::StaticDxc, + ..Default::default() + }, + ..Default::default() + }, }); let adapter = instance .request_adapter(&wgpu::RequestAdapterOptions { diff --git a/tests/difftests/tests/Cargo.lock b/tests/difftests/tests/Cargo.lock index c1cc81b10f..9db57ba794 100644 --- a/tests/difftests/tests/Cargo.lock +++ b/tests/difftests/tests/Cargo.lock @@ -685,6 +685,12 @@ version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" +[[package]] +name = "mach-dxcompiler-rs" +version = "0.1.4+2024.11.22-df583a3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e3cd67e8ea2ba061339150970542cf1c60ba44c6d17e31279cbc133a4b018f8" + [[package]] name = "malloc_buf" version = "0.0.6" @@ -1552,6 +1558,7 @@ dependencies = [ "libc", "libloading", "log", + "mach-dxcompiler-rs", "metal", "naga", "ndk-sys", From 7cfcc2eec3c27197ffb52a86f7fb4fed94e46e65 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Sun, 6 Jul 2025 22:07:20 +0200 Subject: [PATCH 07/24] Add support for U32 Also don't show stats/counts when delaing with raw diffs --- tests/difftests/bin/src/differ.rs | 359 ++++++++++++++---- tests/difftests/bin/src/runner.rs | 46 ++- .../workgroup_memory-rust/src/main.rs | 7 + .../workgroup_memory-wgsl/src/main.rs | 7 + 4 files changed, 328 insertions(+), 91 deletions(-) diff --git a/tests/difftests/bin/src/differ.rs b/tests/difftests/bin/src/differ.rs index 5e74ffe642..38b39eb250 100644 --- a/tests/difftests/bin/src/differ.rs +++ b/tests/difftests/bin/src/differ.rs @@ -1,4 +1,14 @@ use difftest::config::OutputType; +use std::marker::PhantomData; + +/// Represents the magnitude of a difference between two values +#[derive(Debug, Clone)] +pub enum DiffMagnitude { + /// A numeric difference that can be measured + Numeric(f64), + /// A difference that cannot be measured numerically (e.g., raw bytes) + Incomparable, +} /// Trait for comparing two outputs and producing differences pub trait OutputDiffer { @@ -33,8 +43,8 @@ pub struct Difference { pub index: usize, pub value1: String, pub value2: String, - pub absolute_diff: f64, - pub relative_diff: Option, + pub absolute_diff: DiffMagnitude, + pub relative_diff: DiffMagnitude, } /// Differ for raw byte comparison @@ -50,8 +60,8 @@ impl OutputDiffer for RawDiffer { index: 0, value1: format!("{} bytes", output1.len()), value2: format!("{} bytes", output2.len()), - absolute_diff: 0.0, - relative_diff: None, + absolute_diff: DiffMagnitude::Incomparable, + relative_diff: DiffMagnitude::Incomparable, }] } } @@ -91,50 +101,129 @@ impl DifferenceDisplay for RawDiffer { } } -/// Differ for f32 arrays -pub struct F32Differ; +/// Trait for numeric types that can be diffed +pub trait NumericType: Copy + PartialEq + std::fmt::Display + Send + Sync + 'static { + fn from_bytes(bytes: &[u8]) -> Self; + fn abs_diff(a: Self, b: Self) -> f64; + fn type_name() -> &'static str; + fn format_value(value: Self) -> String; + fn can_have_relative_diff() -> bool; + fn as_f64(value: Self) -> f64; +} -impl OutputDiffer for F32Differ { +impl NumericType for f32 { + fn from_bytes(bytes: &[u8]) -> Self { + f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) + } + fn abs_diff(a: Self, b: Self) -> f64 { + (a - b).abs() as f64 + } + fn type_name() -> &'static str { + "F32" + } + fn format_value(value: Self) -> String { + format!("{:.9}", value) + } + fn can_have_relative_diff() -> bool { + true + } + fn as_f64(value: Self) -> f64 { + value as f64 + } +} + +impl NumericType for u32 { + fn from_bytes(bytes: &[u8]) -> Self { + u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) + } + fn abs_diff(a: Self, b: Self) -> f64 { + if a > b { + (a - b) as f64 + } else { + (b - a) as f64 + } + } + fn type_name() -> &'static str { + "U32" + } + fn format_value(value: Self) -> String { + format!("{}", value) + } + fn can_have_relative_diff() -> bool { + false + } + fn as_f64(value: Self) -> f64 { + value as f64 + } +} + +/// Generic differ for numeric types +pub struct NumericDiffer { + _phantom: PhantomData, +} + +impl Default for NumericDiffer { + fn default() -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl OutputDiffer for NumericDiffer { fn compare(&self, output1: &[u8], output2: &[u8], epsilon: Option) -> Vec { if output1.len() != output2.len() { return vec![Difference { index: 0, value1: format!("{} bytes", output1.len()), value2: format!("{} bytes", output2.len()), - absolute_diff: 0.0, - relative_diff: None, + absolute_diff: DiffMagnitude::Numeric(0.0), + relative_diff: DiffMagnitude::Incomparable, }]; } - let floats1: Vec = output1 - .chunks_exact(4) - .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + let values1: Vec = output1 + .chunks_exact(std::mem::size_of::()) + .map(T::from_bytes) .collect(); - let floats2: Vec = output2 - .chunks_exact(4) - .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + let values2: Vec = output2 + .chunks_exact(std::mem::size_of::()) + .map(T::from_bytes) .collect(); let mut differences = Vec::new(); - - for (i, (&v1, &v2)) in floats1.iter().zip(floats2.iter()).enumerate() { - let diff = (v1 - v2).abs(); - // If no epsilon specified, report all differences - let threshold = epsilon.unwrap_or(0.0); - if diff > threshold { - let rel_diff = if v1.abs() > 1e-10 || v2.abs() > 1e-10 { - Some((diff as f64) / f64::max(v1.abs() as f64, v2.abs() as f64)) + for (i, (&v1, &v2)) in values1.iter().zip(values2.iter()).enumerate() { + if v1 != v2 { + let diff = T::abs_diff(v1, v2); + // For floating point types, check epsilon threshold + let threshold = if T::can_have_relative_diff() { + epsilon.unwrap_or(0.0) as f64 } else { - None + 0.0 }; - differences.push(Difference { - index: i, - value1: format!("{:.9}", v1), - value2: format!("{:.9}", v2), - absolute_diff: diff as f64, - relative_diff: rel_diff, - }); + if diff > threshold { + let rel_diff = if T::can_have_relative_diff() { + let v1_abs = T::as_f64(v1).abs(); + let v2_abs = T::as_f64(v2).abs(); + let max_abs = f64::max(v1_abs, v2_abs); + if max_abs > 1e-10 { + DiffMagnitude::Numeric(diff / max_abs) + } else { + DiffMagnitude::Incomparable + } + } else { + DiffMagnitude::Incomparable + }; + + differences.push(Difference { + index: i, + value1: T::format_value(v1), + value2: T::format_value(v2), + absolute_diff: DiffMagnitude::Numeric(diff), + relative_diff: rel_diff, + }); + } } } @@ -142,55 +231,61 @@ impl OutputDiffer for F32Differ { } fn name(&self) -> &'static str { - "32-bit Float" + T::type_name() } } -impl DifferenceDisplay for F32Differ { +impl DifferenceDisplay for NumericDiffer { fn format_table(&self, diffs: &[Difference], pkg1: &str, pkg2: &str) -> String { use tabled::settings::{Alignment, Modify, Style, object::Rows}; - if diffs.is_empty() { - return String::new(); - } - - // Extract suffix from package names (e.g., "math_ops-wgsl" -> "WGSL") - let extract_suffix = |pkg: &str| -> String { - if let Some(pos) = pkg.rfind('-') { - let suffix = &pkg[pos + 1..]; - suffix.to_uppercase() - } else { - pkg.to_string() - } - }; - - let col1_name = extract_suffix(pkg1); - let col2_name = extract_suffix(pkg2); - - // Build rows for the table let rows: Vec> = diffs .iter() .take(10) .map(|d| { - vec![ - d.index.to_string(), - d.value1.clone(), - d.value2.clone(), - format!("{:.3e}", d.absolute_diff), - d.relative_diff - .map(|r| format!("{:.2}%", r * 100.0)) - .unwrap_or_else(|| "N/A".to_string()), - ] + let abs_str = match &d.absolute_diff { + DiffMagnitude::Numeric(val) => { + if T::can_have_relative_diff() { + format!("{:.3e}", val) + } else { + format!("{}", *val as u64) + } + } + DiffMagnitude::Incomparable => "N/A".to_string(), + }; + + let rel_str = match &d.relative_diff { + DiffMagnitude::Numeric(val) => format!("{:.2}%", val * 100.0), + DiffMagnitude::Incomparable => "N/A".to_string(), + }; + + if T::can_have_relative_diff() { + vec![ + d.index.to_string(), + d.value1.clone(), + d.value2.clone(), + abs_str, + rel_str, + ] + } else { + vec![ + d.index.to_string(), + d.value1.clone(), + d.value2.clone(), + abs_str, + ] + } }) .collect(); - // Create a table with custom headers let mut builder = tabled::builder::Builder::default(); - // Add header row - builder.push_record(vec!["#", &col1_name, &col2_name, "Δ abs", "Δ %"]); + if T::can_have_relative_diff() { + builder.push_record(vec!["#", pkg1, pkg2, "Δ abs", "Δ %"]); + } else { + builder.push_record(vec!["#", pkg1, pkg2, "Δ"]); + } - // Add data rows for row in &rows { builder.push_record(row); } @@ -203,7 +298,6 @@ impl DifferenceDisplay for F32Differ { let mut result = table.to_string(); if diffs.len() > 10 { - // Get the width of the last line to properly align the text let last_line_width = result .lines() .last() @@ -233,26 +327,30 @@ impl DifferenceDisplay for F32Differ { use std::io::Write; let mut file = std::fs::File::create(path)?; - let floats: Vec = output - .chunks_exact(4) - .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + let values: Vec = output + .chunks_exact(std::mem::size_of::()) + .map(T::from_bytes) .collect(); - for (i, &value) in floats.iter().enumerate() { - writeln!(file, "{}: {:.9}", i, value)?; + for (i, &value) in values.iter().enumerate() { + writeln!(file, "{}: {}", i, T::format_value(value))?; } Ok(()) } } +/// Type aliases for specific numeric differs +pub type F32Differ = NumericDiffer; +pub type U32Differ = NumericDiffer; + impl From for Box { fn from(output_type: OutputType) -> Self { match output_type { OutputType::Raw => Box::new(RawDiffer), - OutputType::F32 => Box::new(F32Differ), + OutputType::F32 => Box::new(F32Differ::default()), OutputType::F64 => todo!("F64Differ not implemented yet"), - OutputType::U32 => todo!("U32Differ not implemented yet"), + OutputType::U32 => Box::new(U32Differ::default()), OutputType::I32 => todo!("I32Differ not implemented yet"), } } @@ -262,10 +360,123 @@ impl From for Box { fn from(output_type: OutputType) -> Self { match output_type { OutputType::Raw => Box::new(RawDiffer), - OutputType::F32 => Box::new(F32Differ), + OutputType::F32 => Box::new(F32Differ::default()), OutputType::F64 => todo!("F64Display not implemented yet"), - OutputType::U32 => todo!("U32Display not implemented yet"), + OutputType::U32 => Box::new(U32Differ::default()), OutputType::I32 => todo!("I32Display not implemented yet"), } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_u32_differ_no_differences() { + let differ = U32Differ::default(); + let data1 = vec![1u32, 2, 3, 4]; + let bytes1 = bytemuck::cast_slice(&data1); + let bytes2 = bytemuck::cast_slice(&data1); + + let diffs = differ.compare(bytes1, bytes2, None); + assert!(diffs.is_empty()); + } + + #[test] + fn test_u32_differ_with_differences() { + let differ = U32Differ::default(); + let data1 = vec![1u32, 2, 3, 4]; + let data2 = vec![1u32, 5, 3, 7]; + let bytes1 = bytemuck::cast_slice(&data1); + let bytes2 = bytemuck::cast_slice(&data2); + + let diffs = differ.compare(bytes1, bytes2, None); + assert_eq!(diffs.len(), 2); + + // Check first difference (index 1: 2 vs 5) + assert_eq!(diffs[0].index, 1); + assert_eq!(diffs[0].value1, "2"); + assert_eq!(diffs[0].value2, "5"); + match &diffs[0].absolute_diff { + DiffMagnitude::Numeric(val) => assert_eq!(*val, 3.0), + _ => panic!("Expected numeric difference"), + } + match &diffs[0].relative_diff { + DiffMagnitude::Incomparable => {} + _ => panic!("Expected incomparable relative diff for U32"), + } + + // Check second difference (index 3: 4 vs 7) + assert_eq!(diffs[1].index, 3); + assert_eq!(diffs[1].value1, "4"); + assert_eq!(diffs[1].value2, "7"); + match &diffs[1].absolute_diff { + DiffMagnitude::Numeric(val) => assert_eq!(*val, 3.0), + _ => panic!("Expected numeric difference"), + } + } + + #[test] + fn test_u32_differ_different_lengths() { + let differ = U32Differ::default(); + let data1 = vec![1u32, 2]; + let data2 = vec![1u32, 2, 3, 4]; + let bytes1 = bytemuck::cast_slice(&data1); + let bytes2 = bytemuck::cast_slice(&data2); + + let diffs = differ.compare(bytes1, bytes2, None); + assert_eq!(diffs.len(), 1); + assert_eq!(diffs[0].value1, "8 bytes"); + assert_eq!(diffs[0].value2, "16 bytes"); + } + + #[test] + fn test_f32_differ_with_epsilon() { + let differ = F32Differ::default(); + let data1 = vec![1.0f32, 2.0, 3.0]; + let data2 = vec![1.0001f32, 2.0, 3.01]; + let bytes1 = bytemuck::cast_slice(&data1); + let bytes2 = bytemuck::cast_slice(&data2); + + // With epsilon = 0.001, only the third value should be reported + let diffs = differ.compare(bytes1, bytes2, Some(0.001)); + assert_eq!(diffs.len(), 1); + assert_eq!(diffs[0].index, 2); + + // Without epsilon, both differences should be reported + let diffs = differ.compare(bytes1, bytes2, None); + assert_eq!(diffs.len(), 2); + } + + #[test] + fn test_raw_differ() { + let differ = RawDiffer; + let bytes1 = b"hello"; + let bytes2 = b"world"; + + let diffs = differ.compare(bytes1, bytes2, None); + assert_eq!(diffs.len(), 1); + match &diffs[0].absolute_diff { + DiffMagnitude::Incomparable => {} + _ => panic!("Expected incomparable diff for raw bytes"), + } + } + + #[test] + fn test_diff_magnitude_enum() { + // Test that we can create and match on DiffMagnitude variants + let numeric = DiffMagnitude::Numeric(42.0); + let incomparable = DiffMagnitude::Incomparable; + + match numeric { + DiffMagnitude::Numeric(val) => assert_eq!(val, 42.0), + _ => panic!("Expected numeric"), + } + + match incomparable { + DiffMagnitude::Incomparable => {} + _ => panic!("Expected incomparable"), + } + } +} diff --git a/tests/difftests/bin/src/runner.rs b/tests/difftests/bin/src/runner.rs index 9e74e6b46c..b2cf6cd506 100644 --- a/tests/difftests/bin/src/runner.rs +++ b/tests/difftests/bin/src/runner.rs @@ -12,7 +12,7 @@ use tempfile::NamedTempFile; use thiserror::Error; use tracing::{debug, error, info, trace}; -use crate::differ::{Difference, DifferenceDisplay, OutputDiffer}; +use crate::differ::{DiffMagnitude, Difference, DifferenceDisplay, OutputDiffer}; #[derive(Debug, Error)] pub enum RunnerError { @@ -103,11 +103,13 @@ impl ErrorReport { fn set_summary_from_differences(&mut self, differences: &[Difference]) { if !differences.is_empty() { - let (max_diff, max_rel) = self.calculate_max_differences(differences); self.summary_parts .push(format!("{} differences", differences.len())); - self.summary_parts - .push(format!("max: {:.3e} ({:.2}%)", max_diff, max_rel * 100.0)); + + if let Some((max_diff, max_rel)) = self.calculate_max_differences(differences) { + self.summary_parts + .push(format!("max: {:.3e} ({:.2}%)", max_diff, max_rel * 100.0)); + } } } @@ -157,28 +159,38 @@ impl ErrorReport { fn add_summary_line(&mut self, differences: &[Difference]) { if !differences.is_empty() { - let (max_diff, max_rel) = self.calculate_max_differences(differences); - self.lines.push(format!( - "• {} differences, max: {:.3e} ({:.2}%)", - differences.len(), - max_diff, - max_rel * 100.0 - )); + if let Some((max_diff, max_rel)) = self.calculate_max_differences(differences) { + self.lines.push(format!( + "• {} differences, max: {:.3e} ({:.2}%)", + differences.len(), + max_diff, + max_rel * 100.0 + )); + } else { + self.lines + .push(format!("• {} differences", differences.len())); + } } } - fn calculate_max_differences(&self, differences: &[Difference]) -> (f64, f64) { + fn calculate_max_differences(&self, differences: &[Difference]) -> Option<(f64, f64)> { let max_diff = differences .iter() - .map(|d| d.absolute_diff) - .max_by(|a, b| a.partial_cmp(b).unwrap()) - .unwrap_or(0.0); + .filter_map(|d| match &d.absolute_diff { + DiffMagnitude::Numeric(val) => Some(*val), + DiffMagnitude::Incomparable => None, + }) + .max_by(|a, b| a.partial_cmp(b).unwrap()); let max_rel = differences .iter() - .filter_map(|d| d.relative_diff) + .filter_map(|d| match &d.relative_diff { + DiffMagnitude::Numeric(val) => Some(*val), + DiffMagnitude::Incomparable => None, + }) .max_by(|a, b| a.partial_cmp(b).unwrap()) .unwrap_or(0.0); - (max_diff, max_rel) + + max_diff.map(|diff| (diff, max_rel)) } fn build(self) -> String { diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs index cd20ab4765..f96dbca335 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs @@ -30,6 +30,13 @@ fn main() { buffers, ); + // Write metadata for U32 comparison + let metadata = difftest::config::TestMetadata { + epsilon: None, + output_type: difftest::config::OutputType::U32, + }; + config.write_metadata(&metadata).unwrap(); + test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs index 1da92981a9..49989b6668 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs @@ -29,5 +29,12 @@ fn main() { buffers, ); + // Write metadata for U32 comparison + let metadata = difftest::config::TestMetadata { + epsilon: None, + output_type: difftest::config::OutputType::U32, + }; + config.write_metadata(&metadata).unwrap(); + test.run_test(&config).unwrap(); } From 3a946faa7ab2ac29e6dfc73f68a20d00274eae78 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Sun, 6 Jul 2025 22:12:12 +0200 Subject: [PATCH 08/24] Add final barrier to (hopefully) fix windows --- .../arch/workgroup_memory/workgroup_memory-rust/src/lib.rs | 4 ++++ .../arch/workgroup_memory/workgroup_memory-wgsl/shader.wgsl | 2 ++ 2 files changed, 6 insertions(+) diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs index dc706b2477..9b0b6aef92 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs @@ -66,6 +66,10 @@ pub fn main_cs( shared[lid] += shared[lid + 1]; } + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + // Write final result if lid == 0 { output[0] = shared[0]; diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/shader.wgsl b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/shader.wgsl index 7f12d99439..4f6018ddf0 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/shader.wgsl +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/shader.wgsl @@ -52,6 +52,8 @@ fn main_cs(@builtin(local_invocation_id) local_id: vec3) { shared_data[lid] += shared_data[lid + 1u]; } + workgroupBarrier(); + // Write final result if (lid == 0u) { output[0] = shared_data[0]; From 3f547c9015ac1509a846f632b021929f45b0c374 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Mon, 7 Jul 2025 00:28:07 +0200 Subject: [PATCH 09/24] Add Skip and add a vulkano test to try to debug windows --- Cargo.lock | 249 ++++++++++++++--- tests/difftests/README.md | 40 +++ tests/difftests/bin/src/runner.rs | 12 + tests/difftests/lib/Cargo.toml | 2 + tests/difftests/lib/src/config.rs | 4 + .../lib/src/scaffold/compute/backend.rs | 84 ++++++ .../difftests/lib/src/scaffold/compute/mod.rs | 7 +- .../lib/src/scaffold/compute/vulkano.rs | 250 ++++++++++++++++++ .../lib/src/scaffold/compute/wgpu.rs | 190 ++++++++++++- tests/difftests/lib/src/scaffold/mod.rs | 3 + tests/difftests/lib/src/scaffold/skip.rs | 25 ++ tests/difftests/tests/Cargo.lock | 204 +++++++++++++- tests/difftests/tests/Cargo.toml | 1 + .../workgroup_memory-rust/src/lib.rs | 76 +----- .../workgroup_memory-rust/src/main.rs | 1 + .../workgroup_memory-rust/src/shader.rs | 75 ++++++ .../workgroup_memory-vulkano/Cargo.toml | 22 ++ .../workgroup_memory-vulkano/src/lib.rs | 8 + .../workgroup_memory-vulkano/src/main.rs | 88 ++++++ .../workgroup_memory-wgsl/src/main.rs | 1 + .../ops/math_ops/math_ops-rust/src/main.rs | 1 + .../ops/math_ops/math_ops-wgsl/src/main.rs | 1 + .../matrix_ops/matrix_ops-rust/src/main.rs | 1 + .../matrix_ops/matrix_ops-wgsl/src/main.rs | 1 + .../vector_ops/vector_ops-rust/src/main.rs | 1 + .../vector_ops/vector_ops-wgsl/src/main.rs | 1 + 26 files changed, 1218 insertions(+), 130 deletions(-) create mode 100644 tests/difftests/lib/src/scaffold/compute/backend.rs create mode 100644 tests/difftests/lib/src/scaffold/compute/vulkano.rs create mode 100644 tests/difftests/lib/src/scaffold/skip.rs create mode 100644 tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/shader.rs create mode 100644 tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml create mode 100644 tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/lib.rs create mode 100644 tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 12c6f88263..4e898aee13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -213,7 +213,7 @@ checksum = "52bca67b61cb81e5553babde81b8211f713cb6db79766f80168f3e5f40ea6c82" dependencies = [ "ash", "raw-window-handle 0.6.2", - "raw-window-metal", + "raw-window-metal 0.4.0", ] [[package]] @@ -276,7 +276,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c132eebf10f5cad5289222520a4a058514204aed6d791f1cf4fe8088b82d15f" dependencies = [ - "objc2", + "objc2 0.5.2", ] [[package]] @@ -438,7 +438,7 @@ version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn", @@ -658,6 +658,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -717,10 +726,12 @@ dependencies = [ "anyhow", "bytemuck", "futures", + "naga", "serde", "serde_json", "spirv-builder", "tempfile", + "vulkano", "wgpu", ] @@ -770,6 +781,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd0c93bb4b0c6d9b77f4435b0ae98c24d17f1c45b2ff844c6151a07256ca923b" +[[package]] +name = "dispatch2" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec" +dependencies = [ + "bitflags 2.9.1", + "objc2 0.6.1", +] + [[package]] name = "dlib" version = "0.5.2" @@ -1225,6 +1246,7 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ + "bytemuck", "cfg-if", "crunchy", "num-traits", @@ -1251,6 +1273,12 @@ dependencies = [ "foldhash", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "heck" version = "0.5.0" @@ -1590,6 +1618,12 @@ dependencies = [ "x11-dl", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.8.2" @@ -1654,6 +1688,7 @@ dependencies = [ "num-traits", "once_cell", "petgraph", + "pp-rs", "rustc-hash", "spirv", "strum", @@ -1713,6 +1748,16 @@ dependencies = [ "memoffset", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "notify" version = "7.0.0" @@ -1832,6 +1877,15 @@ dependencies = [ "objc2-encode", ] +[[package]] +name = "objc2" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88c6597e14493ab2e44ce58f2fdecf095a51f12ca57bec060a11c57332520551" +dependencies = [ + "objc2-encode", +] + [[package]] name = "objc2-app-kit" version = "0.2.2" @@ -1841,11 +1895,11 @@ dependencies = [ "bitflags 2.9.1", "block2", "libc", - "objc2", + "objc2 0.5.2", "objc2-core-data", "objc2-core-image", - "objc2-foundation", - "objc2-quartz-core", + "objc2-foundation 0.2.2", + "objc2-quartz-core 0.2.2", ] [[package]] @@ -1856,9 +1910,9 @@ checksum = "74dd3b56391c7a0596a295029734d3c1c5e7e510a4cb30245f8221ccea96b009" dependencies = [ "bitflags 2.9.1", "block2", - "objc2", + "objc2 0.5.2", "objc2-core-location", - "objc2-foundation", + "objc2-foundation 0.2.2", ] [[package]] @@ -1868,8 +1922,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5ff520e9c33812fd374d8deecef01d4a840e7b41862d849513de77e44aa4889" dependencies = [ "block2", - "objc2", - "objc2-foundation", + "objc2 0.5.2", + "objc2-foundation 0.2.2", ] [[package]] @@ -1880,8 +1934,19 @@ checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef" dependencies = [ "bitflags 2.9.1", "block2", - "objc2", - "objc2-foundation", + "objc2 0.5.2", + "objc2-foundation 0.2.2", +] + +[[package]] +name = "objc2-core-foundation" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" +dependencies = [ + "bitflags 2.9.1", + "dispatch2", + "objc2 0.6.1", ] [[package]] @@ -1891,9 +1956,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55260963a527c99f1819c4f8e3b47fe04f9650694ef348ffd2227e8196d34c80" dependencies = [ "block2", - "objc2", - "objc2-foundation", - "objc2-metal", + "objc2 0.5.2", + "objc2-foundation 0.2.2", + "objc2-metal 0.2.2", ] [[package]] @@ -1903,16 +1968,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "000cfee34e683244f284252ee206a27953279d370e309649dc3ee317b37e5781" dependencies = [ "block2", - "objc2", + "objc2 0.5.2", "objc2-contacts", - "objc2-foundation", + "objc2-foundation 0.2.2", ] [[package]] name = "objc2-encode" -version = "4.0.3" +version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7891e71393cd1f227313c9379a26a584ff3d7e6e7159e988851f0934c993f0f8" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" [[package]] name = "objc2-foundation" @@ -1924,7 +1989,18 @@ dependencies = [ "block2", "dispatch", "libc", - "objc2", + "objc2 0.5.2", +] + +[[package]] +name = "objc2-foundation" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c" +dependencies = [ + "bitflags 2.9.1", + "objc2 0.6.1", + "objc2-core-foundation", ] [[package]] @@ -1934,9 +2010,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1a1ae721c5e35be65f01a03b6d2ac13a54cb4fa70d8a5da293d7b0020261398" dependencies = [ "block2", - "objc2", + "objc2 0.5.2", "objc2-app-kit", - "objc2-foundation", + "objc2-foundation 0.2.2", ] [[package]] @@ -1947,8 +2023,19 @@ checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6" dependencies = [ "bitflags 2.9.1", "block2", - "objc2", - "objc2-foundation", + "objc2 0.5.2", + "objc2-foundation 0.2.2", +] + +[[package]] +name = "objc2-metal" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f246c183239540aab1782457b35ab2040d4259175bd1d0c58e46ada7b47a874" +dependencies = [ + "bitflags 2.9.1", + "objc2 0.6.1", + "objc2-foundation 0.3.1", ] [[package]] @@ -1959,9 +2046,22 @@ checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a" dependencies = [ "bitflags 2.9.1", "block2", - "objc2", - "objc2-foundation", - "objc2-metal", + "objc2 0.5.2", + "objc2-foundation 0.2.2", + "objc2-metal 0.2.2", +] + +[[package]] +name = "objc2-quartz-core" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ffb6a0cd5f182dc964334388560b12a57f7b74b3e2dec5e2722aa2dfb2ccd5" +dependencies = [ + "bitflags 2.9.1", + "objc2 0.6.1", + "objc2-core-foundation", + "objc2-foundation 0.3.1", + "objc2-metal 0.3.1", ] [[package]] @@ -1970,8 +2070,8 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a684efe3dec1b305badae1a28f6555f6ddd3bb2c2267896782858d5a78404dc" dependencies = [ - "objc2", - "objc2-foundation", + "objc2 0.5.2", + "objc2-foundation 0.2.2", ] [[package]] @@ -1982,14 +2082,14 @@ checksum = "b8bb46798b20cd6b91cbd113524c490f1686f4c4e8f49502431415f3512e2b6f" dependencies = [ "bitflags 2.9.1", "block2", - "objc2", + "objc2 0.5.2", "objc2-cloud-kit", "objc2-core-data", "objc2-core-image", "objc2-core-location", - "objc2-foundation", + "objc2-foundation 0.2.2", "objc2-link-presentation", - "objc2-quartz-core", + "objc2-quartz-core 0.2.2", "objc2-symbols", "objc2-uniform-type-identifiers", "objc2-user-notifications", @@ -2002,8 +2102,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44fa5f9748dbfe1ca6c0b79ad20725a11eca7c2218bceb4b005cb1be26273bfe" dependencies = [ "block2", - "objc2", - "objc2-foundation", + "objc2 0.5.2", + "objc2-foundation 0.2.2", ] [[package]] @@ -2014,9 +2114,9 @@ checksum = "76cfcbf642358e8689af64cee815d139339f3ed8ad05103ed5eaf73db8d84cb3" dependencies = [ "bitflags 2.9.1", "block2", - "objc2", + "objc2 0.5.2", "objc2-core-location", - "objc2-foundation", + "objc2-foundation 0.2.2", ] [[package]] @@ -2212,6 +2312,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "pp-rs" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb458bb7f6e250e6eb79d5026badc10a3ebb8f9a15d1fff0f13d17c71f4d6dee" +dependencies = [ + "unicode-xid", +] + [[package]] name = "presser" version = "0.3.1" @@ -2324,6 +2433,18 @@ dependencies = [ "raw-window-handle 0.6.2", ] +[[package]] +name = "raw-window-metal" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40d213455a5f1dc59214213c7330e074ddf8114c9a42411eb890c767357ce135" +dependencies = [ + "objc2 0.6.1", + "objc2-core-foundation", + "objc2-foundation 0.3.1", + "objc2-quartz-core 0.3.1", +] + [[package]] name = "rayon" version = "1.10.0" @@ -2719,6 +2840,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "slabbin" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9db491c0d4152a069911a0fbdaca959691bf0b9d7110d98a7ed1c8e59b79ab30" + [[package]] name = "slotmap" version = "1.0.7" @@ -2906,7 +3033,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "rustversion", @@ -3259,6 +3386,12 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "utf8parse" version = "0.2.2" @@ -3283,6 +3416,44 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vk-parse" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3859da4d7b98bec73e68fb65815d47a263819c415c90eed42b80440a02cbce8c" +dependencies = [ + "xml-rs", +] + +[[package]] +name = "vulkano" +version = "0.35.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08840c2b51759a6f88f26f5ea378bc8b5c199a5b4760ddda292304be087249c4" +dependencies = [ + "ash", + "bytemuck", + "crossbeam-queue", + "foldhash", + "half", + "heck 0.4.1", + "indexmap", + "libloading", + "nom", + "once_cell", + "parking_lot", + "proc-macro2", + "quote", + "raw-window-handle 0.6.2", + "raw-window-metal 1.1.0", + "serde", + "serde_json", + "slabbin", + "smallvec", + "thread_local", + "vk-parse", +] + [[package]] name = "walkdir" version = "2.5.0" @@ -4058,9 +4229,9 @@ dependencies = [ "libc", "memmap2", "ndk", - "objc2", + "objc2 0.5.2", "objc2-app-kit", - "objc2-foundation", + "objc2-foundation 0.2.2", "objc2-ui-kit", "orbclient", "percent-encoding", diff --git a/tests/difftests/README.md b/tests/difftests/README.md index 9e282bbb89..d825f43f8a 100644 --- a/tests/difftests/README.md +++ b/tests/difftests/README.md @@ -99,12 +99,18 @@ The library provides helper types for common test patterns: - `WgpuComputeTestMultiBuffer` - Multi-buffer compute shader test with input/output separation - `WgpuComputeTestPushConstant` - Compute shader test with push constants support +- `Skip` - Marks a test variant as skipped with a reason **Shader source types:** - `RustComputeShader` - Compiles the current crate as a Rust GPU shader - `WgslComputeShader` - Loads WGSL shader from file (shader.wgsl or compute.wgsl) +**Backend types:** + +- `WgpuBackend` - Default wgpu-based compute backend +- `VulkanoBackend` - Vulkano-based compute backend (useful for testing different GPU drivers) + For examples, see: - [`tests/lang/core/ops/math_ops/`](tests/lang/core/ops/math_ops/) - Multi-buffer test @@ -205,6 +211,40 @@ The harness automatically writes human-readable `.txt` files alongside binary ou For floating-point data (F32/F64), these show the array values in decimal format. For raw/integer data, these show the values as hex bytes or integers +## Skipping Tests on Specific Platforms + +Sometimes a test variant needs to be skipped on certain platforms (e.g., due to driver +issues or platform limitations). The difftest framework provides a clean way to handle +this using the `Skip` scaffolding type: + +```rust +use difftest::scaffold::Skip; + +fn main() { + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Skip on macOS due to platform-specific issues + #[cfg(target_os = "macos")] + { + let skip = Skip::new("This test is not supported on macOS"); + skip.run_test(&config).unwrap(); + return; + } + + // Run the actual test on other platforms + #[cfg(not(target_os = "macos"))] + { + // ... normal test implementation ... + } +} +``` + +When a test is skipped: +- The skip reason is recorded in the test metadata +- The test runner logs the skip reason +- The test doesn't contribute to the output comparison +- If all variants are skipped, the test fails with an error + ## Harness logs If you suspect a bug in the test harness, you can view detailed test harness logs: diff --git a/tests/difftests/bin/src/runner.rs b/tests/difftests/bin/src/runner.rs index b2cf6cd506..0abba0cdac 100644 --- a/tests/difftests/bin/src/runner.rs +++ b/tests/difftests/bin/src/runner.rs @@ -336,6 +336,12 @@ impl Runner { if !metadata_content.trim().is_empty() { match serde_json::from_str::(&metadata_content) { Ok(metadata) => { + // Check if test was skipped + if let Some(skip_reason) = &metadata.skipped { + info!("Package '{}' was skipped: {}", pkg_name, skip_reason); + continue; + } + if let Some(meta_epsilon) = metadata.epsilon { epsilon = match epsilon { Some(e) => Some(e.max(meta_epsilon)), @@ -390,6 +396,12 @@ impl Runner { }); } + // Check if we have any valid outputs + if pkg_outputs.is_empty() { + error!("All packages were skipped. At least one package must produce output."); + return Err(RunnerError::EmptyOutput); + } + if pkg_outputs.iter().all(|po| po.output.is_empty()) { error!("All packages produced empty output."); return Err(RunnerError::EmptyOutput); diff --git a/tests/difftests/lib/Cargo.toml b/tests/difftests/lib/Cargo.toml index 31d12d00c7..e0cfbd8483 100644 --- a/tests/difftests/lib/Cargo.toml +++ b/tests/difftests/lib/Cargo.toml @@ -20,6 +20,8 @@ spirv-builder.workspace = true serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" wgpu = { version = "25.0.2", features = ["spirv", "vulkan-portability", "static-dxc"] } +vulkano = { version = "0.35.1", default-features = false } +naga = { version = "25.0.1", features = ["glsl-in", "spv-out"] } tempfile = "3.5" futures = "0.3.31" bytemuck = "1.21.0" diff --git a/tests/difftests/lib/src/config.rs b/tests/difftests/lib/src/config.rs index 51fd61d9f5..297d64f0ed 100644 --- a/tests/difftests/lib/src/config.rs +++ b/tests/difftests/lib/src/config.rs @@ -27,6 +27,10 @@ pub struct TestMetadata { /// (F32, F64, etc.) to enable epsilon-based comparison for numeric types. #[serde(default)] pub output_type: OutputType, + + /// If present, indicates this test was skipped with the given reason + #[serde(default, skip_serializing_if = "Option::is_none")] + pub skipped: Option, } /// Specifies how test output data should be interpreted for comparison diff --git a/tests/difftests/lib/src/scaffold/compute/backend.rs b/tests/difftests/lib/src/scaffold/compute/backend.rs new file mode 100644 index 0000000000..5099290615 --- /dev/null +++ b/tests/difftests/lib/src/scaffold/compute/backend.rs @@ -0,0 +1,84 @@ +use crate::config::Config; +use anyhow::Result; + +/// Configuration for a GPU buffer +#[derive(Clone)] +pub struct BufferConfig { + pub size: u64, + pub usage: BufferUsage, + pub initial_data: Option>, +} + +/// Buffer usage type +#[derive(Clone, Copy, PartialEq)] +pub enum BufferUsage { + Storage, + StorageReadOnly, + Uniform, +} + +/// A generic trait for compute backends +pub trait ComputeBackend: Sized { + /// Initialize the backend + fn init() -> Result; + + /// Create and run a compute shader with multiple buffers + fn run_compute( + &self, + spirv_bytes: &[u8], + entry_point: &str, + dispatch: [u32; 3], + buffers: Vec, + ) -> Result>>; +} + +/// A compute test that can run on any backend +pub struct ComputeTest { + backend: B, + spirv_bytes: Vec, + entry_point: String, + dispatch: [u32; 3], + buffers: Vec, +} + +impl ComputeTest { + pub fn new( + spirv_bytes: Vec, + entry_point: String, + dispatch: [u32; 3], + buffers: Vec, + ) -> Result { + Ok(Self { + backend: B::init()?, + spirv_bytes, + entry_point, + dispatch, + buffers, + }) + } + + pub fn run(self) -> Result>> { + self.backend.run_compute( + &self.spirv_bytes, + &self.entry_point, + self.dispatch, + self.buffers, + ) + } + + pub fn run_test(self, config: &Config) -> Result<()> { + let buffers = self.buffers.clone(); + let outputs = self.run()?; + // Write the first storage buffer output to the file + for (output, buffer_config) in outputs.iter().zip(&buffers) { + if matches!(buffer_config.usage, BufferUsage::Storage) && !output.is_empty() { + use std::fs::File; + use std::io::Write; + let mut f = File::create(&config.output_path)?; + f.write_all(output)?; + return Ok(()); + } + } + anyhow::bail!("No storage buffer output found") + } +} diff --git a/tests/difftests/lib/src/scaffold/compute/mod.rs b/tests/difftests/lib/src/scaffold/compute/mod.rs index 47dc6cafbb..a60e43b931 100644 --- a/tests/difftests/lib/src/scaffold/compute/mod.rs +++ b/tests/difftests/lib/src/scaffold/compute/mod.rs @@ -1,5 +1,10 @@ +mod backend; +mod vulkano; mod wgpu; + +pub use backend::{BufferConfig, BufferUsage, ComputeBackend, ComputeTest}; +pub use vulkano::VulkanoBackend; pub use wgpu::{ - BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTest, WgpuComputeTestMultiBuffer, + RustComputeShader, WgpuBackend, WgpuComputeTest, WgpuComputeTestMultiBuffer, WgpuComputeTestPushConstants, WgslComputeShader, }; diff --git a/tests/difftests/lib/src/scaffold/compute/vulkano.rs b/tests/difftests/lib/src/scaffold/compute/vulkano.rs new file mode 100644 index 0000000000..974856bba3 --- /dev/null +++ b/tests/difftests/lib/src/scaffold/compute/vulkano.rs @@ -0,0 +1,250 @@ +use super::backend::{BufferConfig, BufferUsage, ComputeBackend}; +use anyhow::{Context, Result}; +use std::sync::Arc; +use vulkano::{ + VulkanLibrary, + buffer::{Buffer, BufferCreateInfo, BufferUsage as VkBufferUsage}, + command_buffer::{ + AutoCommandBufferBuilder, CommandBufferUsage, allocator::StandardCommandBufferAllocator, + }, + descriptor_set::{ + DescriptorSet, WriteDescriptorSet, allocator::StandardDescriptorSetAllocator, + }, + device::{ + Device, DeviceCreateInfo, DeviceFeatures, Queue, QueueCreateInfo, QueueFlags, + physical::PhysicalDeviceType, + }, + instance::{Instance, InstanceCreateFlags, InstanceCreateInfo}, + memory::allocator::{AllocationCreateInfo, MemoryTypeFilter, StandardMemoryAllocator}, + pipeline::{ + ComputePipeline, Pipeline, PipelineLayout, PipelineShaderStageCreateInfo, + compute::ComputePipelineCreateInfo, layout::PipelineDescriptorSetLayoutCreateInfo, + }, + shader::{ShaderModule, ShaderModuleCreateInfo}, + sync::{self, GpuFuture}, +}; + +pub struct VulkanoBackend { + device: Arc, + queue: Arc, + memory_allocator: Arc, + command_buffer_allocator: Arc, + descriptor_set_allocator: Arc, +} + +impl ComputeBackend for VulkanoBackend { + fn init() -> Result { + let library = VulkanLibrary::new()?; + + // Use the library's supported API version + let api_version = library.api_version(); + eprintln!( + "Vulkan library API version: {}.{}.{}", + api_version.major, api_version.minor, api_version.patch + ); + + let instance = Instance::new(library, InstanceCreateInfo { + flags: InstanceCreateFlags::ENUMERATE_PORTABILITY, + ..Default::default() + })?; + + // Pick a physical device + let physical_device = instance + .enumerate_physical_devices()? + .min_by_key(|p| match p.properties().device_type { + PhysicalDeviceType::DiscreteGpu => 0, + PhysicalDeviceType::IntegratedGpu => 1, + PhysicalDeviceType::VirtualGpu => 2, + PhysicalDeviceType::Cpu => 3, + PhysicalDeviceType::Other => 4, + _ => 5, + }) + .context("No suitable physical device found")?; + + // Find a compute queue + let queue_family_index = physical_device + .queue_family_properties() + .iter() + .enumerate() + .position(|(_, q)| q.queue_flags.intersects(QueueFlags::COMPUTE)) + .context("No compute queue family found")? as u32; + + // Check if vulkan_memory_model is supported + let supported_features = physical_device.supported_features(); + let mut enabled_features = DeviceFeatures::empty(); + if supported_features.vulkan_memory_model { + enabled_features.vulkan_memory_model = true; + } + + let (device, mut queues) = Device::new(physical_device, DeviceCreateInfo { + queue_create_infos: vec![QueueCreateInfo { + queue_family_index, + ..Default::default() + }], + enabled_features, + ..Default::default() + })?; + + let queue = queues.next().context("No queue returned")?; + + let memory_allocator = Arc::new(StandardMemoryAllocator::new_default(device.clone())); + let command_buffer_allocator = Arc::new(StandardCommandBufferAllocator::new( + device.clone(), + Default::default(), + )); + let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new( + device.clone(), + Default::default(), + )); + + Ok(Self { + device, + queue, + memory_allocator, + command_buffer_allocator, + descriptor_set_allocator, + }) + } + + fn run_compute( + &self, + spirv_bytes: &[u8], + entry_point: &str, + dispatch: [u32; 3], + buffers: Vec, + ) -> Result>> { + // Convert bytes to u32 words + if spirv_bytes.len() % 4 != 0 { + anyhow::bail!("SPIR-V binary length is not a multiple of 4"); + } + let spirv_words: Vec = spirv_bytes + .chunks_exact(4) + .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(); + + // Create shader module + let shader = unsafe { + ShaderModule::new( + self.device.clone(), + ShaderModuleCreateInfo::new(&spirv_words), + )? + }; + + // Get the entry point + let entry_point = shader + .entry_point(entry_point) + .context("Entry point not found in shader module")?; + + // Create pipeline + let stage = PipelineShaderStageCreateInfo::new(entry_point); + let layout = PipelineLayout::new( + self.device.clone(), + PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage]) + .into_pipeline_layout_create_info(self.device.clone())?, + )?; + + let compute_pipeline = ComputePipeline::new( + self.device.clone(), + None, + ComputePipelineCreateInfo::stage_layout(stage, layout), + )?; + + // Create buffers + let mut gpu_buffers = Vec::new(); + for buffer_config in buffers.iter() { + let usage = match buffer_config.usage { + BufferUsage::Storage => VkBufferUsage::STORAGE_BUFFER, + BufferUsage::StorageReadOnly => VkBufferUsage::STORAGE_BUFFER, + BufferUsage::Uniform => VkBufferUsage::UNIFORM_BUFFER, + }; + + let buffer = if let Some(initial_data) = &buffer_config.initial_data { + Buffer::from_iter( + self.memory_allocator.clone(), + BufferCreateInfo { + usage: usage | VkBufferUsage::TRANSFER_SRC | VkBufferUsage::TRANSFER_DST, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE + | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE, + ..Default::default() + }, + initial_data.iter().cloned(), + )? + } else { + // Zero initialize + let zeros = vec![0u8; buffer_config.size as usize]; + Buffer::from_iter( + self.memory_allocator.clone(), + BufferCreateInfo { + usage: usage | VkBufferUsage::TRANSFER_SRC | VkBufferUsage::TRANSFER_DST, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE + | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE, + ..Default::default() + }, + zeros, + )? + }; + + gpu_buffers.push(buffer); + } + + // Create descriptor set + let layout = compute_pipeline.layout().set_layouts()[0].clone(); + let mut writes = Vec::new(); + for (i, buffer) in gpu_buffers.iter().enumerate() { + writes.push(WriteDescriptorSet::buffer(i as u32, buffer.clone())); + } + + let descriptor_set = + DescriptorSet::new(self.descriptor_set_allocator.clone(), layout, writes, [])?; + + // Create command buffer + let mut builder = AutoCommandBufferBuilder::primary( + self.command_buffer_allocator.clone(), + self.queue.queue_family_index(), + CommandBufferUsage::OneTimeSubmit, + )?; + + unsafe { + builder + .bind_pipeline_compute(compute_pipeline.clone())? + .bind_descriptor_sets( + vulkano::pipeline::PipelineBindPoint::Compute, + compute_pipeline.layout().clone(), + 0, + descriptor_set, + )? + .dispatch(dispatch)?; + } + + let command_buffer = builder.build()?; + + // Execute + let future = sync::now(self.device.clone()) + .then_execute(self.queue.clone(), command_buffer)? + .then_signal_fence_and_flush()?; + + future.wait(None)?; + + // Read back results + let mut results = Vec::new(); + for (i, buffer_config) in buffers.iter().enumerate() { + if matches!( + buffer_config.usage, + BufferUsage::Storage | BufferUsage::StorageReadOnly + ) { + let content_guard = gpu_buffers[i].read()?; + results.push(content_guard.to_vec()); + } else { + results.push(Vec::new()); + } + } + + Ok(results) + } +} diff --git a/tests/difftests/lib/src/scaffold/compute/wgpu.rs b/tests/difftests/lib/src/scaffold/compute/wgpu.rs index 7d0046066b..b92f3a7fa7 100644 --- a/tests/difftests/lib/src/scaffold/compute/wgpu.rs +++ b/tests/difftests/lib/src/scaffold/compute/wgpu.rs @@ -9,9 +9,15 @@ use std::{ fs::{self, File}, io::Write, path::PathBuf, + sync::Arc, }; use wgpu::{PipelineCompilationOptions, util::DeviceExt}; +use super::backend::{self, ComputeBackend}; + +pub type BufferConfig = backend::BufferConfig; +pub type BufferUsage = backend::BufferUsage; + /// Trait that creates a shader module and provides its entry point. pub trait ComputeShader { fn create_module( @@ -125,20 +131,6 @@ pub struct WgpuComputeTestPushConstants { push_constants_data: Vec, } -#[derive(Clone)] -pub struct BufferConfig { - pub size: u64, - pub usage: BufferUsage, - pub initial_data: Option>, -} - -#[derive(Clone, Copy, PartialEq)] -pub enum BufferUsage { - Storage, - StorageReadOnly, - Uniform, -} - impl WgpuComputeTest where S: ComputeShader, @@ -338,6 +330,176 @@ where } } +/// wgpu backend implementation for the generic ComputeBackend trait +pub struct WgpuBackend { + device: Arc, + queue: Arc, +} + +impl ComputeBackend for WgpuBackend { + fn init() -> anyhow::Result { + let (device, queue) = WgpuComputeTest::::init()?; + Ok(Self { + device: Arc::new(device), + queue: Arc::new(queue), + }) + } + + fn run_compute( + &self, + spirv_bytes: &[u8], + entry_point: &str, + dispatch: [u32; 3], + buffers: Vec, + ) -> anyhow::Result>> { + // Convert bytes to u32 words + if spirv_bytes.len() % 4 != 0 { + anyhow::bail!("SPIR-V binary length is not a multiple of 4"); + } + let spirv_words: Vec = bytemuck::cast_slice(spirv_bytes).to_vec(); + + // Create shader module + let module = self + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Compute Shader"), + source: wgpu::ShaderSource::SpirV(Cow::Owned(spirv_words)), + }); + + // Create pipeline + let pipeline = self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Compute Pipeline"), + layout: None, + module: &module, + entry_point: Some(entry_point), + compilation_options: PipelineCompilationOptions::default(), + cache: None, + }); + + // Create buffers + let mut gpu_buffers = Vec::new(); + for (i, buffer_config) in buffers.iter().enumerate() { + let usage = match buffer_config.usage { + BufferUsage::Storage => wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + BufferUsage::StorageReadOnly => { + wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC + } + BufferUsage::Uniform => wgpu::BufferUsages::UNIFORM, + }; + + let buffer = if let Some(initial_data) = &buffer_config.initial_data { + self.device + .create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some(&format!("Buffer {}", i)), + contents: initial_data, + usage, + }) + } else { + let buffer = self.device.create_buffer(&wgpu::BufferDescriptor { + label: Some(&format!("Buffer {}", i)), + size: buffer_config.size, + usage, + mapped_at_creation: true, + }); + { + // Zero the buffer + let initial_data = vec![0u8; buffer_config.size as usize]; + let mut mapping = buffer.slice(..).get_mapped_range_mut(); + mapping.copy_from_slice(&initial_data); + } + buffer.unmap(); + buffer + }; + gpu_buffers.push(buffer); + } + + // Create bind entries + let bind_entries: Vec<_> = gpu_buffers + .iter() + .enumerate() + .map(|(i, buffer)| wgpu::BindGroupEntry { + binding: i as u32, + resource: buffer.as_entire_binding(), + }) + .collect(); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &pipeline.get_bind_group_layout(0), + entries: &bind_entries, + label: Some("Compute Bind Group"), + }); + + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Compute Encoder"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Compute Pass"), + timestamp_writes: Default::default(), + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, &bind_group, &[]); + pass.dispatch_workgroups(dispatch[0], dispatch[1], dispatch[2]); + } + + // Create staging buffers and copy results + let mut staging_buffers = Vec::new(); + for (i, buffer_config) in buffers.iter().enumerate() { + if matches!( + buffer_config.usage, + BufferUsage::Storage | BufferUsage::StorageReadOnly + ) { + let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor { + label: Some(&format!("Staging Buffer {}", i)), + size: buffer_config.size, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + encoder.copy_buffer_to_buffer( + &gpu_buffers[i], + 0, + &staging_buffer, + 0, + buffer_config.size, + ); + staging_buffers.push(Some(staging_buffer)); + } else { + staging_buffers.push(None); + } + } + + self.queue.submit(Some(encoder.finish())); + + // Read back results + let mut results = Vec::new(); + for staging_buffer in staging_buffers.into_iter() { + if let Some(buffer) = staging_buffer { + let buffer_slice = buffer.slice(..); + let (sender, receiver) = futures::channel::oneshot::channel(); + buffer_slice.map_async(wgpu::MapMode::Read, move |res| { + let _ = sender.send(res); + }); + self.device.poll(wgpu::PollType::Wait)?; + block_on(receiver) + .context("mapping canceled")? + .context("mapping failed")?; + let data = buffer_slice.get_mapped_range().to_vec(); + buffer.unmap(); + results.push(data); + } else { + results.push(Vec::new()); + } + } + + Ok(results) + } +} + /// For WGSL, the code checks for "shader.wgsl" then "compute.wgsl". impl Default for WgslComputeShader { fn default() -> Self { diff --git a/tests/difftests/lib/src/scaffold/mod.rs b/tests/difftests/lib/src/scaffold/mod.rs index 6ad3d09910..91edf95419 100644 --- a/tests/difftests/lib/src/scaffold/mod.rs +++ b/tests/difftests/lib/src/scaffold/mod.rs @@ -1 +1,4 @@ pub mod compute; +pub mod skip; + +pub use skip::Skip; diff --git a/tests/difftests/lib/src/scaffold/skip.rs b/tests/difftests/lib/src/scaffold/skip.rs new file mode 100644 index 0000000000..88f3ab6173 --- /dev/null +++ b/tests/difftests/lib/src/scaffold/skip.rs @@ -0,0 +1,25 @@ +use crate::config::{Config, TestMetadata}; + +/// A scaffolding type for tests that should be skipped on certain platforms +pub struct Skip { + message: &'static str, +} + +impl Skip { + /// Create a Skip test with a reason message + pub fn new(message: &'static str) -> Self { + Self { message } + } + + /// Run the skip test - writes metadata indicating the test should be skipped + pub fn run_test(&self, config: &Config) -> anyhow::Result<()> { + // Write metadata indicating this test was skipped + let metadata = TestMetadata { + skipped: Some(self.message.to_string()), + ..Default::default() + }; + config.write_metadata(&metadata)?; + + Ok(()) + } +} diff --git a/tests/difftests/tests/Cargo.lock b/tests/difftests/tests/Cargo.lock index 9db57ba794..a4a7baf652 100644 --- a/tests/difftests/tests/Cargo.lock +++ b/tests/difftests/tests/Cargo.lock @@ -269,6 +269,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crunchy" version = "0.2.3" @@ -282,13 +297,25 @@ dependencies = [ "anyhow", "bytemuck", "futures", + "naga", "serde", "serde_json", "spirv-builder", "tempfile", + "vulkano", "wgpu", ] +[[package]] +name = "dispatch2" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec" +dependencies = [ + "bitflags 2.9.0", + "objc2", +] + [[package]] name = "document-features" version = "0.2.11" @@ -558,6 +585,7 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ + "bytemuck", "cfg-if", "crunchy", "num-traits", @@ -574,6 +602,12 @@ dependencies = [ "foldhash", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "heck" version = "0.5.0" @@ -770,6 +804,12 @@ dependencies = [ "paste", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "naga" version = "25.0.1" @@ -789,6 +829,7 @@ dependencies = [ "num-traits", "once_cell", "petgraph", + "pp-rs", "rustc-hash", "spirv", "strum", @@ -805,6 +846,16 @@ dependencies = [ "jni-sys", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -824,6 +875,67 @@ dependencies = [ "malloc_buf", ] +[[package]] +name = "objc2" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88c6597e14493ab2e44ce58f2fdecf095a51f12ca57bec060a11c57332520551" +dependencies = [ + "objc2-encode", +] + +[[package]] +name = "objc2-core-foundation" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" +dependencies = [ + "bitflags 2.9.0", + "dispatch2", + "objc2", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c" +dependencies = [ + "bitflags 2.9.0", + "objc2", + "objc2-core-foundation", +] + +[[package]] +name = "objc2-metal" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f246c183239540aab1782457b35ab2040d4259175bd1d0c58e46ada7b47a874" +dependencies = [ + "bitflags 2.9.0", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-quartz-core" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ffb6a0cd5f182dc964334388560b12a57f7b74b3e2dec5e2722aa2dfb2ccd5" +dependencies = [ + "bitflags 2.9.0", + "objc2", + "objc2-core-foundation", + "objc2-foundation", + "objc2-metal", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -904,6 +1016,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "pp-rs" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb458bb7f6e250e6eb79d5026badc10a3ebb8f9a15d1fff0f13d17c71f4d6dee" +dependencies = [ + "unicode-xid", +] + [[package]] name = "presser" version = "0.3.1" @@ -969,6 +1090,18 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" +[[package]] +name = "raw-window-metal" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40d213455a5f1dc59214213c7330e074ddf8114c9a42411eb890c767357ce135" +dependencies = [ + "objc2", + "objc2-core-foundation", + "objc2-foundation", + "objc2-quartz-core", +] + [[package]] name = "redox_syscall" version = "0.5.10" @@ -1106,6 +1239,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "slabbin" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9db491c0d4152a069911a0fbdaca959691bf0b9d7110d98a7ed1c8e59b79ab30" + [[package]] name = "slotmap" version = "1.0.7" @@ -1192,7 +1331,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "rustversion", @@ -1272,6 +1411,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + [[package]] name = "trig_ops-rust" version = "0.0.0" @@ -1299,6 +1447,12 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "vector_extract_insert-rust" version = "0.0.0" @@ -1354,6 +1508,44 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vk-parse" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3859da4d7b98bec73e68fb65815d47a263819c415c90eed42b80440a02cbce8c" +dependencies = [ + "xml-rs", +] + +[[package]] +name = "vulkano" +version = "0.35.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08840c2b51759a6f88f26f5ea378bc8b5c199a5b4760ddda292304be087249c4" +dependencies = [ + "ash", + "bytemuck", + "crossbeam-queue", + "foldhash", + "half", + "heck 0.4.1", + "indexmap", + "libloading", + "nom", + "once_cell", + "parking_lot", + "proc-macro2", + "quote", + "raw-window-handle", + "raw-window-metal", + "serde", + "serde_json", + "slabbin", + "smallvec", + "thread_local", + "vk-parse", +] + [[package]] name = "wasi" version = "0.13.3+wasi-0.2.2" @@ -1757,6 +1949,16 @@ dependencies = [ "spirv-std", ] +[[package]] +name = "workgroup_memory-vulkano" +version = "0.0.0" +dependencies = [ + "bytemuck", + "difftest", + "spirv-builder", + "spirv-std", +] + [[package]] name = "workgroup_memory-wgsl" version = "0.0.0" diff --git a/tests/difftests/tests/Cargo.toml b/tests/difftests/tests/Cargo.toml index d5f4b492da..e57c1fe71d 100644 --- a/tests/difftests/tests/Cargo.toml +++ b/tests/difftests/tests/Cargo.toml @@ -7,6 +7,7 @@ members = [ "arch/atomic_ops/atomic_ops-wgsl", "arch/workgroup_memory/workgroup_memory-rust", "arch/workgroup_memory/workgroup_memory-wgsl", + "arch/workgroup_memory/workgroup_memory-vulkano", "arch/memory_barriers/memory_barriers-rust", "arch/memory_barriers/memory_barriers-wgsl", "arch/vector_extract_insert/vector_extract_insert-rust", diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs index 9b0b6aef92..d424f066b4 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/lib.rs @@ -1,77 +1,3 @@ #![no_std] -use spirv_std::arch::workgroup_memory_barrier_with_group_sync; -use spirv_std::spirv; - -#[spirv(compute(threads(64)))] -pub fn main_cs( - #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32], - #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [u32], - #[spirv(local_invocation_id)] local_id: spirv_std::glam::UVec3, - #[spirv(workgroup)] shared: &mut [u32; 64], -) { - let lid = local_id.x as usize; - - // Load data into shared memory - shared[lid] = input[lid]; - - // Synchronize to ensure all threads have loaded - unsafe { - workgroup_memory_barrier_with_group_sync(); - } - - // Each thread sums its value with its neighbor (reduction step) - if lid < 32 { - shared[lid] += shared[lid + 32]; - } - - // Synchronize again - unsafe { - workgroup_memory_barrier_with_group_sync(); - } - - if lid < 16 { - shared[lid] += shared[lid + 16]; - } - - unsafe { - workgroup_memory_barrier_with_group_sync(); - } - - if lid < 8 { - shared[lid] += shared[lid + 8]; - } - - unsafe { - workgroup_memory_barrier_with_group_sync(); - } - - if lid < 4 { - shared[lid] += shared[lid + 4]; - } - - unsafe { - workgroup_memory_barrier_with_group_sync(); - } - - if lid < 2 { - shared[lid] += shared[lid + 2]; - } - - unsafe { - workgroup_memory_barrier_with_group_sync(); - } - - if lid < 1 { - shared[lid] += shared[lid + 1]; - } - - unsafe { - workgroup_memory_barrier_with_group_sync(); - } - - // Write final result - if lid == 0 { - output[0] = shared[0]; - } -} +include!("shader.rs"); diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs index f96dbca335..6d80e70222 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs @@ -34,6 +34,7 @@ fn main() { let metadata = difftest::config::TestMetadata { epsilon: None, output_type: difftest::config::OutputType::U32, + ..Default::default() }; config.write_metadata(&metadata).unwrap(); diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/shader.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/shader.rs new file mode 100644 index 0000000000..d743a1cea8 --- /dev/null +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/shader.rs @@ -0,0 +1,75 @@ +use spirv_std::arch::workgroup_memory_barrier_with_group_sync; +use spirv_std::spirv; + +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [u32], + #[spirv(local_invocation_id)] local_id: spirv_std::glam::UVec3, + #[spirv(workgroup)] shared: &mut [u32; 64], +) { + let lid = local_id.x as usize; + + // Load data into shared memory + shared[lid] = input[lid]; + + // Synchronize to ensure all threads have loaded + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + // Each thread sums its value with its neighbor (reduction step) + if lid < 32 { + shared[lid] += shared[lid + 32]; + } + + // Synchronize again + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + if lid < 16 { + shared[lid] += shared[lid + 16]; + } + + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + if lid < 8 { + shared[lid] += shared[lid + 8]; + } + + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + if lid < 4 { + shared[lid] += shared[lid + 4]; + } + + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + if lid < 2 { + shared[lid] += shared[lid + 2]; + } + + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + if lid < 1 { + shared[lid] += shared[lid + 1]; + } + + unsafe { + workgroup_memory_barrier_with_group_sync(); + } + + // Write final result + if lid == 0 { + output[0] = shared[0]; + } +} diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml new file mode 100644 index 0000000000..1968af9717 --- /dev/null +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "workgroup_memory-vulkano" +edition.workspace = true + +[lints] +workspace = true + + +[lib] +crate-type = ["dylib"] + +# Common deps +[dependencies] + +# GPU deps +spirv-std.workspace = true + +# CPU deps +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +spirv-builder.workspace = true +difftest.workspace = true +bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/lib.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/lib.rs new file mode 100644 index 0000000000..18e58216a7 --- /dev/null +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/lib.rs @@ -0,0 +1,8 @@ +#![no_std] + +// Include the shader content from workgroup_memory-rust to keep them in sync +// Note: We need to skip the #![no_std] from the included file +include!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/../workgroup_memory-rust/src/shader.rs" +)); diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/main.rs new file mode 100644 index 0000000000..59668783c8 --- /dev/null +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/main.rs @@ -0,0 +1,88 @@ +#[cfg(not(target_arch = "spirv"))] +fn main() { + use difftest::config::Config; + use difftest::scaffold::Skip; + + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Skip on macOS due to Vulkan/MoltenVK configuration issues + #[cfg(target_os = "macos")] + { + let skip = + Skip::new("Vulkano tests are skipped on macOS due to MoltenVK configuration issues"); + skip.run_test(&config).unwrap(); + return; + } + + // Run the actual test on other platforms + #[cfg(not(target_os = "macos"))] + { + use difftest::scaffold::compute::{BufferConfig, BufferUsage, ComputeTest, VulkanoBackend}; + use spirv_builder::{ModuleResult, SpirvBuilder}; + use std::fs; + + // Build the Rust shader to SPIR-V + let builder = SpirvBuilder::new(".", "spirv-unknown-vulkan1.2") + .print_metadata(spirv_builder::MetadataPrintout::None) + .release(true) + .multimodule(false) + .shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit) + .preserve_bindings(true) + .capability(spirv_builder::Capability::VulkanMemoryModel); + + let artifact = builder.build().expect("Failed to build SPIR-V"); + + if artifact.entry_points.len() != 1 { + panic!( + "Expected exactly one entry point, found {}", + artifact.entry_points.len() + ); + } + let entry_point = artifact.entry_points.into_iter().next().unwrap(); + + let spirv_bytes = match artifact.module { + ModuleResult::SingleModule(path) => { + fs::read(&path).expect("Failed to read SPIR-V file") + } + ModuleResult::MultiModule(_) => panic!("Unexpected multi-module result"), + }; + + // Initialize input buffer with values to sum + let input_data: Vec = (1..=64).collect(); + let input_bytes = bytemuck::cast_slice(&input_data).to_vec(); + + let buffers = vec![ + BufferConfig { + size: 256, // 64 u32 values + usage: BufferUsage::StorageReadOnly, + initial_data: Some(input_bytes), + }, + BufferConfig { + size: 4, // 1 u32 value for output + usage: BufferUsage::Storage, + initial_data: None, + }, + ]; + + let test = ComputeTest::::new( + spirv_bytes, + entry_point, + [1, 1, 1], // Single workgroup with 64 threads + buffers, + ) + .unwrap(); + + // Write metadata for U32 comparison + let metadata = difftest::config::TestMetadata { + epsilon: None, + output_type: difftest::config::OutputType::U32, + ..Default::default() + }; + config.write_metadata(&metadata).unwrap(); + + test.run_test(&config).unwrap(); + } +} + +#[cfg(target_arch = "spirv")] +fn main() {} diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs index 49989b6668..a6470b95ca 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs @@ -33,6 +33,7 @@ fn main() { let metadata = difftest::config::TestMetadata { epsilon: None, output_type: difftest::config::OutputType::U32, + ..Default::default() }; config.write_metadata(&metadata).unwrap(); diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs index f7a3cb3958..23c0e027df 100644 --- a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs @@ -55,6 +55,7 @@ fn main() { let metadata = difftest::config::TestMetadata { epsilon: Some(2e-6), // Small epsilon for last-bit differences output_type: difftest::config::OutputType::F32, + ..Default::default() }; config.write_metadata(&metadata).unwrap(); diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs index 09ecdd545c..b97a39a536 100644 --- a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs @@ -54,6 +54,7 @@ fn main() { let metadata = difftest::config::TestMetadata { epsilon: Some(2e-6), // Small epsilon for last-bit differences output_type: difftest::config::OutputType::F32, + ..Default::default() }; config.write_metadata(&metadata).unwrap(); diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs index da6752004f..2c979bb0dd 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs @@ -55,6 +55,7 @@ fn main() { let metadata = difftest::config::TestMetadata { epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding output_type: difftest::config::OutputType::F32, + ..Default::default() }; config.write_metadata(&metadata).unwrap(); diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs index 244d119609..c3ee27b2db 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs @@ -54,6 +54,7 @@ fn main() { let metadata = difftest::config::TestMetadata { epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding output_type: difftest::config::OutputType::F32, + ..Default::default() }; config.write_metadata(&metadata).unwrap(); diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs index 7a8d4128f2..3ef378155b 100644 --- a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs @@ -55,6 +55,7 @@ fn main() { let metadata = difftest::config::TestMetadata { epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding output_type: difftest::config::OutputType::F32, + ..Default::default() }; config.write_metadata(&metadata).unwrap(); diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs index b9a98fc170..03f2e97b44 100644 --- a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs @@ -54,6 +54,7 @@ fn main() { let metadata = difftest::config::TestMetadata { epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding output_type: difftest::config::OutputType::F32, + ..Default::default() }; config.write_metadata(&metadata).unwrap(); From d4997632c32b8b1625beb400c9bcca86813772df Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Mon, 7 Jul 2025 09:21:05 +0200 Subject: [PATCH 10/24] Fix pdb/path issue on windows --- .../arch/workgroup_memory/workgroup_memory-rust/Cargo.toml | 1 + .../arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml index 0091d02acf..2e55e489b1 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml @@ -6,6 +6,7 @@ edition.workspace = true workspace = true [lib] +name = "workgroup_memory_rust_shader" crate-type = ["dylib"] # Common deps diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml index 1968af9717..3a3dbd0ea7 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml @@ -5,8 +5,8 @@ edition.workspace = true [lints] workspace = true - [lib] +name = "workgroup_memory_vulkano_shader" crate-type = ["dylib"] # Common deps From 94a27625317f2090bdd405de7e358f707f98f7fb Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Mon, 7 Jul 2025 22:20:13 +0200 Subject: [PATCH 11/24] Ash variant of a difftest --- Cargo.lock | 229 ++------- tests/difftests/README.md | 2 +- tests/difftests/lib/Cargo.toml | 2 +- .../difftests/lib/src/scaffold/compute/ash.rs | 436 ++++++++++++++++++ .../difftests/lib/src/scaffold/compute/mod.rs | 4 +- .../lib/src/scaffold/compute/vulkano.rs | 250 ---------- tests/difftests/tests/Cargo.lock | 183 +------- tests/difftests/tests/Cargo.toml | 2 +- .../Cargo.toml | 4 +- .../src/lib.rs | 0 .../src/main.rs | 9 +- .../workgroup_memory-rust/Cargo.toml | 1 + .../matrix_ops/matrix_ops-rust/src/main.rs | 2 +- .../matrix_ops/matrix_ops-wgsl/src/main.rs | 2 +- 14 files changed, 494 insertions(+), 632 deletions(-) create mode 100644 tests/difftests/lib/src/scaffold/compute/ash.rs delete mode 100644 tests/difftests/lib/src/scaffold/compute/vulkano.rs rename tests/difftests/tests/arch/workgroup_memory/{workgroup_memory-vulkano => workgroup_memory-ash}/Cargo.toml (80%) rename tests/difftests/tests/arch/workgroup_memory/{workgroup_memory-vulkano => workgroup_memory-ash}/src/lib.rs (100%) rename tests/difftests/tests/arch/workgroup_memory/{workgroup_memory-vulkano => workgroup_memory-ash}/src/main.rs (90%) diff --git a/Cargo.lock b/Cargo.lock index 4e898aee13..65974cb4d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -213,7 +213,7 @@ checksum = "52bca67b61cb81e5553babde81b8211f713cb6db79766f80168f3e5f40ea6c82" dependencies = [ "ash", "raw-window-handle 0.6.2", - "raw-window-metal 0.4.0", + "raw-window-metal", ] [[package]] @@ -276,7 +276,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c132eebf10f5cad5289222520a4a058514204aed6d791f1cf4fe8088b82d15f" dependencies = [ - "objc2 0.5.2", + "objc2", ] [[package]] @@ -438,7 +438,7 @@ version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn", @@ -658,15 +658,6 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "crossbeam-queue" -version = "0.3.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" -dependencies = [ - "crossbeam-utils", -] - [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -724,6 +715,7 @@ name = "difftest" version = "0.9.0" dependencies = [ "anyhow", + "ash", "bytemuck", "futures", "naga", @@ -731,7 +723,6 @@ dependencies = [ "serde_json", "spirv-builder", "tempfile", - "vulkano", "wgpu", ] @@ -781,16 +772,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd0c93bb4b0c6d9b77f4435b0ae98c24d17f1c45b2ff844c6151a07256ca923b" -[[package]] -name = "dispatch2" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec" -dependencies = [ - "bitflags 2.9.1", - "objc2 0.6.1", -] - [[package]] name = "dlib" version = "0.5.2" @@ -1246,7 +1227,6 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ - "bytemuck", "cfg-if", "crunchy", "num-traits", @@ -1273,12 +1253,6 @@ dependencies = [ "foldhash", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -1618,12 +1592,6 @@ dependencies = [ "x11-dl", ] -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - [[package]] name = "miniz_oxide" version = "0.8.2" @@ -1748,16 +1716,6 @@ dependencies = [ "memoffset", ] -[[package]] -name = "nom" -version = "7.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" -dependencies = [ - "memchr", - "minimal-lexical", -] - [[package]] name = "notify" version = "7.0.0" @@ -1877,15 +1835,6 @@ dependencies = [ "objc2-encode", ] -[[package]] -name = "objc2" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88c6597e14493ab2e44ce58f2fdecf095a51f12ca57bec060a11c57332520551" -dependencies = [ - "objc2-encode", -] - [[package]] name = "objc2-app-kit" version = "0.2.2" @@ -1895,11 +1844,11 @@ dependencies = [ "bitflags 2.9.1", "block2", "libc", - "objc2 0.5.2", + "objc2", "objc2-core-data", "objc2-core-image", - "objc2-foundation 0.2.2", - "objc2-quartz-core 0.2.2", + "objc2-foundation", + "objc2-quartz-core", ] [[package]] @@ -1910,9 +1859,9 @@ checksum = "74dd3b56391c7a0596a295029734d3c1c5e7e510a4cb30245f8221ccea96b009" dependencies = [ "bitflags 2.9.1", "block2", - "objc2 0.5.2", + "objc2", "objc2-core-location", - "objc2-foundation 0.2.2", + "objc2-foundation", ] [[package]] @@ -1922,8 +1871,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5ff520e9c33812fd374d8deecef01d4a840e7b41862d849513de77e44aa4889" dependencies = [ "block2", - "objc2 0.5.2", - "objc2-foundation 0.2.2", + "objc2", + "objc2-foundation", ] [[package]] @@ -1934,19 +1883,8 @@ checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef" dependencies = [ "bitflags 2.9.1", "block2", - "objc2 0.5.2", - "objc2-foundation 0.2.2", -] - -[[package]] -name = "objc2-core-foundation" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" -dependencies = [ - "bitflags 2.9.1", - "dispatch2", - "objc2 0.6.1", + "objc2", + "objc2-foundation", ] [[package]] @@ -1956,9 +1894,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55260963a527c99f1819c4f8e3b47fe04f9650694ef348ffd2227e8196d34c80" dependencies = [ "block2", - "objc2 0.5.2", - "objc2-foundation 0.2.2", - "objc2-metal 0.2.2", + "objc2", + "objc2-foundation", + "objc2-metal", ] [[package]] @@ -1968,9 +1906,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "000cfee34e683244f284252ee206a27953279d370e309649dc3ee317b37e5781" dependencies = [ "block2", - "objc2 0.5.2", + "objc2", "objc2-contacts", - "objc2-foundation 0.2.2", + "objc2-foundation", ] [[package]] @@ -1989,18 +1927,7 @@ dependencies = [ "block2", "dispatch", "libc", - "objc2 0.5.2", -] - -[[package]] -name = "objc2-foundation" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c" -dependencies = [ - "bitflags 2.9.1", - "objc2 0.6.1", - "objc2-core-foundation", + "objc2", ] [[package]] @@ -2010,9 +1937,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1a1ae721c5e35be65f01a03b6d2ac13a54cb4fa70d8a5da293d7b0020261398" dependencies = [ "block2", - "objc2 0.5.2", + "objc2", "objc2-app-kit", - "objc2-foundation 0.2.2", + "objc2-foundation", ] [[package]] @@ -2023,19 +1950,8 @@ checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6" dependencies = [ "bitflags 2.9.1", "block2", - "objc2 0.5.2", - "objc2-foundation 0.2.2", -] - -[[package]] -name = "objc2-metal" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f246c183239540aab1782457b35ab2040d4259175bd1d0c58e46ada7b47a874" -dependencies = [ - "bitflags 2.9.1", - "objc2 0.6.1", - "objc2-foundation 0.3.1", + "objc2", + "objc2-foundation", ] [[package]] @@ -2046,22 +1962,9 @@ checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a" dependencies = [ "bitflags 2.9.1", "block2", - "objc2 0.5.2", - "objc2-foundation 0.2.2", - "objc2-metal 0.2.2", -] - -[[package]] -name = "objc2-quartz-core" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ffb6a0cd5f182dc964334388560b12a57f7b74b3e2dec5e2722aa2dfb2ccd5" -dependencies = [ - "bitflags 2.9.1", - "objc2 0.6.1", - "objc2-core-foundation", - "objc2-foundation 0.3.1", - "objc2-metal 0.3.1", + "objc2", + "objc2-foundation", + "objc2-metal", ] [[package]] @@ -2070,8 +1973,8 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a684efe3dec1b305badae1a28f6555f6ddd3bb2c2267896782858d5a78404dc" dependencies = [ - "objc2 0.5.2", - "objc2-foundation 0.2.2", + "objc2", + "objc2-foundation", ] [[package]] @@ -2082,14 +1985,14 @@ checksum = "b8bb46798b20cd6b91cbd113524c490f1686f4c4e8f49502431415f3512e2b6f" dependencies = [ "bitflags 2.9.1", "block2", - "objc2 0.5.2", + "objc2", "objc2-cloud-kit", "objc2-core-data", "objc2-core-image", "objc2-core-location", - "objc2-foundation 0.2.2", + "objc2-foundation", "objc2-link-presentation", - "objc2-quartz-core 0.2.2", + "objc2-quartz-core", "objc2-symbols", "objc2-uniform-type-identifiers", "objc2-user-notifications", @@ -2102,8 +2005,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44fa5f9748dbfe1ca6c0b79ad20725a11eca7c2218bceb4b005cb1be26273bfe" dependencies = [ "block2", - "objc2 0.5.2", - "objc2-foundation 0.2.2", + "objc2", + "objc2-foundation", ] [[package]] @@ -2114,9 +2017,9 @@ checksum = "76cfcbf642358e8689af64cee815d139339f3ed8ad05103ed5eaf73db8d84cb3" dependencies = [ "bitflags 2.9.1", "block2", - "objc2 0.5.2", + "objc2", "objc2-core-location", - "objc2-foundation 0.2.2", + "objc2-foundation", ] [[package]] @@ -2433,18 +2336,6 @@ dependencies = [ "raw-window-handle 0.6.2", ] -[[package]] -name = "raw-window-metal" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40d213455a5f1dc59214213c7330e074ddf8114c9a42411eb890c767357ce135" -dependencies = [ - "objc2 0.6.1", - "objc2-core-foundation", - "objc2-foundation 0.3.1", - "objc2-quartz-core 0.3.1", -] - [[package]] name = "rayon" version = "1.10.0" @@ -2840,12 +2731,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "slabbin" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9db491c0d4152a069911a0fbdaca959691bf0b9d7110d98a7ed1c8e59b79ab30" - [[package]] name = "slotmap" version = "1.0.7" @@ -3033,7 +2918,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", @@ -3416,44 +3301,6 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" -[[package]] -name = "vk-parse" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3859da4d7b98bec73e68fb65815d47a263819c415c90eed42b80440a02cbce8c" -dependencies = [ - "xml-rs", -] - -[[package]] -name = "vulkano" -version = "0.35.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08840c2b51759a6f88f26f5ea378bc8b5c199a5b4760ddda292304be087249c4" -dependencies = [ - "ash", - "bytemuck", - "crossbeam-queue", - "foldhash", - "half", - "heck 0.4.1", - "indexmap", - "libloading", - "nom", - "once_cell", - "parking_lot", - "proc-macro2", - "quote", - "raw-window-handle 0.6.2", - "raw-window-metal 1.1.0", - "serde", - "serde_json", - "slabbin", - "smallvec", - "thread_local", - "vk-parse", -] - [[package]] name = "walkdir" version = "2.5.0" @@ -4229,9 +4076,9 @@ dependencies = [ "libc", "memmap2", "ndk", - "objc2 0.5.2", + "objc2", "objc2-app-kit", - "objc2-foundation 0.2.2", + "objc2-foundation", "objc2-ui-kit", "orbclient", "percent-encoding", diff --git a/tests/difftests/README.md b/tests/difftests/README.md index d825f43f8a..0940f6a3bd 100644 --- a/tests/difftests/README.md +++ b/tests/difftests/README.md @@ -109,7 +109,7 @@ The library provides helper types for common test patterns: **Backend types:** - `WgpuBackend` - Default wgpu-based compute backend -- `VulkanoBackend` - Vulkano-based compute backend (useful for testing different GPU drivers) +- `AshBackend` - Ash-based compute backend (low-level Vulkan access for debugging driver issues) For examples, see: diff --git a/tests/difftests/lib/Cargo.toml b/tests/difftests/lib/Cargo.toml index e0cfbd8483..c014741df6 100644 --- a/tests/difftests/lib/Cargo.toml +++ b/tests/difftests/lib/Cargo.toml @@ -20,7 +20,7 @@ spirv-builder.workspace = true serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" wgpu = { version = "25.0.2", features = ["spirv", "vulkan-portability", "static-dxc"] } -vulkano = { version = "0.35.1", default-features = false } +ash = { version = "0.38", default-features = false } naga = { version = "25.0.1", features = ["glsl-in", "spv-out"] } tempfile = "3.5" futures = "0.3.31" diff --git a/tests/difftests/lib/src/scaffold/compute/ash.rs b/tests/difftests/lib/src/scaffold/compute/ash.rs new file mode 100644 index 0000000000..c9670f52fa --- /dev/null +++ b/tests/difftests/lib/src/scaffold/compute/ash.rs @@ -0,0 +1,436 @@ +use super::backend::{BufferConfig, BufferUsage, ComputeBackend}; +use anyhow::{Context, Result, bail}; +use ash::vk; +use std::ffi::{CStr, CString}; + +pub struct AshBackend { + instance: ash::Instance, + device: ash::Device, + queue: vk::Queue, + command_pool: vk::CommandPool, + descriptor_pool: vk::DescriptorPool, + memory_properties: vk::PhysicalDeviceMemoryProperties, + _entry: ash::Entry, +} + +impl AshBackend { + fn find_memory_type( + &self, + type_filter: u32, + properties: vk::MemoryPropertyFlags, + ) -> Option { + for i in 0..self.memory_properties.memory_type_count { + if (type_filter & (1 << i)) != 0 + && self.memory_properties.memory_types[i as usize] + .property_flags + .contains(properties) + { + return Some(i); + } + } + None + } + + fn create_buffer(&self, config: &BufferConfig) -> Result<(vk::Buffer, vk::DeviceMemory)> { + unsafe { + let usage = match config.usage { + BufferUsage::Storage => vk::BufferUsageFlags::STORAGE_BUFFER, + BufferUsage::StorageReadOnly => vk::BufferUsageFlags::STORAGE_BUFFER, + BufferUsage::Uniform => vk::BufferUsageFlags::UNIFORM_BUFFER, + }; + + let buffer_create_info = vk::BufferCreateInfo::default() + .size(config.size) + .usage( + usage | vk::BufferUsageFlags::TRANSFER_SRC | vk::BufferUsageFlags::TRANSFER_DST, + ) + .sharing_mode(vk::SharingMode::EXCLUSIVE); + + let buffer = self + .device + .create_buffer(&buffer_create_info, None) + .context("Failed to create buffer")?; + + let memory_requirements = self.device.get_buffer_memory_requirements(buffer); + + let memory_type_index = self + .find_memory_type( + memory_requirements.memory_type_bits, + vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT, + ) + .context("Failed to find suitable memory type")?; + + let allocate_info = vk::MemoryAllocateInfo::default() + .allocation_size(memory_requirements.size) + .memory_type_index(memory_type_index); + + let memory = self + .device + .allocate_memory(&allocate_info, None) + .context("Failed to allocate memory")?; + + self.device + .bind_buffer_memory(buffer, memory, 0) + .context("Failed to bind buffer memory")?; + + // Initialize buffer if initial data provided + if let Some(data) = &config.initial_data { + let mapped_ptr = self + .device + .map_memory(memory, 0, config.size, vk::MemoryMapFlags::empty()) + .context("Failed to map memory")?; + + std::ptr::copy_nonoverlapping(data.as_ptr(), mapped_ptr as *mut u8, data.len()); + + self.device.unmap_memory(memory); + } + + Ok((buffer, memory)) + } + } +} + +impl ComputeBackend for AshBackend { + fn init() -> Result { + unsafe { + let entry = ash::Entry::load().context("Failed to load Vulkan entry")?; + + // Create instance + let app_info = vk::ApplicationInfo::default() + .application_name(CStr::from_bytes_with_nul_unchecked(b"difftest\0")) + .application_version(vk::make_api_version(0, 1, 0, 0)) + .engine_name(CStr::from_bytes_with_nul_unchecked(b"difftest\0")) + .engine_version(vk::make_api_version(0, 1, 0, 0)) + .api_version(vk::API_VERSION_1_2); + + let instance_create_info = + vk::InstanceCreateInfo::default().application_info(&app_info); + + let instance = entry + .create_instance(&instance_create_info, None) + .context("Failed to create Vulkan instance")?; + + // Select physical device + let physical_devices = instance + .enumerate_physical_devices() + .context("Failed to enumerate physical devices")?; + + if physical_devices.is_empty() { + bail!("No Vulkan devices found"); + } + + let physical_device = physical_devices[0]; + let memory_properties = instance.get_physical_device_memory_properties(physical_device); + + // Find compute queue family + let queue_family_properties = + instance.get_physical_device_queue_family_properties(physical_device); + let queue_family_index = queue_family_properties + .iter() + .enumerate() + .find(|(_, props)| props.queue_flags.contains(vk::QueueFlags::COMPUTE)) + .map(|(index, _)| index as u32) + .context("No compute queue family found")?; + + // Create device + let priorities = [1.0]; + let queue_create_info = vk::DeviceQueueCreateInfo::default() + .queue_family_index(queue_family_index) + .queue_priorities(&priorities); + + let device_features = vk::PhysicalDeviceFeatures::default(); + + let queue_create_infos = [queue_create_info]; + let device_create_info = vk::DeviceCreateInfo::default() + .queue_create_infos(&queue_create_infos) + .enabled_features(&device_features); + + let device = instance + .create_device(physical_device, &device_create_info, None) + .context("Failed to create Vulkan device")?; + + let queue = device.get_device_queue(queue_family_index, 0); + + // Create command pool + let command_pool_create_info = vk::CommandPoolCreateInfo::default() + .queue_family_index(queue_family_index) + .flags(vk::CommandPoolCreateFlags::RESET_COMMAND_BUFFER); + + let command_pool = device + .create_command_pool(&command_pool_create_info, None) + .context("Failed to create command pool")?; + + // Create descriptor pool + let descriptor_pool_sizes = vec![ + vk::DescriptorPoolSize { + ty: vk::DescriptorType::STORAGE_BUFFER, + descriptor_count: 16, + }, + vk::DescriptorPoolSize { + ty: vk::DescriptorType::UNIFORM_BUFFER, + descriptor_count: 16, + }, + ]; + + let descriptor_pool_create_info = vk::DescriptorPoolCreateInfo::default() + .pool_sizes(&descriptor_pool_sizes) + .max_sets(16); + + let descriptor_pool = device + .create_descriptor_pool(&descriptor_pool_create_info, None) + .context("Failed to create descriptor pool")?; + + Ok(Self { + instance, + device, + queue, + command_pool, + descriptor_pool, + memory_properties, + _entry: entry, + }) + } + } + + fn run_compute( + &self, + spirv_bytes: &[u8], + entry_point: &str, + dispatch: [u32; 3], + buffers: Vec, + ) -> Result>> { + unsafe { + // Create shader module + let spirv_u32: Vec = spirv_bytes + .chunks_exact(4) + .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(); + + let shader_module_create_info = vk::ShaderModuleCreateInfo::default().code(&spirv_u32); + + let shader_module = self + .device + .create_shader_module(&shader_module_create_info, None) + .context("Failed to create shader module")?; + + // Create descriptor set layout + let mut layout_bindings = Vec::new(); + for (i, buffer) in buffers.iter().enumerate() { + let descriptor_type = match buffer.usage { + BufferUsage::Storage | BufferUsage::StorageReadOnly => { + vk::DescriptorType::STORAGE_BUFFER + } + BufferUsage::Uniform => vk::DescriptorType::UNIFORM_BUFFER, + }; + let binding = vk::DescriptorSetLayoutBinding::default() + .binding(i as u32) + .descriptor_type(descriptor_type) + .descriptor_count(1) + .stage_flags(vk::ShaderStageFlags::COMPUTE); + layout_bindings.push(binding); + } + + let descriptor_set_layout_create_info = + vk::DescriptorSetLayoutCreateInfo::default().bindings(&layout_bindings); + + let descriptor_set_layout = self + .device + .create_descriptor_set_layout(&descriptor_set_layout_create_info, None) + .context("Failed to create descriptor set layout")?; + + // Create pipeline layout + let set_layouts = [descriptor_set_layout]; + let pipeline_layout_create_info = + vk::PipelineLayoutCreateInfo::default().set_layouts(&set_layouts); + + let pipeline_layout = self + .device + .create_pipeline_layout(&pipeline_layout_create_info, None) + .context("Failed to create pipeline layout")?; + + // Create compute pipeline + let entry_point_cstring = CString::new(entry_point)?; + let stage_create_info = vk::PipelineShaderStageCreateInfo::default() + .stage(vk::ShaderStageFlags::COMPUTE) + .module(shader_module) + .name(&entry_point_cstring); + + let compute_pipeline_create_info = vk::ComputePipelineCreateInfo::default() + .stage(stage_create_info) + .layout(pipeline_layout); + + let pipelines = self + .device + .create_compute_pipelines( + vk::PipelineCache::null(), + &[compute_pipeline_create_info], + None, + ) + .map_err(|(_, e)| e) + .context("Failed to create compute pipeline")?; + + let pipeline = pipelines[0]; + + // Create buffers + let mut vk_buffers = Vec::new(); + let mut buffer_memories = Vec::new(); + + for buffer_config in &buffers { + let (buffer, memory) = self.create_buffer(buffer_config)?; + vk_buffers.push(buffer); + buffer_memories.push(memory); + } + + // Allocate descriptor set + let set_layouts = [descriptor_set_layout]; + let descriptor_set_allocate_info = vk::DescriptorSetAllocateInfo::default() + .descriptor_pool(self.descriptor_pool) + .set_layouts(&set_layouts); + + let descriptor_sets = self + .device + .allocate_descriptor_sets(&descriptor_set_allocate_info) + .context("Failed to allocate descriptor sets")?; + + let descriptor_set = descriptor_sets[0]; + + // Update descriptor sets + let buffer_infos: Vec = vk_buffers + .iter() + .zip(&buffers) + .map(|(buffer, config)| { + vk::DescriptorBufferInfo::default() + .buffer(*buffer) + .offset(0) + .range(config.size) + }) + .collect(); + + let descriptor_writes: Vec> = buffer_infos + .iter() + .zip(&buffers) + .enumerate() + .map(|(i, (buffer_info, config))| { + let descriptor_type = match config.usage { + BufferUsage::Storage | BufferUsage::StorageReadOnly => { + vk::DescriptorType::STORAGE_BUFFER + } + BufferUsage::Uniform => vk::DescriptorType::UNIFORM_BUFFER, + }; + + vk::WriteDescriptorSet::default() + .dst_set(descriptor_set) + .dst_binding(i as u32) + .descriptor_type(descriptor_type) + .buffer_info(std::slice::from_ref(buffer_info)) + }) + .collect(); + + self.device.update_descriptor_sets(&descriptor_writes, &[]); + + // Allocate command buffer + let command_buffer_allocate_info = vk::CommandBufferAllocateInfo::default() + .command_pool(self.command_pool) + .level(vk::CommandBufferLevel::PRIMARY) + .command_buffer_count(1); + + let command_buffers = self + .device + .allocate_command_buffers(&command_buffer_allocate_info) + .context("Failed to allocate command buffer")?; + + let command_buffer = command_buffers[0]; + + // Begin command buffer + let begin_info = vk::CommandBufferBeginInfo::default() + .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT); + + self.device + .begin_command_buffer(command_buffer, &begin_info) + .context("Failed to begin command buffer")?; + + // Bind pipeline and descriptor set + self.device + .cmd_bind_pipeline(command_buffer, vk::PipelineBindPoint::COMPUTE, pipeline); + + self.device.cmd_bind_descriptor_sets( + command_buffer, + vk::PipelineBindPoint::COMPUTE, + pipeline_layout, + 0, + &[descriptor_set], + &[], + ); + + // Dispatch compute + self.device + .cmd_dispatch(command_buffer, dispatch[0], dispatch[1], dispatch[2]); + + // End command buffer + self.device + .end_command_buffer(command_buffer) + .context("Failed to end command buffer")?; + + // Submit command buffer + let command_buffers = [command_buffer]; + let submit_info = vk::SubmitInfo::default().command_buffers(&command_buffers); + + self.device + .queue_submit(self.queue, &[submit_info], vk::Fence::null()) + .context("Failed to submit queue")?; + + // Wait for completion + self.device + .queue_wait_idle(self.queue) + .context("Failed to wait for queue")?; + + // Read buffer results + let mut results = Vec::new(); + for (_i, (memory, config)) in buffer_memories.iter().zip(&buffers).enumerate() { + let mut data = vec![0u8; config.size as usize]; + + let mapped_ptr = self + .device + .map_memory(*memory, 0, config.size, vk::MemoryMapFlags::empty()) + .context("Failed to map memory for reading")?; + + std::ptr::copy_nonoverlapping( + mapped_ptr as *const u8, + data.as_mut_ptr(), + config.size as usize, + ); + + self.device.unmap_memory(*memory); + + results.push(data); + } + + // Clean up + self.device + .free_command_buffers(self.command_pool, &[command_buffer]); + for (buffer, memory) in vk_buffers.iter().zip(&buffer_memories) { + self.device.destroy_buffer(*buffer, None); + self.device.free_memory(*memory, None); + } + self.device.destroy_pipeline(pipeline, None); + self.device.destroy_pipeline_layout(pipeline_layout, None); + self.device + .destroy_descriptor_set_layout(descriptor_set_layout, None); + self.device.destroy_shader_module(shader_module, None); + + Ok(results) + } + } +} + +impl Drop for AshBackend { + fn drop(&mut self) { + unsafe { + self.device + .destroy_descriptor_pool(self.descriptor_pool, None); + self.device.destroy_command_pool(self.command_pool, None); + self.device.destroy_device(None); + self.instance.destroy_instance(None); + } + } +} diff --git a/tests/difftests/lib/src/scaffold/compute/mod.rs b/tests/difftests/lib/src/scaffold/compute/mod.rs index a60e43b931..16d6d63e51 100644 --- a/tests/difftests/lib/src/scaffold/compute/mod.rs +++ b/tests/difftests/lib/src/scaffold/compute/mod.rs @@ -1,9 +1,9 @@ +mod ash; mod backend; -mod vulkano; mod wgpu; +pub use ash::AshBackend; pub use backend::{BufferConfig, BufferUsage, ComputeBackend, ComputeTest}; -pub use vulkano::VulkanoBackend; pub use wgpu::{ RustComputeShader, WgpuBackend, WgpuComputeTest, WgpuComputeTestMultiBuffer, WgpuComputeTestPushConstants, WgslComputeShader, diff --git a/tests/difftests/lib/src/scaffold/compute/vulkano.rs b/tests/difftests/lib/src/scaffold/compute/vulkano.rs deleted file mode 100644 index 974856bba3..0000000000 --- a/tests/difftests/lib/src/scaffold/compute/vulkano.rs +++ /dev/null @@ -1,250 +0,0 @@ -use super::backend::{BufferConfig, BufferUsage, ComputeBackend}; -use anyhow::{Context, Result}; -use std::sync::Arc; -use vulkano::{ - VulkanLibrary, - buffer::{Buffer, BufferCreateInfo, BufferUsage as VkBufferUsage}, - command_buffer::{ - AutoCommandBufferBuilder, CommandBufferUsage, allocator::StandardCommandBufferAllocator, - }, - descriptor_set::{ - DescriptorSet, WriteDescriptorSet, allocator::StandardDescriptorSetAllocator, - }, - device::{ - Device, DeviceCreateInfo, DeviceFeatures, Queue, QueueCreateInfo, QueueFlags, - physical::PhysicalDeviceType, - }, - instance::{Instance, InstanceCreateFlags, InstanceCreateInfo}, - memory::allocator::{AllocationCreateInfo, MemoryTypeFilter, StandardMemoryAllocator}, - pipeline::{ - ComputePipeline, Pipeline, PipelineLayout, PipelineShaderStageCreateInfo, - compute::ComputePipelineCreateInfo, layout::PipelineDescriptorSetLayoutCreateInfo, - }, - shader::{ShaderModule, ShaderModuleCreateInfo}, - sync::{self, GpuFuture}, -}; - -pub struct VulkanoBackend { - device: Arc, - queue: Arc, - memory_allocator: Arc, - command_buffer_allocator: Arc, - descriptor_set_allocator: Arc, -} - -impl ComputeBackend for VulkanoBackend { - fn init() -> Result { - let library = VulkanLibrary::new()?; - - // Use the library's supported API version - let api_version = library.api_version(); - eprintln!( - "Vulkan library API version: {}.{}.{}", - api_version.major, api_version.minor, api_version.patch - ); - - let instance = Instance::new(library, InstanceCreateInfo { - flags: InstanceCreateFlags::ENUMERATE_PORTABILITY, - ..Default::default() - })?; - - // Pick a physical device - let physical_device = instance - .enumerate_physical_devices()? - .min_by_key(|p| match p.properties().device_type { - PhysicalDeviceType::DiscreteGpu => 0, - PhysicalDeviceType::IntegratedGpu => 1, - PhysicalDeviceType::VirtualGpu => 2, - PhysicalDeviceType::Cpu => 3, - PhysicalDeviceType::Other => 4, - _ => 5, - }) - .context("No suitable physical device found")?; - - // Find a compute queue - let queue_family_index = physical_device - .queue_family_properties() - .iter() - .enumerate() - .position(|(_, q)| q.queue_flags.intersects(QueueFlags::COMPUTE)) - .context("No compute queue family found")? as u32; - - // Check if vulkan_memory_model is supported - let supported_features = physical_device.supported_features(); - let mut enabled_features = DeviceFeatures::empty(); - if supported_features.vulkan_memory_model { - enabled_features.vulkan_memory_model = true; - } - - let (device, mut queues) = Device::new(physical_device, DeviceCreateInfo { - queue_create_infos: vec![QueueCreateInfo { - queue_family_index, - ..Default::default() - }], - enabled_features, - ..Default::default() - })?; - - let queue = queues.next().context("No queue returned")?; - - let memory_allocator = Arc::new(StandardMemoryAllocator::new_default(device.clone())); - let command_buffer_allocator = Arc::new(StandardCommandBufferAllocator::new( - device.clone(), - Default::default(), - )); - let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new( - device.clone(), - Default::default(), - )); - - Ok(Self { - device, - queue, - memory_allocator, - command_buffer_allocator, - descriptor_set_allocator, - }) - } - - fn run_compute( - &self, - spirv_bytes: &[u8], - entry_point: &str, - dispatch: [u32; 3], - buffers: Vec, - ) -> Result>> { - // Convert bytes to u32 words - if spirv_bytes.len() % 4 != 0 { - anyhow::bail!("SPIR-V binary length is not a multiple of 4"); - } - let spirv_words: Vec = spirv_bytes - .chunks_exact(4) - .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) - .collect(); - - // Create shader module - let shader = unsafe { - ShaderModule::new( - self.device.clone(), - ShaderModuleCreateInfo::new(&spirv_words), - )? - }; - - // Get the entry point - let entry_point = shader - .entry_point(entry_point) - .context("Entry point not found in shader module")?; - - // Create pipeline - let stage = PipelineShaderStageCreateInfo::new(entry_point); - let layout = PipelineLayout::new( - self.device.clone(), - PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage]) - .into_pipeline_layout_create_info(self.device.clone())?, - )?; - - let compute_pipeline = ComputePipeline::new( - self.device.clone(), - None, - ComputePipelineCreateInfo::stage_layout(stage, layout), - )?; - - // Create buffers - let mut gpu_buffers = Vec::new(); - for buffer_config in buffers.iter() { - let usage = match buffer_config.usage { - BufferUsage::Storage => VkBufferUsage::STORAGE_BUFFER, - BufferUsage::StorageReadOnly => VkBufferUsage::STORAGE_BUFFER, - BufferUsage::Uniform => VkBufferUsage::UNIFORM_BUFFER, - }; - - let buffer = if let Some(initial_data) = &buffer_config.initial_data { - Buffer::from_iter( - self.memory_allocator.clone(), - BufferCreateInfo { - usage: usage | VkBufferUsage::TRANSFER_SRC | VkBufferUsage::TRANSFER_DST, - ..Default::default() - }, - AllocationCreateInfo { - memory_type_filter: MemoryTypeFilter::PREFER_DEVICE - | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE, - ..Default::default() - }, - initial_data.iter().cloned(), - )? - } else { - // Zero initialize - let zeros = vec![0u8; buffer_config.size as usize]; - Buffer::from_iter( - self.memory_allocator.clone(), - BufferCreateInfo { - usage: usage | VkBufferUsage::TRANSFER_SRC | VkBufferUsage::TRANSFER_DST, - ..Default::default() - }, - AllocationCreateInfo { - memory_type_filter: MemoryTypeFilter::PREFER_DEVICE - | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE, - ..Default::default() - }, - zeros, - )? - }; - - gpu_buffers.push(buffer); - } - - // Create descriptor set - let layout = compute_pipeline.layout().set_layouts()[0].clone(); - let mut writes = Vec::new(); - for (i, buffer) in gpu_buffers.iter().enumerate() { - writes.push(WriteDescriptorSet::buffer(i as u32, buffer.clone())); - } - - let descriptor_set = - DescriptorSet::new(self.descriptor_set_allocator.clone(), layout, writes, [])?; - - // Create command buffer - let mut builder = AutoCommandBufferBuilder::primary( - self.command_buffer_allocator.clone(), - self.queue.queue_family_index(), - CommandBufferUsage::OneTimeSubmit, - )?; - - unsafe { - builder - .bind_pipeline_compute(compute_pipeline.clone())? - .bind_descriptor_sets( - vulkano::pipeline::PipelineBindPoint::Compute, - compute_pipeline.layout().clone(), - 0, - descriptor_set, - )? - .dispatch(dispatch)?; - } - - let command_buffer = builder.build()?; - - // Execute - let future = sync::now(self.device.clone()) - .then_execute(self.queue.clone(), command_buffer)? - .then_signal_fence_and_flush()?; - - future.wait(None)?; - - // Read back results - let mut results = Vec::new(); - for (i, buffer_config) in buffers.iter().enumerate() { - if matches!( - buffer_config.usage, - BufferUsage::Storage | BufferUsage::StorageReadOnly - ) { - let content_guard = gpu_buffers[i].read()?; - results.push(content_guard.to_vec()); - } else { - results.push(Vec::new()); - } - } - - Ok(results) - } -} diff --git a/tests/difftests/tests/Cargo.lock b/tests/difftests/tests/Cargo.lock index a4a7baf652..f30889f006 100644 --- a/tests/difftests/tests/Cargo.lock +++ b/tests/difftests/tests/Cargo.lock @@ -269,21 +269,6 @@ dependencies = [ "libc", ] -[[package]] -name = "crossbeam-queue" -version = "0.3.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" - [[package]] name = "crunchy" version = "0.2.3" @@ -295,6 +280,7 @@ name = "difftest" version = "0.9.0" dependencies = [ "anyhow", + "ash", "bytemuck", "futures", "naga", @@ -302,20 +288,9 @@ dependencies = [ "serde_json", "spirv-builder", "tempfile", - "vulkano", "wgpu", ] -[[package]] -name = "dispatch2" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec" -dependencies = [ - "bitflags 2.9.0", - "objc2", -] - [[package]] name = "document-features" version = "0.2.11" @@ -585,7 +560,6 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ - "bytemuck", "cfg-if", "crunchy", "num-traits", @@ -602,12 +576,6 @@ dependencies = [ "foldhash", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -804,12 +772,6 @@ dependencies = [ "paste", ] -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - [[package]] name = "naga" version = "25.0.1" @@ -846,16 +808,6 @@ dependencies = [ "jni-sys", ] -[[package]] -name = "nom" -version = "7.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" -dependencies = [ - "memchr", - "minimal-lexical", -] - [[package]] name = "num-traits" version = "0.2.19" @@ -875,67 +827,6 @@ dependencies = [ "malloc_buf", ] -[[package]] -name = "objc2" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88c6597e14493ab2e44ce58f2fdecf095a51f12ca57bec060a11c57332520551" -dependencies = [ - "objc2-encode", -] - -[[package]] -name = "objc2-core-foundation" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" -dependencies = [ - "bitflags 2.9.0", - "dispatch2", - "objc2", -] - -[[package]] -name = "objc2-encode" -version = "4.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" - -[[package]] -name = "objc2-foundation" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c" -dependencies = [ - "bitflags 2.9.0", - "objc2", - "objc2-core-foundation", -] - -[[package]] -name = "objc2-metal" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f246c183239540aab1782457b35ab2040d4259175bd1d0c58e46ada7b47a874" -dependencies = [ - "bitflags 2.9.0", - "objc2", - "objc2-foundation", -] - -[[package]] -name = "objc2-quartz-core" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ffb6a0cd5f182dc964334388560b12a57f7b74b3e2dec5e2722aa2dfb2ccd5" -dependencies = [ - "bitflags 2.9.0", - "objc2", - "objc2-core-foundation", - "objc2-foundation", - "objc2-metal", -] - [[package]] name = "once_cell" version = "1.21.3" @@ -1090,18 +981,6 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" -[[package]] -name = "raw-window-metal" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40d213455a5f1dc59214213c7330e074ddf8114c9a42411eb890c767357ce135" -dependencies = [ - "objc2", - "objc2-core-foundation", - "objc2-foundation", - "objc2-quartz-core", -] - [[package]] name = "redox_syscall" version = "0.5.10" @@ -1239,12 +1118,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "slabbin" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9db491c0d4152a069911a0fbdaca959691bf0b9d7110d98a7ed1c8e59b79ab30" - [[package]] name = "slotmap" version = "1.0.7" @@ -1331,7 +1204,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", @@ -1411,15 +1284,6 @@ dependencies = [ "syn", ] -[[package]] -name = "thread_local" -version = "1.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" -dependencies = [ - "cfg-if", -] - [[package]] name = "trig_ops-rust" version = "0.0.0" @@ -1508,44 +1372,6 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" -[[package]] -name = "vk-parse" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3859da4d7b98bec73e68fb65815d47a263819c415c90eed42b80440a02cbce8c" -dependencies = [ - "xml-rs", -] - -[[package]] -name = "vulkano" -version = "0.35.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08840c2b51759a6f88f26f5ea378bc8b5c199a5b4760ddda292304be087249c4" -dependencies = [ - "ash", - "bytemuck", - "crossbeam-queue", - "foldhash", - "half", - "heck 0.4.1", - "indexmap", - "libloading", - "nom", - "once_cell", - "parking_lot", - "proc-macro2", - "quote", - "raw-window-handle", - "raw-window-metal", - "serde", - "serde_json", - "slabbin", - "smallvec", - "thread_local", - "vk-parse", -] - [[package]] name = "wasi" version = "0.13.3+wasi-0.2.2" @@ -1941,16 +1767,17 @@ dependencies = [ ] [[package]] -name = "workgroup_memory-rust" +name = "workgroup_memory-ash" version = "0.0.0" dependencies = [ "bytemuck", "difftest", + "spirv-builder", "spirv-std", ] [[package]] -name = "workgroup_memory-vulkano" +name = "workgroup_memory-rust" version = "0.0.0" dependencies = [ "bytemuck", diff --git a/tests/difftests/tests/Cargo.toml b/tests/difftests/tests/Cargo.toml index e57c1fe71d..f4b617a8aa 100644 --- a/tests/difftests/tests/Cargo.toml +++ b/tests/difftests/tests/Cargo.toml @@ -7,7 +7,7 @@ members = [ "arch/atomic_ops/atomic_ops-wgsl", "arch/workgroup_memory/workgroup_memory-rust", "arch/workgroup_memory/workgroup_memory-wgsl", - "arch/workgroup_memory/workgroup_memory-vulkano", + "arch/workgroup_memory/workgroup_memory-ash", "arch/memory_barriers/memory_barriers-rust", "arch/memory_barriers/memory_barriers-wgsl", "arch/vector_extract_insert/vector_extract_insert-rust", diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/Cargo.toml similarity index 80% rename from tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml rename to tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/Cargo.toml index 3a3dbd0ea7..1d57d3f91f 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/Cargo.toml +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/Cargo.toml @@ -1,12 +1,12 @@ [package] -name = "workgroup_memory-vulkano" +name = "workgroup_memory-ash" edition.workspace = true [lints] workspace = true [lib] -name = "workgroup_memory_vulkano_shader" +name = "workgroup_memory_ash_shader" crate-type = ["dylib"] # Common deps diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/lib.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/lib.rs similarity index 100% rename from tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/lib.rs rename to tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/lib.rs diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs similarity index 90% rename from tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/main.rs rename to tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs index 59668783c8..c439580b31 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-vulkano/src/main.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs @@ -1,15 +1,16 @@ #[cfg(not(target_arch = "spirv"))] fn main() { use difftest::config::Config; - use difftest::scaffold::Skip; let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); // Skip on macOS due to Vulkan/MoltenVK configuration issues #[cfg(target_os = "macos")] { + use difftest::scaffold::Skip; + let skip = - Skip::new("Vulkano tests are skipped on macOS due to MoltenVK configuration issues"); + Skip::new("Ash tests are skipped on macOS due to MoltenVK configuration issues"); skip.run_test(&config).unwrap(); return; } @@ -17,7 +18,7 @@ fn main() { // Run the actual test on other platforms #[cfg(not(target_os = "macos"))] { - use difftest::scaffold::compute::{BufferConfig, BufferUsage, ComputeTest, VulkanoBackend}; + use difftest::scaffold::compute::{AshBackend, BufferConfig, BufferUsage, ComputeTest}; use spirv_builder::{ModuleResult, SpirvBuilder}; use std::fs; @@ -64,7 +65,7 @@ fn main() { }, ]; - let test = ComputeTest::::new( + let test = ComputeTest::::new( spirv_bytes, entry_point, [1, 1, 1], // Single workgroup with 64 threads diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml index 2e55e489b1..2ddca6bedf 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml @@ -17,5 +17,6 @@ spirv-std.workspace = true # CPU deps [target.'cfg(not(target_arch = "spirv"))'.dependencies] +spirv-builder.workspace = true difftest.workspace = true bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs index 2c979bb0dd..6597c253f7 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs @@ -53,7 +53,7 @@ fn main() { // Write metadata file let metadata = difftest::config::TestMetadata { - epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding + epsilon: Some(2e-5), output_type: difftest::config::OutputType::F32, ..Default::default() }; diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs index c3ee27b2db..c71c23872e 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs @@ -52,7 +52,7 @@ fn main() { // Write metadata file let metadata = difftest::config::TestMetadata { - epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding + epsilon: Some(2e-5), output_type: difftest::config::OutputType::F32, ..Default::default() }; From a8eafaeb4ffe4f8e03bd8ca15c566568aa9f1f12 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Mon, 7 Jul 2025 21:28:23 +0200 Subject: [PATCH 12/24] Vulkan on windows via swiftshader --- .github/workflows/ci.yaml | 20 ++++++++++++++++++++ tests/difftests/bin/src/differ.rs | 6 +++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 27597b1fd0..eddcc67d92 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,6 +29,11 @@ jobs: with: version: 1.4.309.0 cache: true + - if: ${{ runner.os == 'Windows' }} + name: Install Vulkan Runtime with SwiftShader (Windows) + uses: NcStudios/VulkanCI@v1.2 + with: + sdkVersion: 1.4.309.0 - if: ${{ runner.os == 'Linux' }} name: Linux - Install native dependencies run: sudo apt install libwayland-cursor0 libxkbcommon-dev libwayland-dev @@ -88,6 +93,11 @@ jobs: with: version: 1.4.309.0 cache: true + - if: ${{ runner.os == 'Windows' }} + name: Install Vulkan Runtime with SwiftShader (Windows) + uses: NcStudios/VulkanCI@v1.2 + with: + sdkVersion: 1.4.309.0 - name: install rust-toolchain run: cargo version - name: cargo fetch --locked @@ -129,6 +139,11 @@ jobs: with: version: 1.4.309.0 cache: true + - if: ${{ runner.os == 'Windows' }} + name: Install Vulkan Runtime with SwiftShader (Windows) + uses: NcStudios/VulkanCI@v1.2 + with: + sdkVersion: 1.4.309.0 - name: install rust-toolchain run: echo "TARGET=$(rustc --print host-tuple)" >> "$GITHUB_ENV" - name: cargo fetch --locked @@ -149,6 +164,11 @@ jobs: with: version: 1.4.309.0 cache: true + - if: ${{ runner.os == 'Windows' }} + name: Install Vulkan Runtime with SwiftShader (Windows) + uses: NcStudios/VulkanCI@v1.2 + with: + sdkVersion: 1.4.309.0 - if: ${{ runner.os == 'Linux' }} name: Linux - Install native dependencies run: sudo apt install libwayland-cursor0 libxkbcommon-dev libwayland-dev diff --git a/tests/difftests/bin/src/differ.rs b/tests/difftests/bin/src/differ.rs index 38b39eb250..aaf410247a 100644 --- a/tests/difftests/bin/src/differ.rs +++ b/tests/difftests/bin/src/differ.rs @@ -150,7 +150,7 @@ impl NumericType for u32 { format!("{}", value) } fn can_have_relative_diff() -> bool { - false + true } fn as_f64(value: Self) -> f64 { value as f64 @@ -403,8 +403,8 @@ mod tests { _ => panic!("Expected numeric difference"), } match &diffs[0].relative_diff { - DiffMagnitude::Incomparable => {} - _ => panic!("Expected incomparable relative diff for U32"), + DiffMagnitude::Numeric(val) => assert_eq!(*val, 3.0 / 5.0), // 3/5 = 0.6 + _ => panic!("Expected numeric relative diff for U32"), } // Check second difference (index 3: 4 vs 7) From ec302497b852c548efb9d29b9efb8a2eb0376af2 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Wed, 9 Jul 2025 16:45:56 +0200 Subject: [PATCH 13/24] Add support for diffing Raw output --- tests/difftests/bin/src/differ.rs | 228 +++++++++++++++++++++++++++--- 1 file changed, 211 insertions(+), 17 deletions(-) diff --git a/tests/difftests/bin/src/differ.rs b/tests/difftests/bin/src/differ.rs index aaf410247a..5c84b4ceee 100644 --- a/tests/difftests/bin/src/differ.rs +++ b/tests/difftests/bin/src/differ.rs @@ -55,14 +55,47 @@ impl OutputDiffer for RawDiffer { if output1 == output2 { vec![] } else { - // For raw comparison, we just note that they differ - vec![Difference { - index: 0, - value1: format!("{} bytes", output1.len()), - value2: format!("{} bytes", output2.len()), - absolute_diff: DiffMagnitude::Incomparable, - relative_diff: DiffMagnitude::Incomparable, - }] + let mut differences = Vec::new(); + let max_len = std::cmp::max(output1.len(), output2.len()); + + // Find byte-level differences + for i in 0..max_len { + let byte1 = output1.get(i); + let byte2 = output2.get(i); + + match (byte1, byte2) { + (Some(&b1), Some(&b2)) if b1 != b2 => { + differences.push(Difference { + index: i, + value1: format!("{}", b1), + value2: format!("{}", b2), + absolute_diff: DiffMagnitude::Incomparable, + relative_diff: DiffMagnitude::Incomparable, + }); + } + (Some(&b1), None) => { + differences.push(Difference { + index: i, + value1: format!("{}", b1), + value2: "".to_string(), + absolute_diff: DiffMagnitude::Incomparable, + relative_diff: DiffMagnitude::Incomparable, + }); + } + (None, Some(&b2)) => { + differences.push(Difference { + index: i, + value1: "".to_string(), + value2: format!("{}", b2), + absolute_diff: DiffMagnitude::Incomparable, + relative_diff: DiffMagnitude::Incomparable, + }); + } + _ => {} // bytes are equal + } + } + + differences } } @@ -72,18 +105,118 @@ impl OutputDiffer for RawDiffer { } impl DifferenceDisplay for RawDiffer { - fn format_table(&self, _diffs: &[Difference], _pkg1: &str, _pkg2: &str) -> String { - "Binary files differ".to_string() + fn format_table(&self, diffs: &[Difference], pkg1: &str, pkg2: &str) -> String { + use tabled::settings::{Alignment, Modify, Span, Style, object::Rows}; + + let rows: Vec> = diffs + .iter() + .take(10) + .map(|d| { + let (hex1, dec1, ascii1) = if d.value1.is_empty() { + ("--".to_string(), "--".to_string(), "--".to_string()) + } else { + let byte = d.value1.parse::().unwrap(); + let ascii = if byte.is_ascii_graphic() || byte == b' ' { + format!("{}", byte as char) + } else { + match byte { + b'\n' => "\\n".to_string(), + b'\r' => "\\r".to_string(), + b'\t' => "\\t".to_string(), + b'\0' => "\\0".to_string(), + _ => "".to_string(), // Empty for non-printable + } + }; + ( + format!("{:>3}", format!("{:02x}", byte)), + format!("{:3}", byte), + format!("{:^5}", ascii), + ) + }; + + let (hex2, dec2, ascii2) = if d.value2.is_empty() { + ("--".to_string(), "--".to_string(), "--".to_string()) + } else { + let byte = d.value2.parse::().unwrap(); + let ascii = if byte.is_ascii_graphic() || byte == b' ' { + format!("{}", byte as char) + } else { + match byte { + b'\n' => "\\n".to_string(), + b'\r' => "\\r".to_string(), + b'\t' => "\\t".to_string(), + b'\0' => "\\0".to_string(), + _ => "".to_string(), // Empty for non-printable + } + }; + ( + format!("{:>3}", format!("{:02x}", byte)), + format!("{:3}", byte), + format!("{:^5}", ascii), + ) + }; + + vec![ + format!("0x{:04x}", d.index), + hex1, + dec1, + ascii1, + hex2, + dec2, + ascii2, + ] + }) + .collect(); + + let mut builder = tabled::builder::Builder::default(); + + // Header rows + builder.push_record(vec!["Offset", pkg1, "", "", pkg2, "", ""]); + builder.push_record(vec!["", "Hex", "Dec", "ASCII", "Hex", "Dec", "ASCII"]); + + for row in &rows { + builder.push_record(row); + } + + let mut table = builder.build(); + table + .with(Style::modern()) + .with(Modify::new(Rows::new(0..)).with(Alignment::center())) + // Apply column spans to merge the package names across their columns + .modify((0, 1), Span::column(3)) + .modify((0, 4), Span::column(3)) + // Remove the borders between merged cells + .with(tabled::settings::style::BorderSpanCorrection); + + let mut result = table.to_string(); + + if diffs.len() > 10 { + let last_line_width = result + .lines() + .last() + .map(|l| l.chars().count()) + .unwrap_or(0); + result.push_str(&format!( + "\n{:>width$}", + format!("... {} more differences", diffs.len() - 10), + width = last_line_width + )); + } + + result } fn format_report( &self, - _diffs: &[Difference], + diffs: &[Difference], pkg1: &str, pkg2: &str, _epsilon: Option, ) -> String { - format!("Binary outputs from {} and {} differ", pkg1, pkg2) + let mut report = format!("Total differences: {} bytes\n\n", diffs.len()); + report.push_str(&self.format_table(diffs, pkg1, pkg2)); + + report } fn write_human_readable(&self, output: &[u8], path: &std::path::Path) -> std::io::Result<()> { @@ -456,11 +589,72 @@ mod tests { let bytes2 = b"world"; let diffs = differ.compare(bytes1, bytes2, None); - assert_eq!(diffs.len(), 1); - match &diffs[0].absolute_diff { - DiffMagnitude::Incomparable => {} - _ => panic!("Expected incomparable diff for raw bytes"), - } + assert_eq!(diffs.len(), 4); // 4 bytes differ (l at position 3 is same in both) + + // Check first difference (h vs w) + assert_eq!(diffs[0].index, 0); + assert_eq!(diffs[0].value1, "104"); // h = 104 + assert_eq!(diffs[0].value2, "119"); // w = 119 + + // Check second difference (e vs o) + assert_eq!(diffs[1].index, 1); + assert_eq!(diffs[1].value1, "101"); // 'e' = 101 + assert_eq!(diffs[1].value2, "111"); // 'o' = 111 + + // Check third difference (first l vs r) + assert_eq!(diffs[2].index, 2); + assert_eq!(diffs[2].value1, "108"); // 'l' = 108 + assert_eq!(diffs[2].value2, "114"); // 'r' = 114 + + // Check fourth difference (o vs d) + assert_eq!(diffs[3].index, 4); + assert_eq!(diffs[3].value1, "111"); // 'o' = 111 + assert_eq!(diffs[3].value2, "100"); // 'd' = 100 + } + + #[test] + fn test_raw_differ_partial_match() { + let differ = RawDiffer; + let bytes1 = b"hello world"; + let bytes2 = b"hello earth"; + + let diffs = differ.compare(bytes1, bytes2, None); + assert_eq!(diffs.len(), 4); // 4 bytes differ in "world" vs "earth" (r at position 8 is same) + + // First difference should be at index 6 (w vs e) + assert_eq!(diffs[0].index, 6); + assert_eq!(diffs[0].value1, "119"); // 'w' = 119 + assert_eq!(diffs[0].value2, "101"); // 'e' = 101 + + // Second difference at index 7 (o vs a) + assert_eq!(diffs[1].index, 7); + assert_eq!(diffs[1].value1, "111"); // 'o' = 111 + assert_eq!(diffs[1].value2, "97"); // 'a' = 97 + + // Third difference at index 9 (l vs t) + assert_eq!(diffs[2].index, 9); + assert_eq!(diffs[2].value1, "108"); // 'l' = 108 + assert_eq!(diffs[2].value2, "116"); // 't' = 116 + + // Fourth difference at index 10 (d vs h) + assert_eq!(diffs[3].index, 10); + assert_eq!(diffs[3].value1, "100"); // 'd' = 100 + assert_eq!(diffs[3].value2, "104"); // 'h' = 104 + } + + #[test] + fn test_raw_differ_different_lengths() { + let differ = RawDiffer; + let bytes1 = b"hello"; + let bytes2 = b"hello world"; + + let diffs = differ.compare(bytes1, bytes2, None); + assert_eq!(diffs.len(), 6); // " world" = 6 extra bytes + + // Check that missing bytes are shown as empty string + assert_eq!(diffs[0].index, 5); + assert_eq!(diffs[0].value1, ""); + assert_eq!(diffs[0].value2, "32"); // ' ' = 32 } #[test] From 736eb7b5e25a4136ae15dd203dfef1e278f70d2f Mon Sep 17 00:00:00 2001 From: Firestar99 Date: Wed, 9 Jul 2025 19:13:07 +0200 Subject: [PATCH 14/24] grab spirv-builder from difftest to pick up features --- tests/difftests/lib/src/lib.rs | 3 +++ tests/difftests/tests/Cargo.lock | 2 -- tests/difftests/tests/Cargo.toml | 1 - .../workgroup_memory/workgroup_memory-ash/Cargo.toml | 1 - .../workgroup_memory/workgroup_memory-ash/src/main.rs | 10 ++++++---- .../workgroup_memory/workgroup_memory-rust/Cargo.toml | 1 - 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/difftests/lib/src/lib.rs b/tests/difftests/lib/src/lib.rs index 2ef8552858..2f66926766 100644 --- a/tests/difftests/lib/src/lib.rs +++ b/tests/difftests/lib/src/lib.rs @@ -5,6 +5,9 @@ pub mod config; #[cfg(not(target_arch = "spirv"))] pub mod scaffold; +#[cfg(not(target_arch = "spirv"))] +pub use spirv_builder; + /// Macro to round a f32 value for cross-platform compatibility in floating-point /// operations. This helps ensure difftest results are consistent across different /// platforms (Linux, Mac, Windows) which may have slight differences in floating-point diff --git a/tests/difftests/tests/Cargo.lock b/tests/difftests/tests/Cargo.lock index f30889f006..a8238f3fea 100644 --- a/tests/difftests/tests/Cargo.lock +++ b/tests/difftests/tests/Cargo.lock @@ -1772,7 +1772,6 @@ version = "0.0.0" dependencies = [ "bytemuck", "difftest", - "spirv-builder", "spirv-std", ] @@ -1782,7 +1781,6 @@ version = "0.0.0" dependencies = [ "bytemuck", "difftest", - "spirv-builder", "spirv-std", ] diff --git a/tests/difftests/tests/Cargo.toml b/tests/difftests/tests/Cargo.toml index f4b617a8aa..48449e6cc7 100644 --- a/tests/difftests/tests/Cargo.toml +++ b/tests/difftests/tests/Cargo.toml @@ -45,7 +45,6 @@ unexpected_cfgs = { level = "allow", check-cfg = [ ] } [workspace.dependencies] -spirv-builder = { path = "../../../crates/spirv-builder", version = "=0.9.0", default-features = false } spirv-std = { path = "../../../crates/spirv-std", version = "=0.9.0" } spirv-std-types = { path = "../../../crates/spirv-std/shared", version = "=0.9.0" } spirv-std-macros = { path = "../../../crates/spirv-std/macros", version = "=0.9.0" } diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/Cargo.toml b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/Cargo.toml index 1d57d3f91f..93652986fd 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/Cargo.toml +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/Cargo.toml @@ -17,6 +17,5 @@ spirv-std.workspace = true # CPU deps [target.'cfg(not(target_arch = "spirv"))'.dependencies] -spirv-builder.workspace = true difftest.workspace = true bytemuck.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs index c439580b31..076f67709e 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs @@ -19,17 +19,19 @@ fn main() { #[cfg(not(target_os = "macos"))] { use difftest::scaffold::compute::{AshBackend, BufferConfig, BufferUsage, ComputeTest}; - use spirv_builder::{ModuleResult, SpirvBuilder}; + use difftest::spirv_builder::{ + Capability, MetadataPrintout, ModuleResult, ShaderPanicStrategy, SpirvBuilder, + }; use std::fs; // Build the Rust shader to SPIR-V let builder = SpirvBuilder::new(".", "spirv-unknown-vulkan1.2") - .print_metadata(spirv_builder::MetadataPrintout::None) + .print_metadata(MetadataPrintout::None) .release(true) .multimodule(false) - .shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit) + .shader_panic_strategy(ShaderPanicStrategy::SilentExit) .preserve_bindings(true) - .capability(spirv_builder::Capability::VulkanMemoryModel); + .capability(Capability::VulkanMemoryModel); let artifact = builder.build().expect("Failed to build SPIR-V"); diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml index 2ddca6bedf..2e55e489b1 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/Cargo.toml @@ -17,6 +17,5 @@ spirv-std.workspace = true # CPU deps [target.'cfg(not(target_arch = "spirv"))'.dependencies] -spirv-builder.workspace = true difftest.workspace = true bytemuck.workspace = true \ No newline at end of file From d5ddbc9f520355929eea743d07bfbf4e54f164c1 Mon Sep 17 00:00:00 2001 From: Firestar99 Date: Wed, 9 Jul 2025 20:03:39 +0200 Subject: [PATCH 15/24] difftest ash: cleanup code --- .../difftests/lib/src/scaffold/compute/ash.rs | 322 ++++++++---------- 1 file changed, 145 insertions(+), 177 deletions(-) diff --git a/tests/difftests/lib/src/scaffold/compute/ash.rs b/tests/difftests/lib/src/scaffold/compute/ash.rs index c9670f52fa..2e5cdc29d7 100644 --- a/tests/difftests/lib/src/scaffold/compute/ash.rs +++ b/tests/difftests/lib/src/scaffold/compute/ash.rs @@ -1,6 +1,7 @@ use super::backend::{BufferConfig, BufferUsage, ComputeBackend}; -use anyhow::{Context, Result, bail}; +use anyhow::{Context, Result}; use ash::vk; +use ash::vk::DescriptorType; use std::ffi::{CStr, CString}; pub struct AshBackend { @@ -38,37 +39,37 @@ impl AshBackend { BufferUsage::StorageReadOnly => vk::BufferUsageFlags::STORAGE_BUFFER, BufferUsage::Uniform => vk::BufferUsageFlags::UNIFORM_BUFFER, }; - - let buffer_create_info = vk::BufferCreateInfo::default() - .size(config.size) - .usage( - usage | vk::BufferUsageFlags::TRANSFER_SRC | vk::BufferUsageFlags::TRANSFER_DST, - ) - .sharing_mode(vk::SharingMode::EXCLUSIVE); - let buffer = self .device - .create_buffer(&buffer_create_info, None) + .create_buffer( + &vk::BufferCreateInfo::default() + .size(config.size) + .usage( + usage + | vk::BufferUsageFlags::TRANSFER_SRC + | vk::BufferUsageFlags::TRANSFER_DST, + ) + .sharing_mode(vk::SharingMode::EXCLUSIVE), + None, + ) .context("Failed to create buffer")?; let memory_requirements = self.device.get_buffer_memory_requirements(buffer); - let memory_type_index = self .find_memory_type( memory_requirements.memory_type_bits, vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT, ) .context("Failed to find suitable memory type")?; - - let allocate_info = vk::MemoryAllocateInfo::default() - .allocation_size(memory_requirements.size) - .memory_type_index(memory_type_index); - let memory = self .device - .allocate_memory(&allocate_info, None) + .allocate_memory( + &vk::MemoryAllocateInfo::default() + .allocation_size(memory_requirements.size) + .memory_type_index(memory_type_index), + None, + ) .context("Failed to allocate memory")?; - self.device .bind_buffer_memory(buffer, memory, 0) .context("Failed to bind buffer memory")?; @@ -96,30 +97,27 @@ impl ComputeBackend for AshBackend { let entry = ash::Entry::load().context("Failed to load Vulkan entry")?; // Create instance - let app_info = vk::ApplicationInfo::default() - .application_name(CStr::from_bytes_with_nul_unchecked(b"difftest\0")) - .application_version(vk::make_api_version(0, 1, 0, 0)) - .engine_name(CStr::from_bytes_with_nul_unchecked(b"difftest\0")) - .engine_version(vk::make_api_version(0, 1, 0, 0)) - .api_version(vk::API_VERSION_1_2); - - let instance_create_info = - vk::InstanceCreateInfo::default().application_info(&app_info); - let instance = entry - .create_instance(&instance_create_info, None) + .create_instance( + &vk::InstanceCreateInfo::default().application_info( + &vk::ApplicationInfo::default() + .application_name(CStr::from_bytes_with_nul_unchecked(b"difftest\0")) + .application_version(vk::make_api_version(0, 1, 0, 0)) + .engine_name(CStr::from_bytes_with_nul_unchecked(b"difftest\0")) + .engine_version(vk::make_api_version(0, 1, 0, 0)) + .api_version(vk::API_VERSION_1_2), + ), + None, + ) .context("Failed to create Vulkan instance")?; // Select physical device let physical_devices = instance .enumerate_physical_devices() .context("Failed to enumerate physical devices")?; - - if physical_devices.is_empty() { - bail!("No Vulkan devices found"); - } - - let physical_device = physical_devices[0]; + let physical_device = *physical_devices + .first() + .context("No Vulkan devices found")?; let memory_properties = instance.get_physical_device_memory_properties(physical_device); // Find compute queue family @@ -133,51 +131,44 @@ impl ComputeBackend for AshBackend { .context("No compute queue family found")?; // Create device - let priorities = [1.0]; - let queue_create_info = vk::DeviceQueueCreateInfo::default() - .queue_family_index(queue_family_index) - .queue_priorities(&priorities); - - let device_features = vk::PhysicalDeviceFeatures::default(); - - let queue_create_infos = [queue_create_info]; - let device_create_info = vk::DeviceCreateInfo::default() - .queue_create_infos(&queue_create_infos) - .enabled_features(&device_features); - let device = instance - .create_device(physical_device, &device_create_info, None) + .create_device( + physical_device, + &vk::DeviceCreateInfo::default() + .queue_create_infos(&[vk::DeviceQueueCreateInfo::default() + .queue_family_index(queue_family_index) + .queue_priorities(&[1.0])]) + .enabled_features(&vk::PhysicalDeviceFeatures::default()), + None, + ) .context("Failed to create Vulkan device")?; - let queue = device.get_device_queue(queue_family_index, 0); // Create command pool - let command_pool_create_info = vk::CommandPoolCreateInfo::default() - .queue_family_index(queue_family_index) - .flags(vk::CommandPoolCreateFlags::RESET_COMMAND_BUFFER); - let command_pool = device - .create_command_pool(&command_pool_create_info, None) + .create_command_pool( + &vk::CommandPoolCreateInfo::default().queue_family_index(queue_family_index), + None, + ) .context("Failed to create command pool")?; // Create descriptor pool - let descriptor_pool_sizes = vec![ - vk::DescriptorPoolSize { - ty: vk::DescriptorType::STORAGE_BUFFER, - descriptor_count: 16, - }, - vk::DescriptorPoolSize { - ty: vk::DescriptorType::UNIFORM_BUFFER, - descriptor_count: 16, - }, - ]; - - let descriptor_pool_create_info = vk::DescriptorPoolCreateInfo::default() - .pool_sizes(&descriptor_pool_sizes) - .max_sets(16); - let descriptor_pool = device - .create_descriptor_pool(&descriptor_pool_create_info, None) + .create_descriptor_pool( + &vk::DescriptorPoolCreateInfo::default() + .pool_sizes(&vec![ + vk::DescriptorPoolSize { + ty: DescriptorType::STORAGE_BUFFER, + descriptor_count: 16, + }, + vk::DescriptorPoolSize { + ty: DescriptorType::UNIFORM_BUFFER, + descriptor_count: 16, + }, + ]) + .max_sets(16), + None, + ) .context("Failed to create descriptor pool")?; Ok(Self { @@ -200,81 +191,71 @@ impl ComputeBackend for AshBackend { buffers: Vec, ) -> Result>> { unsafe { - // Create shader module - let spirv_u32: Vec = spirv_bytes - .chunks_exact(4) - .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) - .collect(); - - let shader_module_create_info = vk::ShaderModuleCreateInfo::default().code(&spirv_u32); - - let shader_module = self - .device - .create_shader_module(&shader_module_create_info, None) - .context("Failed to create shader module")?; - // Create descriptor set layout - let mut layout_bindings = Vec::new(); - for (i, buffer) in buffers.iter().enumerate() { - let descriptor_type = match buffer.usage { - BufferUsage::Storage | BufferUsage::StorageReadOnly => { - vk::DescriptorType::STORAGE_BUFFER - } - BufferUsage::Uniform => vk::DescriptorType::UNIFORM_BUFFER, - }; - let binding = vk::DescriptorSetLayoutBinding::default() - .binding(i as u32) - .descriptor_type(descriptor_type) - .descriptor_count(1) - .stage_flags(vk::ShaderStageFlags::COMPUTE); - layout_bindings.push(binding); - } - - let descriptor_set_layout_create_info = - vk::DescriptorSetLayoutCreateInfo::default().bindings(&layout_bindings); - + let bindings = buffers + .iter() + .enumerate() + .map(|(i, buffer)| { + vk::DescriptorSetLayoutBinding::default() + .binding(i as u32) + .descriptor_type(buffer_usage_to_descriptor_type(buffer.usage)) + .descriptor_count(1) + .stage_flags(vk::ShaderStageFlags::COMPUTE) + }) + .collect::>(); let descriptor_set_layout = self .device - .create_descriptor_set_layout(&descriptor_set_layout_create_info, None) + .create_descriptor_set_layout( + &vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings), + None, + ) .context("Failed to create descriptor set layout")?; // Create pipeline layout - let set_layouts = [descriptor_set_layout]; - let pipeline_layout_create_info = - vk::PipelineLayoutCreateInfo::default().set_layouts(&set_layouts); - let pipeline_layout = self .device - .create_pipeline_layout(&pipeline_layout_create_info, None) - .context("Failed to create pipeline layout")?; - - // Create compute pipeline - let entry_point_cstring = CString::new(entry_point)?; - let stage_create_info = vk::PipelineShaderStageCreateInfo::default() - .stage(vk::ShaderStageFlags::COMPUTE) - .module(shader_module) - .name(&entry_point_cstring); - - let compute_pipeline_create_info = vk::ComputePipelineCreateInfo::default() - .stage(stage_create_info) - .layout(pipeline_layout); - - let pipelines = self - .device - .create_compute_pipelines( - vk::PipelineCache::null(), - &[compute_pipeline_create_info], + .create_pipeline_layout( + &vk::PipelineLayoutCreateInfo::default().set_layouts(&[descriptor_set_layout]), None, ) - .map_err(|(_, e)| e) - .context("Failed to create compute pipeline")?; + .context("Failed to create pipeline layout")?; - let pipeline = pipelines[0]; + // Create compute pipeline + let pipeline = { + let spirv_u32: Vec = spirv_bytes + .chunks_exact(4) + .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(); + let shader_module = self + .device + .create_shader_module( + &vk::ShaderModuleCreateInfo::default().code(&spirv_u32), + None, + ) + .context("Failed to create shader module")?; + let pipeline = self + .device + .create_compute_pipelines( + vk::PipelineCache::null(), + &[vk::ComputePipelineCreateInfo::default() + .stage( + vk::PipelineShaderStageCreateInfo::default() + .stage(vk::ShaderStageFlags::COMPUTE) + .module(shader_module) + .name(&CString::new(entry_point)?), + ) + .layout(pipeline_layout)], + None, + ) + .map_err(|(_, e)| e) + .context("Failed to create compute pipeline")?[0]; + self.device.destroy_shader_module(shader_module, None); + pipeline + }; // Create buffers let mut vk_buffers = Vec::new(); let mut buffer_memories = Vec::new(); - for buffer_config in &buffers { let (buffer, memory) = self.create_buffer(buffer_config)?; vk_buffers.push(buffer); @@ -282,17 +263,14 @@ impl ComputeBackend for AshBackend { } // Allocate descriptor set - let set_layouts = [descriptor_set_layout]; - let descriptor_set_allocate_info = vk::DescriptorSetAllocateInfo::default() - .descriptor_pool(self.descriptor_pool) - .set_layouts(&set_layouts); - - let descriptor_sets = self + let descriptor_set = self .device - .allocate_descriptor_sets(&descriptor_set_allocate_info) - .context("Failed to allocate descriptor sets")?; - - let descriptor_set = descriptor_sets[0]; + .allocate_descriptor_sets( + &vk::DescriptorSetAllocateInfo::default() + .descriptor_pool(self.descriptor_pool) + .set_layouts(&[descriptor_set_layout]), + ) + .context("Failed to allocate descriptor sets")?[0]; // Update descriptor sets let buffer_infos: Vec = vk_buffers @@ -305,54 +283,41 @@ impl ComputeBackend for AshBackend { .range(config.size) }) .collect(); - let descriptor_writes: Vec> = buffer_infos .iter() .zip(&buffers) .enumerate() .map(|(i, (buffer_info, config))| { - let descriptor_type = match config.usage { - BufferUsage::Storage | BufferUsage::StorageReadOnly => { - vk::DescriptorType::STORAGE_BUFFER - } - BufferUsage::Uniform => vk::DescriptorType::UNIFORM_BUFFER, - }; - vk::WriteDescriptorSet::default() .dst_set(descriptor_set) .dst_binding(i as u32) - .descriptor_type(descriptor_type) + .descriptor_type(buffer_usage_to_descriptor_type(config.usage)) .buffer_info(std::slice::from_ref(buffer_info)) }) .collect(); - self.device.update_descriptor_sets(&descriptor_writes, &[]); // Allocate command buffer - let command_buffer_allocate_info = vk::CommandBufferAllocateInfo::default() - .command_pool(self.command_pool) - .level(vk::CommandBufferLevel::PRIMARY) - .command_buffer_count(1); - - let command_buffers = self + let command_buffer = self .device - .allocate_command_buffers(&command_buffer_allocate_info) - .context("Failed to allocate command buffer")?; - - let command_buffer = command_buffers[0]; - - // Begin command buffer - let begin_info = vk::CommandBufferBeginInfo::default() - .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT); + .allocate_command_buffers( + &vk::CommandBufferAllocateInfo::default() + .command_pool(self.command_pool) + .level(vk::CommandBufferLevel::PRIMARY) + .command_buffer_count(1), + ) + .context("Failed to allocate command buffer")?[0]; + // Record command buffer self.device - .begin_command_buffer(command_buffer, &begin_info) + .begin_command_buffer( + command_buffer, + &vk::CommandBufferBeginInfo::default() + .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT), + ) .context("Failed to begin command buffer")?; - - // Bind pipeline and descriptor set self.device .cmd_bind_pipeline(command_buffer, vk::PipelineBindPoint::COMPUTE, pipeline); - self.device.cmd_bind_descriptor_sets( command_buffer, vk::PipelineBindPoint::COMPUTE, @@ -361,22 +326,19 @@ impl ComputeBackend for AshBackend { &[descriptor_set], &[], ); - - // Dispatch compute self.device .cmd_dispatch(command_buffer, dispatch[0], dispatch[1], dispatch[2]); - - // End command buffer self.device .end_command_buffer(command_buffer) .context("Failed to end command buffer")?; // Submit command buffer - let command_buffers = [command_buffer]; - let submit_info = vk::SubmitInfo::default().command_buffers(&command_buffers); - self.device - .queue_submit(self.queue, &[submit_info], vk::Fence::null()) + .queue_submit( + self.queue, + &[vk::SubmitInfo::default().command_buffers(&[command_buffer])], + vk::Fence::null(), + ) .context("Failed to submit queue")?; // Wait for completion @@ -416,7 +378,6 @@ impl ComputeBackend for AshBackend { self.device.destroy_pipeline_layout(pipeline_layout, None); self.device .destroy_descriptor_set_layout(descriptor_set_layout, None); - self.device.destroy_shader_module(shader_module, None); Ok(results) } @@ -434,3 +395,10 @@ impl Drop for AshBackend { } } } + +fn buffer_usage_to_descriptor_type(usage: BufferUsage) -> DescriptorType { + match usage { + BufferUsage::Storage | BufferUsage::StorageReadOnly => DescriptorType::STORAGE_BUFFER, + BufferUsage::Uniform => DescriptorType::UNIFORM_BUFFER, + } +} From 736a5933df3086124c05888e842e69a21bf5e472 Mon Sep 17 00:00:00 2001 From: Firestar99 Date: Wed, 9 Jul 2025 20:08:32 +0200 Subject: [PATCH 16/24] difftest ash: allocate descriptor pool just as required --- .../difftests/lib/src/scaffold/compute/ash.rs | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/tests/difftests/lib/src/scaffold/compute/ash.rs b/tests/difftests/lib/src/scaffold/compute/ash.rs index 2e5cdc29d7..e2417aa5c8 100644 --- a/tests/difftests/lib/src/scaffold/compute/ash.rs +++ b/tests/difftests/lib/src/scaffold/compute/ash.rs @@ -9,7 +9,6 @@ pub struct AshBackend { device: ash::Device, queue: vk::Queue, command_pool: vk::CommandPool, - descriptor_pool: vk::DescriptorPool, memory_properties: vk::PhysicalDeviceMemoryProperties, _entry: ash::Entry, } @@ -152,31 +151,11 @@ impl ComputeBackend for AshBackend { ) .context("Failed to create command pool")?; - // Create descriptor pool - let descriptor_pool = device - .create_descriptor_pool( - &vk::DescriptorPoolCreateInfo::default() - .pool_sizes(&vec![ - vk::DescriptorPoolSize { - ty: DescriptorType::STORAGE_BUFFER, - descriptor_count: 16, - }, - vk::DescriptorPoolSize { - ty: DescriptorType::UNIFORM_BUFFER, - descriptor_count: 16, - }, - ]) - .max_sets(16), - None, - ) - .context("Failed to create descriptor pool")?; - Ok(Self { instance, device, queue, command_pool, - descriptor_pool, memory_properties, _entry: entry, }) @@ -263,11 +242,34 @@ impl ComputeBackend for AshBackend { } // Allocate descriptor set + let count_descriptor_types = |desc_type: DescriptorType| vk::DescriptorPoolSize { + ty: desc_type, + descriptor_count: buffers + .iter() + .filter(|buffer| buffer_usage_to_descriptor_type(buffer.usage) == desc_type) + .count() as u32, + }; + let pool_sizes = [ + count_descriptor_types(DescriptorType::STORAGE_BUFFER), + count_descriptor_types(DescriptorType::UNIFORM_BUFFER), + ] + .into_iter() + .filter(|a| a.descriptor_count != 0) + .collect::>(); + let descriptor_pool = self + .device + .create_descriptor_pool( + &vk::DescriptorPoolCreateInfo::default() + .pool_sizes(&pool_sizes) + .max_sets(1), + None, + ) + .context("Failed to create descriptor pool")?; let descriptor_set = self .device .allocate_descriptor_sets( &vk::DescriptorSetAllocateInfo::default() - .descriptor_pool(self.descriptor_pool) + .descriptor_pool(descriptor_pool) .set_layouts(&[descriptor_set_layout]), ) .context("Failed to allocate descriptor sets")?[0]; @@ -374,6 +376,7 @@ impl ComputeBackend for AshBackend { self.device.destroy_buffer(*buffer, None); self.device.free_memory(*memory, None); } + self.device.destroy_descriptor_pool(descriptor_pool, None); self.device.destroy_pipeline(pipeline, None); self.device.destroy_pipeline_layout(pipeline_layout, None); self.device @@ -387,8 +390,6 @@ impl ComputeBackend for AshBackend { impl Drop for AshBackend { fn drop(&mut self) { unsafe { - self.device - .destroy_descriptor_pool(self.descriptor_pool, None); self.device.destroy_command_pool(self.command_pool, None); self.device.destroy_device(None); self.instance.destroy_instance(None); From 581cf4a668615374edd582310977572945a121a8 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Thu, 10 Jul 2025 09:27:59 +0200 Subject: [PATCH 17/24] Rustfmt after rebasing on new rust version --- .../memory_barriers-rust/src/main.rs | 28 +++++++++++-------- .../memory_barriers-wgsl/src/main.rs | 28 +++++++++++-------- .../vector_extract_insert-rust/src/main.rs | 11 ++++---- .../vector_extract_insert-wgsl/src/main.rs | 11 ++++---- .../workgroup_memory-ash/src/main.rs | 5 ++-- .../control_flow_complex-rust/src/main.rs | 10 +++---- .../control_flow_complex-wgsl/src/main.rs | 10 +++---- .../bitwise_ops/bitwise_ops-rust/src/main.rs | 11 ++++---- .../bitwise_ops/bitwise_ops-wgsl/src/main.rs | 11 ++++---- .../ops/trig_ops/trig_ops-rust/src/main.rs | 10 +++---- .../ops/trig_ops/trig_ops-wgsl/src/main.rs | 10 +++---- .../vector_swizzle-rust/src/main.rs | 10 +++---- .../vector_swizzle-wgsl/src/main.rs | 10 +++---- 13 files changed, 84 insertions(+), 81 deletions(-) diff --git a/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/main.rs b/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/main.rs index 3857ec8a89..2be264b5bf 100644 --- a/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/main.rs +++ b/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/main.rs @@ -10,18 +10,22 @@ fn main() { let initial_data: Vec = (0..64).collect(); let initial_bytes: Vec = initial_data.iter().flat_map(|&x| x.to_ne_bytes()).collect(); - let test = WgpuComputeTestMultiBuffer::new(RustComputeShader::default(), [1, 1, 1], vec![ - BufferConfig { - size: buffer_size, - usage: BufferUsage::StorageReadOnly, - initial_data: Some(initial_bytes), - }, - BufferConfig { - size: buffer_size, - usage: BufferUsage::Storage, - initial_data: None, - }, - ]); + let test = WgpuComputeTestMultiBuffer::new( + RustComputeShader::default(), + [1, 1, 1], + vec![ + BufferConfig { + size: buffer_size, + usage: BufferUsage::StorageReadOnly, + initial_data: Some(initial_bytes), + }, + BufferConfig { + size: buffer_size, + usage: BufferUsage::Storage, + initial_data: None, + }, + ], + ); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/src/main.rs b/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/src/main.rs index d0ba32b0c4..a8a8ce08a2 100644 --- a/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/src/main.rs +++ b/tests/difftests/tests/arch/memory_barriers/memory_barriers-wgsl/src/main.rs @@ -10,18 +10,22 @@ fn main() { let initial_data: Vec = (0..64).collect(); let initial_bytes: Vec = initial_data.iter().flat_map(|&x| x.to_ne_bytes()).collect(); - let test = WgpuComputeTestMultiBuffer::new(WgslComputeShader::default(), [1, 1, 1], vec![ - BufferConfig { - size: buffer_size, - usage: BufferUsage::StorageReadOnly, - initial_data: Some(initial_bytes), - }, - BufferConfig { - size: buffer_size, - usage: BufferUsage::Storage, - initial_data: None, - }, - ]); + let test = WgpuComputeTestMultiBuffer::new( + WgslComputeShader::default(), + [1, 1, 1], + vec![ + BufferConfig { + size: buffer_size, + usage: BufferUsage::StorageReadOnly, + initial_data: Some(initial_bytes), + }, + BufferConfig { + size: buffer_size, + usage: BufferUsage::Storage, + initial_data: None, + }, + ], + ); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/main.rs b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/main.rs index 89dc766576..452c579463 100644 --- a/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/main.rs +++ b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-rust/src/main.rs @@ -5,12 +5,11 @@ fn main() { let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); let buffer_size = 1024; - let test = - WgpuComputeTestMultiBuffer::new_with_sizes(RustComputeShader::default(), [64, 1, 1], &[ - buffer_size, - buffer_size, - buffer_size, - ]); + let test = WgpuComputeTestMultiBuffer::new_with_sizes( + RustComputeShader::default(), + [64, 1, 1], + &[buffer_size, buffer_size, buffer_size], + ); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/src/main.rs b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/src/main.rs index f98ea16c7e..63c8babb8e 100644 --- a/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/src/main.rs +++ b/tests/difftests/tests/arch/vector_extract_insert/vector_extract_insert-wgsl/src/main.rs @@ -5,12 +5,11 @@ fn main() { let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); let buffer_size = 1024; - let test = - WgpuComputeTestMultiBuffer::new_with_sizes(WgslComputeShader::default(), [64, 1, 1], &[ - buffer_size, - buffer_size, - buffer_size, - ]); + let test = WgpuComputeTestMultiBuffer::new_with_sizes( + WgslComputeShader::default(), + [64, 1, 1], + &[buffer_size, buffer_size, buffer_size], + ); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs index 076f67709e..645e8c3e2f 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs @@ -8,9 +8,8 @@ fn main() { #[cfg(target_os = "macos")] { use difftest::scaffold::Skip; - - let skip = - Skip::new("Ash tests are skipped on macOS due to MoltenVK configuration issues"); + + let skip = Skip::new("Ash tests are skipped on macOS due to MoltenVK configuration issues"); skip.run_test(&config).unwrap(); return; } diff --git a/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/main.rs b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/main.rs index eb96343bcf..83e8abf9bb 100644 --- a/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/main.rs +++ b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-rust/src/main.rs @@ -5,11 +5,11 @@ fn main() { let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); let buffer_size = 1024; - let test = - WgpuComputeTestMultiBuffer::new_with_sizes(RustComputeShader::default(), [64, 1, 1], &[ - buffer_size, - buffer_size, - ]); + let test = WgpuComputeTestMultiBuffer::new_with_sizes( + RustComputeShader::default(), + [64, 1, 1], + &[buffer_size, buffer_size], + ); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/src/main.rs b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/src/main.rs index f21cdecfbf..37d570a504 100644 --- a/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/control_flow_complex/control_flow_complex-wgsl/src/main.rs @@ -5,11 +5,11 @@ fn main() { let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); let buffer_size = 1024; - let test = - WgpuComputeTestMultiBuffer::new_with_sizes(WgslComputeShader::default(), [64, 1, 1], &[ - buffer_size, - buffer_size, - ]); + let test = WgpuComputeTestMultiBuffer::new_with_sizes( + WgslComputeShader::default(), + [64, 1, 1], + &[buffer_size, buffer_size], + ); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/main.rs index 89dc766576..452c579463 100644 --- a/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-rust/src/main.rs @@ -5,12 +5,11 @@ fn main() { let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); let buffer_size = 1024; - let test = - WgpuComputeTestMultiBuffer::new_with_sizes(RustComputeShader::default(), [64, 1, 1], &[ - buffer_size, - buffer_size, - buffer_size, - ]); + let test = WgpuComputeTestMultiBuffer::new_with_sizes( + RustComputeShader::default(), + [64, 1, 1], + &[buffer_size, buffer_size, buffer_size], + ); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/src/main.rs index f98ea16c7e..63c8babb8e 100644 --- a/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/bitwise_ops/bitwise_ops-wgsl/src/main.rs @@ -5,12 +5,11 @@ fn main() { let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); let buffer_size = 1024; - let test = - WgpuComputeTestMultiBuffer::new_with_sizes(WgslComputeShader::default(), [64, 1, 1], &[ - buffer_size, - buffer_size, - buffer_size, - ]); + let test = WgpuComputeTestMultiBuffer::new_with_sizes( + WgslComputeShader::default(), + [64, 1, 1], + &[buffer_size, buffer_size, buffer_size], + ); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/main.rs index eb96343bcf..83e8abf9bb 100644 --- a/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-rust/src/main.rs @@ -5,11 +5,11 @@ fn main() { let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); let buffer_size = 1024; - let test = - WgpuComputeTestMultiBuffer::new_with_sizes(RustComputeShader::default(), [64, 1, 1], &[ - buffer_size, - buffer_size, - ]); + let test = WgpuComputeTestMultiBuffer::new_with_sizes( + RustComputeShader::default(), + [64, 1, 1], + &[buffer_size, buffer_size], + ); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/src/main.rs index f21cdecfbf..37d570a504 100644 --- a/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/trig_ops/trig_ops-wgsl/src/main.rs @@ -5,11 +5,11 @@ fn main() { let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); let buffer_size = 1024; - let test = - WgpuComputeTestMultiBuffer::new_with_sizes(WgslComputeShader::default(), [64, 1, 1], &[ - buffer_size, - buffer_size, - ]); + let test = WgpuComputeTestMultiBuffer::new_with_sizes( + WgslComputeShader::default(), + [64, 1, 1], + &[buffer_size, buffer_size], + ); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/main.rs index eb96343bcf..83e8abf9bb 100644 --- a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/main.rs @@ -5,11 +5,11 @@ fn main() { let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); let buffer_size = 1024; - let test = - WgpuComputeTestMultiBuffer::new_with_sizes(RustComputeShader::default(), [64, 1, 1], &[ - buffer_size, - buffer_size, - ]); + let test = WgpuComputeTestMultiBuffer::new_with_sizes( + RustComputeShader::default(), + [64, 1, 1], + &[buffer_size, buffer_size], + ); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/src/main.rs index f21cdecfbf..37d570a504 100644 --- a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-wgsl/src/main.rs @@ -5,11 +5,11 @@ fn main() { let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); let buffer_size = 1024; - let test = - WgpuComputeTestMultiBuffer::new_with_sizes(WgslComputeShader::default(), [64, 1, 1], &[ - buffer_size, - buffer_size, - ]); + let test = WgpuComputeTestMultiBuffer::new_with_sizes( + WgslComputeShader::default(), + [64, 1, 1], + &[buffer_size, buffer_size], + ); test.run_test(&config).unwrap(); } From 0dcbc9c1aa63ec482c5fe9a5542bb8b97d444ac5 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Sat, 12 Jul 2025 17:02:44 +0200 Subject: [PATCH 18/24] Make ash backend use ComputeShader trait to reduce code duplication --- .../lib/src/scaffold/compute/backend.rs | 55 +++++++++++++++- .../difftests/lib/src/scaffold/compute/mod.rs | 6 +- .../lib/src/scaffold/compute/wgpu.rs | 65 +++++++++++++++---- .../workgroup_memory-ash/src/main.rs | 40 ++---------- 4 files changed, 116 insertions(+), 50 deletions(-) diff --git a/tests/difftests/lib/src/scaffold/compute/backend.rs b/tests/difftests/lib/src/scaffold/compute/backend.rs index 5099290615..5696f965b9 100644 --- a/tests/difftests/lib/src/scaffold/compute/backend.rs +++ b/tests/difftests/lib/src/scaffold/compute/backend.rs @@ -17,12 +17,14 @@ pub enum BufferUsage { Uniform, } +use super::SpirvShader; + /// A generic trait for compute backends pub trait ComputeBackend: Sized { /// Initialize the backend fn init() -> Result; - /// Create and run a compute shader with multiple buffers + /// Create and run a compute shader with multiple buffers from raw SPIRV bytes fn run_compute( &self, spirv_bytes: &[u8], @@ -30,6 +32,17 @@ pub trait ComputeBackend: Sized { dispatch: [u32; 3], buffers: Vec, ) -> Result>>; + + /// Create and run a compute shader with multiple buffers from a shader object + fn run_compute_shader( + &self, + shader: &S, + dispatch: [u32; 3], + buffers: Vec, + ) -> Result>> { + let (spirv_bytes, entry_point) = shader.spirv_bytes()?; + self.run_compute(&spirv_bytes, &entry_point, dispatch, buffers) + } } /// A compute test that can run on any backend @@ -82,3 +95,43 @@ impl ComputeTest { anyhow::bail!("No storage buffer output found") } } + +/// A compute test that can run on any backend using a shader object +pub struct ComputeShaderTest { + backend: B, + shader: S, + dispatch: [u32; 3], + buffers: Vec, +} + +impl ComputeShaderTest { + pub fn new(shader: S, dispatch: [u32; 3], buffers: Vec) -> Result { + Ok(Self { + backend: B::init()?, + shader, + dispatch, + buffers, + }) + } + + pub fn run(self) -> Result>> { + self.backend + .run_compute_shader(&self.shader, self.dispatch, self.buffers) + } + + pub fn run_test(self, config: &Config) -> Result<()> { + let buffers = self.buffers.clone(); + let outputs = self.run()?; + // Write the first storage buffer output to the file + for (output, buffer_config) in outputs.iter().zip(&buffers) { + if matches!(buffer_config.usage, BufferUsage::Storage) && !output.is_empty() { + use std::fs::File; + use std::io::Write; + let mut f = File::create(&config.output_path)?; + f.write_all(output)?; + return Ok(()); + } + } + anyhow::bail!("No storage buffer output found") + } +} diff --git a/tests/difftests/lib/src/scaffold/compute/mod.rs b/tests/difftests/lib/src/scaffold/compute/mod.rs index 16d6d63e51..edb0daad2f 100644 --- a/tests/difftests/lib/src/scaffold/compute/mod.rs +++ b/tests/difftests/lib/src/scaffold/compute/mod.rs @@ -3,8 +3,8 @@ mod backend; mod wgpu; pub use ash::AshBackend; -pub use backend::{BufferConfig, BufferUsage, ComputeBackend, ComputeTest}; +pub use backend::{BufferConfig, BufferUsage, ComputeBackend, ComputeShaderTest, ComputeTest}; pub use wgpu::{ - RustComputeShader, WgpuBackend, WgpuComputeTest, WgpuComputeTestMultiBuffer, - WgpuComputeTestPushConstants, WgslComputeShader, + RustComputeShader, SpirvShader, WgpuBackend, WgpuComputeTest, WgpuComputeTestMultiBuffer, + WgpuComputeTestPushConstants, WgpuShader, WgslComputeShader, }; diff --git a/tests/difftests/lib/src/scaffold/compute/wgpu.rs b/tests/difftests/lib/src/scaffold/compute/wgpu.rs index b92f3a7fa7..c16380fbb5 100644 --- a/tests/difftests/lib/src/scaffold/compute/wgpu.rs +++ b/tests/difftests/lib/src/scaffold/compute/wgpu.rs @@ -18,8 +18,15 @@ use super::backend::{self, ComputeBackend}; pub type BufferConfig = backend::BufferConfig; pub type BufferUsage = backend::BufferUsage; -/// Trait that creates a shader module and provides its entry point. -pub trait ComputeShader { +/// Trait for shaders that can provide SPIRV bytes. +pub trait SpirvShader { + /// Returns the SPIRV bytes and entry point name. + fn spirv_bytes(&self) -> anyhow::Result<(Vec, String)>; +} + +/// Trait for shaders that can create wgpu modules. +pub trait WgpuShader { + /// Creates a wgpu shader module. fn create_module( &self, device: &wgpu::Device, @@ -29,25 +36,46 @@ pub trait ComputeShader { /// A compute shader written in Rust compiled with spirv-builder. pub struct RustComputeShader { pub path: PathBuf, + pub target: String, + pub capabilities: Vec, } impl RustComputeShader { pub fn new>(path: P) -> Self { - Self { path: path.into() } + Self { + path: path.into(), + target: "spirv-unknown-vulkan1.1".to_string(), + capabilities: Vec::new(), + } + } + + pub fn with_target>(path: P, target: impl Into) -> Self { + Self { + path: path.into(), + target: target.into(), + capabilities: Vec::new(), + } + } + + pub fn with_capability(mut self, capability: spirv_builder::Capability) -> Self { + self.capabilities.push(capability); + self } } -impl ComputeShader for RustComputeShader { - fn create_module( - &self, - device: &wgpu::Device, - ) -> anyhow::Result<(wgpu::ShaderModule, Option)> { - let builder = SpirvBuilder::new(&self.path, "spirv-unknown-vulkan1.1") +impl SpirvShader for RustComputeShader { + fn spirv_bytes(&self) -> anyhow::Result<(Vec, String)> { + let mut builder = SpirvBuilder::new(&self.path, &self.target) .print_metadata(spirv_builder::MetadataPrintout::None) .release(true) .multimodule(false) .shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit) .preserve_bindings(true); + + for capability in &self.capabilities { + builder = builder.capability(*capability); + } + let artifact = builder.build().context("SpirvBuilder::build() failed")?; if artifact.entry_points.len() != 1 { @@ -66,6 +94,17 @@ impl ComputeShader for RustComputeShader { } }; + Ok((shader_bytes, entry_point)) + } +} + +impl WgpuShader for RustComputeShader { + fn create_module( + &self, + device: &wgpu::Device, + ) -> anyhow::Result<(wgpu::ShaderModule, Option)> { + let (shader_bytes, entry_point) = self.spirv_bytes()?; + if shader_bytes.len() % 4 != 0 { anyhow::bail!("SPIR-V binary length is not a multiple of 4"); } @@ -93,7 +132,7 @@ impl WgslComputeShader { } } -impl ComputeShader for WgslComputeShader { +impl WgpuShader for WgslComputeShader { fn create_module( &self, device: &wgpu::Device, @@ -133,7 +172,7 @@ pub struct WgpuComputeTestPushConstants { impl WgpuComputeTest where - S: ComputeShader, + S: WgpuShader, { pub fn new(shader: S, dispatch: [u32; 3], output_bytes: u64) -> Self { Self { @@ -544,7 +583,7 @@ impl Default for RustComputeShader { impl WgpuComputeTestMultiBuffer where - S: ComputeShader, + S: WgpuShader, { pub fn new(shader: S, dispatch: [u32; 3], buffers: Vec) -> Self { Self { @@ -714,7 +753,7 @@ where impl WgpuComputeTestPushConstants where - S: ComputeShader, + S: WgpuShader, { pub fn new( shader: S, diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs index 645e8c3e2f..bfbad19336 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs @@ -17,37 +17,10 @@ fn main() { // Run the actual test on other platforms #[cfg(not(target_os = "macos"))] { - use difftest::scaffold::compute::{AshBackend, BufferConfig, BufferUsage, ComputeTest}; - use difftest::spirv_builder::{ - Capability, MetadataPrintout, ModuleResult, ShaderPanicStrategy, SpirvBuilder, - }; - use std::fs; - - // Build the Rust shader to SPIR-V - let builder = SpirvBuilder::new(".", "spirv-unknown-vulkan1.2") - .print_metadata(MetadataPrintout::None) - .release(true) - .multimodule(false) - .shader_panic_strategy(ShaderPanicStrategy::SilentExit) - .preserve_bindings(true) - .capability(Capability::VulkanMemoryModel); - - let artifact = builder.build().expect("Failed to build SPIR-V"); - - if artifact.entry_points.len() != 1 { - panic!( - "Expected exactly one entry point, found {}", - artifact.entry_points.len() - ); - } - let entry_point = artifact.entry_points.into_iter().next().unwrap(); - - let spirv_bytes = match artifact.module { - ModuleResult::SingleModule(path) => { - fs::read(&path).expect("Failed to read SPIR-V file") - } - ModuleResult::MultiModule(_) => panic!("Unexpected multi-module result"), + use difftest::scaffold::compute::{ + AshBackend, BufferConfig, BufferUsage, ComputeShaderTest, RustComputeShader, }; + use difftest::spirv_builder::Capability; // Initialize input buffer with values to sum let input_data: Vec = (1..=64).collect(); @@ -66,9 +39,10 @@ fn main() { }, ]; - let test = ComputeTest::::new( - spirv_bytes, - entry_point, + let shader = RustComputeShader::with_target(".", "spirv-unknown-vulkan1.2") + .with_capability(Capability::VulkanMemoryModel); + let test = ComputeShaderTest::::new( + shader, [1, 1, 1], // Single workgroup with 64 threads buffers, ) From 78930dd4b4f0a6b2938dd528f7fc56858f4405c7 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Sat, 12 Jul 2025 23:08:12 +0200 Subject: [PATCH 19/24] Fix clippy errors after rebasing --- tests/difftests/bin/src/differ.rs | 56 ++++++------- tests/difftests/bin/src/runner.rs | 79 +++++++++++++------ .../difftests/lib/src/scaffold/compute/ash.rs | 25 +++--- .../lib/src/scaffold/compute/wgpu.rs | 28 +++---- 4 files changed, 105 insertions(+), 83 deletions(-) diff --git a/tests/difftests/bin/src/differ.rs b/tests/difftests/bin/src/differ.rs index 5c84b4ceee..6b34b91e8a 100644 --- a/tests/difftests/bin/src/differ.rs +++ b/tests/difftests/bin/src/differ.rs @@ -1,3 +1,5 @@ +#![allow(clippy::unimplemented)] + use difftest::config::OutputType; use std::marker::PhantomData; @@ -67,8 +69,8 @@ impl OutputDiffer for RawDiffer { (Some(&b1), Some(&b2)) if b1 != b2 => { differences.push(Difference { index: i, - value1: format!("{}", b1), - value2: format!("{}", b2), + value1: format!("{b1}"), + value2: format!("{b2}"), absolute_diff: DiffMagnitude::Incomparable, relative_diff: DiffMagnitude::Incomparable, }); @@ -76,7 +78,7 @@ impl OutputDiffer for RawDiffer { (Some(&b1), None) => { differences.push(Difference { index: i, - value1: format!("{}", b1), + value1: format!("{b1}"), value2: "".to_string(), absolute_diff: DiffMagnitude::Incomparable, relative_diff: DiffMagnitude::Incomparable, @@ -86,7 +88,7 @@ impl OutputDiffer for RawDiffer { differences.push(Difference { index: i, value1: "".to_string(), - value2: format!("{}", b2), + value2: format!("{b2}"), absolute_diff: DiffMagnitude::Incomparable, relative_diff: DiffMagnitude::Incomparable, }); @@ -129,8 +131,8 @@ impl DifferenceDisplay for RawDiffer { }; ( format!("{:>3}", format!("{:02x}", byte)), - format!("{:3}", byte), - format!("{:^5}", ascii), + format!("{byte:3}"), + format!("{ascii:^5}"), ) }; @@ -151,8 +153,8 @@ impl DifferenceDisplay for RawDiffer { }; ( format!("{:>3}", format!("{:02x}", byte)), - format!("{:3}", byte), - format!("{:^5}", ascii), + format!("{byte:3}"), + format!("{ascii:^5}"), ) }; @@ -191,11 +193,7 @@ impl DifferenceDisplay for RawDiffer { let mut result = table.to_string(); if diffs.len() > 10 { - let last_line_width = result - .lines() - .last() - .map(|l| l.chars().count()) - .unwrap_or(0); + let last_line_width = result.lines().last().map_or(0, |l| l.chars().count()); result.push_str(&format!( "\n{:>width$}", format!("... {} more differences", diffs.len() - 10), @@ -226,7 +224,7 @@ impl DifferenceDisplay for RawDiffer { for (i, chunk) in output.chunks(16).enumerate() { write!(file, "{:08x}: ", i * 16)?; for byte in chunk { - write!(file, "{:02x} ", byte)?; + write!(file, "{byte:02x} ")?; } writeln!(file)?; } @@ -255,7 +253,7 @@ impl NumericType for f32 { "F32" } fn format_value(value: Self) -> String { - format!("{:.9}", value) + format!("{value:.9}") } fn can_have_relative_diff() -> bool { true @@ -280,7 +278,7 @@ impl NumericType for u32 { "U32" } fn format_value(value: Self) -> String { - format!("{}", value) + format!("{value}") } fn can_have_relative_diff() -> bool { true @@ -379,7 +377,7 @@ impl DifferenceDisplay for NumericDiffer { let abs_str = match &d.absolute_diff { DiffMagnitude::Numeric(val) => { if T::can_have_relative_diff() { - format!("{:.3e}", val) + format!("{val:.3e}") } else { format!("{}", *val as u64) } @@ -431,11 +429,7 @@ impl DifferenceDisplay for NumericDiffer { let mut result = table.to_string(); if diffs.len() > 10 { - let last_line_width = result - .lines() - .last() - .map(|l| l.chars().count()) - .unwrap_or(0); + let last_line_width = result.lines().last().map_or(0, |l| l.chars().count()); result.push_str(&format!( "\n{:>width$}", format!("... {} more differences", diffs.len() - 10), @@ -482,9 +476,9 @@ impl From for Box { match output_type { OutputType::Raw => Box::new(RawDiffer), OutputType::F32 => Box::new(F32Differ::default()), - OutputType::F64 => todo!("F64Differ not implemented yet"), + OutputType::F64 => unimplemented!("F64Differ not implemented yet"), OutputType::U32 => Box::new(U32Differ::default()), - OutputType::I32 => todo!("I32Differ not implemented yet"), + OutputType::I32 => unimplemented!("I32Differ not implemented yet"), } } } @@ -494,9 +488,9 @@ impl From for Box { match output_type { OutputType::Raw => Box::new(RawDiffer), OutputType::F32 => Box::new(F32Differ::default()), - OutputType::F64 => todo!("F64Display not implemented yet"), + OutputType::F64 => unimplemented!("F64Differ not implemented yet"), OutputType::U32 => Box::new(U32Differ::default()), - OutputType::I32 => todo!("I32Display not implemented yet"), + OutputType::I32 => unimplemented!("I32Differ not implemented yet"), } } } @@ -533,11 +527,11 @@ mod tests { assert_eq!(diffs[0].value2, "5"); match &diffs[0].absolute_diff { DiffMagnitude::Numeric(val) => assert_eq!(*val, 3.0), - _ => panic!("Expected numeric difference"), + DiffMagnitude::Incomparable => panic!("Expected numeric difference"), } match &diffs[0].relative_diff { DiffMagnitude::Numeric(val) => assert_eq!(*val, 3.0 / 5.0), // 3/5 = 0.6 - _ => panic!("Expected numeric relative diff for U32"), + DiffMagnitude::Incomparable => panic!("Expected numeric relative diff for U32"), } // Check second difference (index 3: 4 vs 7) @@ -546,7 +540,7 @@ mod tests { assert_eq!(diffs[1].value2, "7"); match &diffs[1].absolute_diff { DiffMagnitude::Numeric(val) => assert_eq!(*val, 3.0), - _ => panic!("Expected numeric difference"), + DiffMagnitude::Incomparable => panic!("Expected numeric difference"), } } @@ -665,12 +659,12 @@ mod tests { match numeric { DiffMagnitude::Numeric(val) => assert_eq!(val, 42.0), - _ => panic!("Expected numeric"), + DiffMagnitude::Incomparable => panic!("Expected numeric"), } match incomparable { DiffMagnitude::Incomparable => {} - _ => panic!("Expected incomparable"), + DiffMagnitude::Numeric(_) => panic!("Expected incomparable"), } } } diff --git a/tests/difftests/bin/src/runner.rs b/tests/difftests/bin/src/runner.rs index 0abba0cdac..13738edcd2 100644 --- a/tests/difftests/bin/src/runner.rs +++ b/tests/difftests/bin/src/runner.rs @@ -91,7 +91,7 @@ struct ErrorReport { impl ErrorReport { fn new(test_info: TestInfo, differ_name: &'static str, epsilon: Option) -> Self { let epsilon_str = match epsilon { - Some(e) => format!(", ε={}", e), + Some(e) => format!(", ε={e}"), None => String::new(), }; Self { @@ -106,7 +106,7 @@ impl ErrorReport { self.summary_parts .push(format!("{} differences", differences.len())); - if let Some((max_diff, max_rel)) = self.calculate_max_differences(differences) { + if let Some((max_diff, max_rel)) = Self::calculate_max_differences(differences) { self.summary_parts .push(format!("max: {:.3e} ({:.2}%)", max_diff, max_rel * 100.0)); } @@ -114,8 +114,7 @@ impl ErrorReport { } fn set_distinct_outputs(&mut self, count: usize) { - self.summary_parts - .push(format!("{} distinct outputs", count)); + self.summary_parts.push(format!("{count} distinct outputs")); } fn add_output_files( @@ -159,7 +158,7 @@ impl ErrorReport { fn add_summary_line(&mut self, differences: &[Difference]) { if !differences.is_empty() { - if let Some((max_diff, max_rel)) = self.calculate_max_differences(differences) { + if let Some((max_diff, max_rel)) = Self::calculate_max_differences(differences) { self.lines.push(format!( "• {} differences, max: {:.3e} ({:.2}%)", differences.len(), @@ -173,7 +172,7 @@ impl ErrorReport { } } - fn calculate_max_differences(&self, differences: &[Difference]) -> Option<(f64, f64)> { + fn calculate_max_differences(differences: &[Difference]) -> Option<(f64, f64)> { let max_diff = differences .iter() .filter_map(|d| match &d.absolute_diff { @@ -369,8 +368,7 @@ impl Runner { error!("Failed to parse metadata for package '{}'", pkg_name); return Err(RunnerError::Config { msg: format!( - "Failed to parse metadata for package '{}': {}", - pkg_name, e + "Failed to parse metadata for package '{pkg_name}': {e}" ), }); } @@ -472,7 +470,7 @@ impl Runner { let mut found_group = false; for (group_output, group) in groups.iter_mut() { - if self.outputs_match(&po.output, group_output, epsilon, output_type) { + if Self::outputs_match(&po.output, group_output, epsilon, output_type) { group.push(po); found_group = true; break; @@ -488,7 +486,6 @@ impl Runner { } fn outputs_match( - &self, output1: &[u8], output2: &[u8], epsilon: Option, @@ -862,19 +859,25 @@ mod tests { #[test] fn test_outputs_match_no_epsilon() { - let runner = Runner::new(PathBuf::from("dummy_base")); - // Exact match should work - assert!(runner.outputs_match(b"hello", b"hello", None, OutputType::Raw)); + assert!(Runner::outputs_match( + b"hello", + b"hello", + None, + OutputType::Raw + )); // Different content should not match - assert!(!runner.outputs_match(b"hello", b"world", None, OutputType::Raw)); + assert!(!Runner::outputs_match( + b"hello", + b"world", + None, + OutputType::Raw + )); } #[test] fn test_outputs_match_with_epsilon_f32() { - let runner = Runner::new(PathBuf::from("dummy_base")); - // Prepare test data - two floats with small difference let val1: f32 = 1.0; let val2: f32 = 1.00001; @@ -884,19 +887,32 @@ mod tests { let bytes2 = bytemuck::cast_slice(&arr2); // Should not match without epsilon - assert!(!runner.outputs_match(bytes1, bytes2, None, OutputType::F32)); + assert!(!Runner::outputs_match( + bytes1, + bytes2, + None, + OutputType::F32 + )); // Should match with sufficient epsilon - assert!(runner.outputs_match(bytes1, bytes2, Some(0.0001), OutputType::F32)); + assert!(Runner::outputs_match( + bytes1, + bytes2, + Some(0.0001), + OutputType::F32 + )); // Should not match with too small epsilon - assert!(!runner.outputs_match(bytes1, bytes2, Some(0.000001), OutputType::F32)); + assert!(!Runner::outputs_match( + bytes1, + bytes2, + Some(0.000001), + OutputType::F32 + )); } #[test] fn test_outputs_match_with_epsilon_f64() { - let runner = Runner::new(PathBuf::from("dummy_base")); - // Prepare test data - two doubles with small difference let val1: f64 = 1.0; let val2: f64 = 1.00001; @@ -906,13 +922,28 @@ mod tests { let bytes2 = bytemuck::cast_slice(&arr2); // Should not match without epsilon - assert!(!runner.outputs_match(bytes1, bytes2, None, OutputType::F64)); + assert!(!Runner::outputs_match( + bytes1, + bytes2, + None, + OutputType::F64 + )); // Should match with sufficient epsilon - assert!(runner.outputs_match(bytes1, bytes2, Some(0.0001), OutputType::F64)); + assert!(Runner::outputs_match( + bytes1, + bytes2, + Some(0.0001), + OutputType::F64 + )); // Should not match with too small epsilon - assert!(!runner.outputs_match(bytes1, bytes2, Some(0.000001), OutputType::F64)); + assert!(!Runner::outputs_match( + bytes1, + bytes2, + Some(0.000001), + OutputType::F64 + )); } #[test] diff --git a/tests/difftests/lib/src/scaffold/compute/ash.rs b/tests/difftests/lib/src/scaffold/compute/ash.rs index e2417aa5c8..0977e98a24 100644 --- a/tests/difftests/lib/src/scaffold/compute/ash.rs +++ b/tests/difftests/lib/src/scaffold/compute/ash.rs @@ -2,7 +2,7 @@ use super::backend::{BufferConfig, BufferUsage, ComputeBackend}; use anyhow::{Context, Result}; use ash::vk; use ash::vk::DescriptorType; -use std::ffi::{CStr, CString}; +use std::ffi::CString; pub struct AshBackend { instance: ash::Instance, @@ -19,23 +19,20 @@ impl AshBackend { type_filter: u32, properties: vk::MemoryPropertyFlags, ) -> Option { - for i in 0..self.memory_properties.memory_type_count { - if (type_filter & (1 << i)) != 0 + (0..self.memory_properties.memory_type_count).find(|&i| { + (type_filter & (1 << i)) != 0 && self.memory_properties.memory_types[i as usize] .property_flags .contains(properties) - { - return Some(i); - } - } - None + }) } fn create_buffer(&self, config: &BufferConfig) -> Result<(vk::Buffer, vk::DeviceMemory)> { unsafe { let usage = match config.usage { - BufferUsage::Storage => vk::BufferUsageFlags::STORAGE_BUFFER, - BufferUsage::StorageReadOnly => vk::BufferUsageFlags::STORAGE_BUFFER, + BufferUsage::Storage | BufferUsage::StorageReadOnly => { + vk::BufferUsageFlags::STORAGE_BUFFER + } BufferUsage::Uniform => vk::BufferUsageFlags::UNIFORM_BUFFER, }; let buffer = self @@ -80,7 +77,7 @@ impl AshBackend { .map_memory(memory, 0, config.size, vk::MemoryMapFlags::empty()) .context("Failed to map memory")?; - std::ptr::copy_nonoverlapping(data.as_ptr(), mapped_ptr as *mut u8, data.len()); + std::ptr::copy_nonoverlapping(data.as_ptr(), mapped_ptr.cast::(), data.len()); self.device.unmap_memory(memory); } @@ -100,9 +97,9 @@ impl ComputeBackend for AshBackend { .create_instance( &vk::InstanceCreateInfo::default().application_info( &vk::ApplicationInfo::default() - .application_name(CStr::from_bytes_with_nul_unchecked(b"difftest\0")) + .application_name(c"difftest") .application_version(vk::make_api_version(0, 1, 0, 0)) - .engine_name(CStr::from_bytes_with_nul_unchecked(b"difftest\0")) + .engine_name(c"difftest") .engine_version(vk::make_api_version(0, 1, 0, 0)) .api_version(vk::API_VERSION_1_2), ), @@ -350,7 +347,7 @@ impl ComputeBackend for AshBackend { // Read buffer results let mut results = Vec::new(); - for (_i, (memory, config)) in buffer_memories.iter().zip(&buffers).enumerate() { + for (memory, config) in buffer_memories.iter().zip(&buffers) { let mut data = vec![0u8; config.size as usize]; let mapped_ptr = self diff --git a/tests/difftests/lib/src/scaffold/compute/wgpu.rs b/tests/difftests/lib/src/scaffold/compute/wgpu.rs index c16380fbb5..56a48d6e47 100644 --- a/tests/difftests/lib/src/scaffold/compute/wgpu.rs +++ b/tests/difftests/lib/src/scaffold/compute/wgpu.rs @@ -369,7 +369,7 @@ where } } -/// wgpu backend implementation for the generic ComputeBackend trait +/// wgpu backend implementation for the generic `ComputeBackend` trait pub struct WgpuBackend { device: Arc, queue: Arc, @@ -431,13 +431,13 @@ impl ComputeBackend for WgpuBackend { let buffer = if let Some(initial_data) = &buffer_config.initial_data { self.device .create_buffer_init(&wgpu::util::BufferInitDescriptor { - label: Some(&format!("Buffer {}", i)), + label: Some(&format!("Buffer {i}")), contents: initial_data, usage, }) } else { let buffer = self.device.create_buffer(&wgpu::BufferDescriptor { - label: Some(&format!("Buffer {}", i)), + label: Some(&format!("Buffer {i}")), size: buffer_config.size, usage, mapped_at_creation: true, @@ -494,7 +494,7 @@ impl ComputeBackend for WgpuBackend { BufferUsage::Storage | BufferUsage::StorageReadOnly ) { let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor { - label: Some(&format!("Staging Buffer {}", i)), + label: Some(&format!("Staging Buffer {i}")), size: buffer_config.size, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, @@ -516,7 +516,7 @@ impl ComputeBackend for WgpuBackend { // Read back results let mut results = Vec::new(); - for staging_buffer in staging_buffers.into_iter() { + for staging_buffer in staging_buffers { if let Some(buffer) = staging_buffer { let buffer_slice = buffer.slice(..); let (sender, receiver) = futures::channel::oneshot::channel(); @@ -631,13 +631,13 @@ where let buffer = if let Some(initial_data) = &buffer_config.initial_data { device.create_buffer_init(&wgpu::util::BufferInitDescriptor { - label: Some(&format!("Buffer {}", i)), + label: Some(&format!("Buffer {i}")), contents: initial_data, usage, }) } else { let buffer = device.create_buffer(&wgpu::BufferDescriptor { - label: Some(&format!("Buffer {}", i)), + label: Some(&format!("Buffer {i}")), size: buffer_config.size, usage, mapped_at_creation: true, @@ -692,7 +692,7 @@ where BufferUsage::Storage | BufferUsage::StorageReadOnly ) { let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { - label: Some(&format!("Staging Buffer {}", i)), + label: Some(&format!("Staging Buffer {i}")), size: buffer_config.size, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, @@ -714,7 +714,7 @@ where // Read back results. let mut results = Vec::new(); - for staging_buffer in staging_buffers.into_iter() { + for staging_buffer in staging_buffers { if let Some(buffer) = staging_buffer { let buffer_slice = buffer.slice(..); let (sender, receiver) = futures::channel::oneshot::channel(); @@ -840,13 +840,13 @@ where let buffer = if let Some(initial_data) = &buffer_config.initial_data { device.create_buffer_init(&wgpu::util::BufferInitDescriptor { - label: Some(&format!("Buffer {}", i)), + label: Some(&format!("Buffer {i}")), contents: initial_data, usage, }) } else { let buffer = device.create_buffer(&wgpu::BufferDescriptor { - label: Some(&format!("Buffer {}", i)), + label: Some(&format!("Buffer {i}")), size: buffer_config.size, usage, mapped_at_creation: true, @@ -902,7 +902,7 @@ where BufferUsage::Storage | BufferUsage::StorageReadOnly ) { let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { - label: Some(&format!("Staging Buffer {}", i)), + label: Some(&format!("Staging Buffer {i}")), size: buffer_config.size, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, @@ -924,7 +924,7 @@ where // Read back results. let mut results = Vec::new(); - for staging_buffer in staging_buffers.into_iter() { + for staging_buffer in staging_buffers { if let Some(buffer) = staging_buffer { let buffer_slice = buffer.slice(..); let (sender, receiver) = futures::channel::oneshot::channel(); @@ -950,7 +950,7 @@ where let buffers = self.buffers.clone(); let results = self.run()?; // Write first storage buffer output to file. - for (_i, (data, buffer_config)) in results.iter().zip(&buffers).enumerate() { + for (data, buffer_config) in results.iter().zip(&buffers) { if buffer_config.usage == BufferUsage::Storage && !data.is_empty() { let mut f = File::create(&config.output_path)?; f.write_all(data)?; From 12ff9a36b57232a7f53a8f9f4f2d70d3830cecf1 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Sun, 13 Jul 2025 00:52:55 +0200 Subject: [PATCH 20/24] Cleanup / make more consistent --- .../atomic_ops/atomic_ops-rust/src/main.rs | 14 +++----- .../memory_barriers-rust/src/lib.rs | 35 +++++++++++-------- .../vector_swizzle-rust/src/lib.rs | 1 - 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/main.rs b/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/main.rs index bcef4ca050..934f54ea3c 100644 --- a/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/main.rs +++ b/tests/difftests/tests/arch/atomic_ops/atomic_ops-rust/src/main.rs @@ -1,10 +1,9 @@ -#[cfg(not(target_arch = "spirv"))] -fn main() { - use difftest::config::Config; - use difftest::scaffold::compute::{ - BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTestMultiBuffer, - }; +use difftest::config::Config; +use difftest::scaffold::compute::{ + BufferConfig, BufferUsage, RustComputeShader, WgpuComputeTestMultiBuffer, +}; +fn main() { let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); // Initialize counter buffer with test values @@ -32,6 +31,3 @@ fn main() { test.run_test(&config).unwrap(); } - -#[cfg(target_arch = "spirv")] -fn main() {} diff --git a/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/lib.rs b/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/lib.rs index c1db9678d5..63d75c3aba 100644 --- a/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/lib.rs +++ b/tests/difftests/tests/arch/memory_barriers/memory_barriers-rust/src/lib.rs @@ -27,24 +27,29 @@ pub fn main_cs( let mut result = shared[lid]; // Different threads perform different operations - if lid % 4 == 0 { - // Read from neighboring thread's data - if lid + 1 < 64 { - result += shared[lid + 1]; + match lid % 4 { + 0 => { + // Read from neighboring thread's data + if lid + 1 < 64 { + result += shared[lid + 1]; + } } - } else if lid % 4 == 1 { - // Read from previous thread's data - if lid > 0 { - result += shared[lid - 1]; + 1 => { + // Read from previous thread's data + if lid > 0 { + result += shared[lid - 1]; + } } - } else if lid % 4 == 2 { - // Sum reduction within groups of 4 - if lid + 2 < 64 { - result = shared[lid] + shared[lid + 1] + shared[lid + 2]; + 2 => { + // Sum reduction within groups of 4 + if lid + 2 < 64 { + result = shared[lid] + shared[lid + 1] + shared[lid + 2]; + } + } + _ => { + // XOR with wrapped neighbor + result ^= shared[(lid + 32) % 64]; } - } else { - // XOR with wrapped neighbor - result ^= shared[(lid + 32) % 64]; } // Another barrier before writing back diff --git a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/lib.rs b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/lib.rs index 5412934294..acd1a7e9e8 100644 --- a/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/lib.rs +++ b/tests/difftests/tests/lang/core/ops/vector_swizzle/vector_swizzle-rust/src/lib.rs @@ -1,5 +1,4 @@ #![no_std] -#![cfg_attr(target_arch = "spirv", feature(asm_experimental_arch))] use glam::{Vec3Swizzles, Vec4, Vec4Swizzles}; use spirv_std::spirv; From 77ea53c70b5939faa8c2b2a3c582de19fba41fd0 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Sun, 13 Jul 2025 00:57:18 +0200 Subject: [PATCH 21/24] Fix paths in README --- tests/difftests/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/difftests/README.md b/tests/difftests/README.md index 0940f6a3bd..8f3bc0ea2e 100644 --- a/tests/difftests/README.md +++ b/tests/difftests/README.md @@ -98,7 +98,7 @@ The library provides helper types for common test patterns: - `WgpuComputeTest` - Single buffer compute shader test - `WgpuComputeTestMultiBuffer` - Multi-buffer compute shader test with input/output separation -- `WgpuComputeTestPushConstant` - Compute shader test with push constants support +- `WgpuComputeTestPushConstants` - Compute shader test with push constants support - `Skip` - Marks a test variant as skipped with a reason **Shader source types:** @@ -115,7 +115,7 @@ For examples, see: - [`tests/lang/core/ops/math_ops/`](tests/lang/core/ops/math_ops/) - Multi-buffer test with floating-point metadata -- [`tests/storage_class/push_constant/`](tests/storage_class/push_constant/) - Push +- [`tests/arch/push_constants/`](tests/arch/push_constants/) - Push constants usage - [`tests/arch/workgroup_memory/`](tests/arch/workgroup_memory/) - Workgroup memory usage From 0e026d48dfbb73e204bcf43f5aa0caff897b7262 Mon Sep 17 00:00:00 2001 From: Firestar99 Date: Mon, 14 Jul 2025 13:42:25 +0200 Subject: [PATCH 22/24] proper `TestMetadata` constructor functions --- tests/difftests/README.md | 17 ++++----- tests/difftests/lib/src/config.rs | 38 ++++++++++++++++++- .../workgroup_memory-ash/src/main.rs | 9 ++--- .../workgroup_memory-rust/src/main.rs | 9 ++--- .../workgroup_memory-wgsl/src/main.rs | 9 ++--- .../ops/math_ops/math_ops-rust/src/main.rs | 9 ++--- .../ops/math_ops/math_ops-wgsl/src/main.rs | 9 ++--- .../matrix_ops/matrix_ops-rust/src/main.rs | 10 ++--- .../matrix_ops/matrix_ops-wgsl/src/main.rs | 9 ++--- .../vector_ops/vector_ops-rust/src/main.rs | 9 ++--- .../vector_ops/vector_ops-wgsl/src/main.rs | 9 ++--- 11 files changed, 71 insertions(+), 66 deletions(-) diff --git a/tests/difftests/README.md b/tests/difftests/README.md index 8f3bc0ea2e..8edb58f144 100644 --- a/tests/difftests/README.md +++ b/tests/difftests/README.md @@ -77,11 +77,7 @@ fn main() { file.write_all(&output).expect("Failed to write output"); // Optional: Write metadata for floating-point comparison - let metadata = TestMetadata { - epsilon: Some(0.00001), // Allow differences up to 1e-5 - output_type: OutputType::F32, // Interpret output as f32 array - }; - config.write_metadata(&metadata).expect("Failed to write metadata"); + config.write_metadata(&TestMetadata::f32(0.00001)).expect("Failed to write metadata"); } ``` @@ -130,15 +126,16 @@ outputs: use difftest::config::{TestMetadata, OutputType}; // Write metadata before or after writing output +let metadata = TestMetadata::f32(0.00001); // output is f32 with some epsilon +config.write_metadata(&metadata)?; + +// Alternative: Construct TestMetadata yourself let metadata = TestMetadata { - epsilon: Some(0.00001), // Maximum allowed difference (default: None) + epsilon: Some(0.00001), // Maximum allowed epsilon / difference (default: None) output_type: OutputType::F32, // How to interpret output data (default: Raw) + ..TestMetadata::default() }; config.write_metadata(&metadata)?; - -// Alternative: Use the helper method for common cases -let metadata = TestMetadata::with_epsilon(0.00001); // Sets epsilon, keeps default output_type -config.write_metadata(&metadata)?; ``` **Metadata fields:** diff --git a/tests/difftests/lib/src/config.rs b/tests/difftests/lib/src/config.rs index 297d64f0ed..e3a8709c41 100644 --- a/tests/difftests/lib/src/config.rs +++ b/tests/difftests/lib/src/config.rs @@ -51,13 +51,47 @@ pub enum OutputType { } impl TestMetadata { - /// Create metadata with a specific epsilon value, keeping default output type - pub fn with_epsilon(epsilon: f32) -> Self { + /// Create metadata for f32 with some epsilon value + pub fn f32(epsilon: f32) -> Self { Self { + output_type: OutputType::F32, epsilon: Some(epsilon), ..Default::default() } } + + /// Create metadata for f64 with some epsilon value + pub fn f64(epsilon: f32) -> Self { + Self { + output_type: OutputType::F64, + epsilon: Some(epsilon), + ..Default::default() + } + } + + /// Create metadata for u32 + pub fn u32() -> Self { + Self { + output_type: OutputType::U32, + ..Default::default() + } + } + + /// Create metadata for i32 + pub fn i32() -> Self { + Self { + output_type: OutputType::I32, + ..Default::default() + } + } + + /// Create metadata for raw hex values + pub fn raw() -> Self { + Self { + output_type: OutputType::Raw, + ..Default::default() + } + } } impl Config { diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs index bfbad19336..839eff10c1 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-ash/src/main.rs @@ -49,12 +49,9 @@ fn main() { .unwrap(); // Write metadata for U32 comparison - let metadata = difftest::config::TestMetadata { - epsilon: None, - output_type: difftest::config::OutputType::U32, - ..Default::default() - }; - config.write_metadata(&metadata).unwrap(); + config + .write_metadata(&difftest::config::TestMetadata::u32()) + .unwrap(); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs index 6d80e70222..9e58d86616 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-rust/src/main.rs @@ -31,12 +31,9 @@ fn main() { ); // Write metadata for U32 comparison - let metadata = difftest::config::TestMetadata { - epsilon: None, - output_type: difftest::config::OutputType::U32, - ..Default::default() - }; - config.write_metadata(&metadata).unwrap(); + config + .write_metadata(&difftest::config::TestMetadata::u32()) + .unwrap(); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs index a6470b95ca..0d1e931df3 100644 --- a/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs +++ b/tests/difftests/tests/arch/workgroup_memory/workgroup_memory-wgsl/src/main.rs @@ -30,12 +30,9 @@ fn main() { ); // Write metadata for U32 comparison - let metadata = difftest::config::TestMetadata { - epsilon: None, - output_type: difftest::config::OutputType::U32, - ..Default::default() - }; - config.write_metadata(&metadata).unwrap(); + config + .write_metadata(&difftest::config::TestMetadata::u32()) + .unwrap(); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs index 23c0e027df..13e6ee0699 100644 --- a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/main.rs @@ -52,12 +52,9 @@ fn main() { ); // Write metadata file - let metadata = difftest::config::TestMetadata { - epsilon: Some(2e-6), // Small epsilon for last-bit differences - output_type: difftest::config::OutputType::F32, - ..Default::default() - }; - config.write_metadata(&metadata).unwrap(); + config + .write_metadata(&difftest::config::TestMetadata::f32(2e-6)) + .unwrap(); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs index b97a39a536..1c87d3fc9c 100644 --- a/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/src/main.rs @@ -51,12 +51,9 @@ fn main() { ); // Write metadata file - let metadata = difftest::config::TestMetadata { - epsilon: Some(2e-6), // Small epsilon for last-bit differences - output_type: difftest::config::OutputType::F32, - ..Default::default() - }; - config.write_metadata(&metadata).unwrap(); + config + .write_metadata(&difftest::config::TestMetadata::f32(2e-6)) + .unwrap(); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs index 6597c253f7..ef1317bb8b 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/main.rs @@ -52,12 +52,10 @@ fn main() { ); // Write metadata file - let metadata = difftest::config::TestMetadata { - epsilon: Some(2e-5), - output_type: difftest::config::OutputType::F32, - ..Default::default() - }; - config.write_metadata(&metadata).unwrap(); + // this one requires higher epsilon, noticed on an RDNA2 680M (Ryzen iGPU) with radv + config + .write_metadata(&difftest::config::TestMetadata::f32(2e-5)) + .unwrap(); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs index c71c23872e..76d715f53b 100644 --- a/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-wgsl/src/main.rs @@ -51,12 +51,9 @@ fn main() { ); // Write metadata file - let metadata = difftest::config::TestMetadata { - epsilon: Some(2e-5), - output_type: difftest::config::OutputType::F32, - ..Default::default() - }; - config.write_metadata(&metadata).unwrap(); + config + .write_metadata(&difftest::config::TestMetadata::f32(2e-5)) + .unwrap(); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs index 3ef378155b..75364385ac 100644 --- a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-rust/src/main.rs @@ -52,12 +52,9 @@ fn main() { ); // Write metadata file - let metadata = difftest::config::TestMetadata { - epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding - output_type: difftest::config::OutputType::F32, - ..Default::default() - }; - config.write_metadata(&metadata).unwrap(); + config + .write_metadata(&difftest::config::TestMetadata::f32(1e-5)) + .unwrap(); test.run_test(&config).unwrap(); } diff --git a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs index 03f2e97b44..866826b76c 100644 --- a/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs +++ b/tests/difftests/tests/lang/core/ops/vector_ops/vector_ops-wgsl/src/main.rs @@ -51,12 +51,9 @@ fn main() { ); // Write metadata file - let metadata = difftest::config::TestMetadata { - epsilon: Some(1e-5), // 1e-5 - appropriate for 5 decimal place rounding - output_type: difftest::config::OutputType::F32, - ..Default::default() - }; - config.write_metadata(&metadata).unwrap(); + config + .write_metadata(&difftest::config::TestMetadata::f32(1e-5)) + .unwrap(); test.run_test(&config).unwrap(); } From ab34663570f1b706076e827c530555688a1703d7 Mon Sep 17 00:00:00 2001 From: Firestar99 Date: Mon, 14 Jul 2025 15:31:19 +0200 Subject: [PATCH 23/24] move shaders to separate mod --- .../lib/src/scaffold/compute/backend.rs | 22 ++- .../difftests/lib/src/scaffold/compute/mod.rs | 4 +- .../lib/src/scaffold/compute/wgpu.rs | 187 +----------------- tests/difftests/lib/src/scaffold/mod.rs | 1 + .../difftests/lib/src/scaffold/shader/mod.rs | 20 ++ .../src/scaffold/shader/rust_gpu_shader.rs | 98 +++++++++ .../lib/src/scaffold/shader/wgsl_shader.rs | 69 +++++++ 7 files changed, 215 insertions(+), 186 deletions(-) create mode 100644 tests/difftests/lib/src/scaffold/shader/mod.rs create mode 100644 tests/difftests/lib/src/scaffold/shader/rust_gpu_shader.rs create mode 100644 tests/difftests/lib/src/scaffold/shader/wgsl_shader.rs diff --git a/tests/difftests/lib/src/scaffold/compute/backend.rs b/tests/difftests/lib/src/scaffold/compute/backend.rs index 5696f965b9..7be5237a45 100644 --- a/tests/difftests/lib/src/scaffold/compute/backend.rs +++ b/tests/difftests/lib/src/scaffold/compute/backend.rs @@ -1,4 +1,5 @@ use crate::config::Config; +use crate::scaffold::shader::SpirvShader; use anyhow::Result; /// Configuration for a GPU buffer @@ -9,6 +10,25 @@ pub struct BufferConfig { pub initial_data: Option>, } +impl BufferConfig { + pub fn writeback(size: usize) -> Self { + Self { + size: size as u64, + usage: BufferUsage::Storage, + initial_data: None, + } + } + + pub fn read_only(slice: &[A]) -> Self { + let vec = bytemuck::cast_slice(slice).to_vec(); + Self { + size: vec.len() as u64, + usage: BufferUsage::StorageReadOnly, + initial_data: Some(vec), + } + } +} + /// Buffer usage type #[derive(Clone, Copy, PartialEq)] pub enum BufferUsage { @@ -17,8 +37,6 @@ pub enum BufferUsage { Uniform, } -use super::SpirvShader; - /// A generic trait for compute backends pub trait ComputeBackend: Sized { /// Initialize the backend diff --git a/tests/difftests/lib/src/scaffold/compute/mod.rs b/tests/difftests/lib/src/scaffold/compute/mod.rs index edb0daad2f..0c89c51837 100644 --- a/tests/difftests/lib/src/scaffold/compute/mod.rs +++ b/tests/difftests/lib/src/scaffold/compute/mod.rs @@ -2,9 +2,9 @@ mod ash; mod backend; mod wgpu; +pub use crate::scaffold::shader::*; pub use ash::AshBackend; pub use backend::{BufferConfig, BufferUsage, ComputeBackend, ComputeShaderTest, ComputeTest}; pub use wgpu::{ - RustComputeShader, SpirvShader, WgpuBackend, WgpuComputeTest, WgpuComputeTestMultiBuffer, - WgpuComputeTestPushConstants, WgpuShader, WgslComputeShader, + WgpuBackend, WgpuComputeTest, WgpuComputeTestMultiBuffer, WgpuComputeTestPushConstants, }; diff --git a/tests/difftests/lib/src/scaffold/compute/wgpu.rs b/tests/difftests/lib/src/scaffold/compute/wgpu.rs index 56a48d6e47..dc5c2d8c97 100644 --- a/tests/difftests/lib/src/scaffold/compute/wgpu.rs +++ b/tests/difftests/lib/src/scaffold/compute/wgpu.rs @@ -1,152 +1,17 @@ +use super::backend::{self, ComputeBackend}; use crate::config::Config; +use crate::scaffold::shader::RustComputeShader; +use crate::scaffold::shader::WgpuShader; +use crate::scaffold::shader::WgslComputeShader; use anyhow::Context; use bytemuck::Pod; use futures::executor::block_on; -use spirv_builder::{ModuleResult, SpirvBuilder}; -use std::{ - borrow::Cow, - env, - fs::{self, File}, - io::Write, - path::PathBuf, - sync::Arc, -}; +use std::{borrow::Cow, fs::File, io::Write, sync::Arc}; use wgpu::{PipelineCompilationOptions, util::DeviceExt}; -use super::backend::{self, ComputeBackend}; - pub type BufferConfig = backend::BufferConfig; pub type BufferUsage = backend::BufferUsage; -/// Trait for shaders that can provide SPIRV bytes. -pub trait SpirvShader { - /// Returns the SPIRV bytes and entry point name. - fn spirv_bytes(&self) -> anyhow::Result<(Vec, String)>; -} - -/// Trait for shaders that can create wgpu modules. -pub trait WgpuShader { - /// Creates a wgpu shader module. - fn create_module( - &self, - device: &wgpu::Device, - ) -> anyhow::Result<(wgpu::ShaderModule, Option)>; -} - -/// A compute shader written in Rust compiled with spirv-builder. -pub struct RustComputeShader { - pub path: PathBuf, - pub target: String, - pub capabilities: Vec, -} - -impl RustComputeShader { - pub fn new>(path: P) -> Self { - Self { - path: path.into(), - target: "spirv-unknown-vulkan1.1".to_string(), - capabilities: Vec::new(), - } - } - - pub fn with_target>(path: P, target: impl Into) -> Self { - Self { - path: path.into(), - target: target.into(), - capabilities: Vec::new(), - } - } - - pub fn with_capability(mut self, capability: spirv_builder::Capability) -> Self { - self.capabilities.push(capability); - self - } -} - -impl SpirvShader for RustComputeShader { - fn spirv_bytes(&self) -> anyhow::Result<(Vec, String)> { - let mut builder = SpirvBuilder::new(&self.path, &self.target) - .print_metadata(spirv_builder::MetadataPrintout::None) - .release(true) - .multimodule(false) - .shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit) - .preserve_bindings(true); - - for capability in &self.capabilities { - builder = builder.capability(*capability); - } - - let artifact = builder.build().context("SpirvBuilder::build() failed")?; - - if artifact.entry_points.len() != 1 { - anyhow::bail!( - "Expected exactly one entry point, found {}", - artifact.entry_points.len() - ); - } - let entry_point = artifact.entry_points.into_iter().next().unwrap(); - - let shader_bytes = match artifact.module { - ModuleResult::SingleModule(path) => fs::read(&path) - .with_context(|| format!("reading spv file '{}' failed", path.display()))?, - ModuleResult::MultiModule(_modules) => { - anyhow::bail!("MultiModule modules produced"); - } - }; - - Ok((shader_bytes, entry_point)) - } -} - -impl WgpuShader for RustComputeShader { - fn create_module( - &self, - device: &wgpu::Device, - ) -> anyhow::Result<(wgpu::ShaderModule, Option)> { - let (shader_bytes, entry_point) = self.spirv_bytes()?; - - if shader_bytes.len() % 4 != 0 { - anyhow::bail!("SPIR-V binary length is not a multiple of 4"); - } - let shader_words: Vec = bytemuck::cast_slice(&shader_bytes).to_vec(); - let module = device.create_shader_module(wgpu::ShaderModuleDescriptor { - label: Some("Compute Shader"), - source: wgpu::ShaderSource::SpirV(Cow::Owned(shader_words)), - }); - Ok((module, Some(entry_point))) - } -} - -/// A WGSL compute shader. -pub struct WgslComputeShader { - pub path: PathBuf, - pub entry_point: Option, -} - -impl WgslComputeShader { - pub fn new>(path: P, entry_point: Option) -> Self { - Self { - path: path.into(), - entry_point, - } - } -} - -impl WgpuShader for WgslComputeShader { - fn create_module( - &self, - device: &wgpu::Device, - ) -> anyhow::Result<(wgpu::ShaderModule, Option)> { - let shader_source = fs::read_to_string(&self.path) - .with_context(|| format!("reading wgsl source file '{}'", &self.path.display()))?; - let module = device.create_shader_module(wgpu::ShaderModuleDescriptor { - label: Some("Compute Shader"), - source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader_source)), - }); - Ok((module, self.entry_point.clone())) - } -} - /// Compute test that is generic over the shader type. pub struct WgpuComputeTest { shader: S, @@ -539,48 +404,6 @@ impl ComputeBackend for WgpuBackend { } } -/// For WGSL, the code checks for "shader.wgsl" then "compute.wgsl". -impl Default for WgslComputeShader { - fn default() -> Self { - let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); - let manifest_path = PathBuf::from(manifest_dir); - let shader_path = manifest_path.join("shader.wgsl"); - let compute_path = manifest_path.join("compute.wgsl"); - - let (file, source) = if shader_path.exists() { - ( - shader_path.clone(), - fs::read_to_string(&shader_path).unwrap_or_default(), - ) - } else if compute_path.exists() { - ( - compute_path.clone(), - fs::read_to_string(&compute_path).unwrap_or_default(), - ) - } else { - panic!("No default WGSL shader found in manifest directory"); - }; - - let entry_point = if source.contains("fn main_cs(") { - Some("main_cs".to_string()) - } else if source.contains("fn main(") { - Some("main".to_string()) - } else { - None - }; - - Self::new(file, entry_point) - } -} - -/// For the SPIR-V shader, the manifest directory is used as the build path. -impl Default for RustComputeShader { - fn default() -> Self { - let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); - Self::new(PathBuf::from(manifest_dir)) - } -} - impl WgpuComputeTestMultiBuffer where S: WgpuShader, diff --git a/tests/difftests/lib/src/scaffold/mod.rs b/tests/difftests/lib/src/scaffold/mod.rs index 91edf95419..74f5cb0167 100644 --- a/tests/difftests/lib/src/scaffold/mod.rs +++ b/tests/difftests/lib/src/scaffold/mod.rs @@ -1,4 +1,5 @@ pub mod compute; +pub mod shader; pub mod skip; pub use skip::Skip; diff --git a/tests/difftests/lib/src/scaffold/shader/mod.rs b/tests/difftests/lib/src/scaffold/shader/mod.rs new file mode 100644 index 0000000000..cef42c40fa --- /dev/null +++ b/tests/difftests/lib/src/scaffold/shader/mod.rs @@ -0,0 +1,20 @@ +mod rust_gpu_shader; +mod wgsl_shader; + +pub use rust_gpu_shader::RustComputeShader; +pub use wgsl_shader::WgslComputeShader; + +/// Trait for shaders that can provide SPIRV bytes. +pub trait SpirvShader { + /// Returns the SPIRV bytes and entry point name. + fn spirv_bytes(&self) -> anyhow::Result<(Vec, String)>; +} + +/// Trait for shaders that can create wgpu modules. +pub trait WgpuShader { + /// Creates a wgpu shader module. + fn create_module( + &self, + device: &wgpu::Device, + ) -> anyhow::Result<(wgpu::ShaderModule, Option)>; +} diff --git a/tests/difftests/lib/src/scaffold/shader/rust_gpu_shader.rs b/tests/difftests/lib/src/scaffold/shader/rust_gpu_shader.rs new file mode 100644 index 0000000000..a431e2ff41 --- /dev/null +++ b/tests/difftests/lib/src/scaffold/shader/rust_gpu_shader.rs @@ -0,0 +1,98 @@ +use crate::scaffold::shader::{SpirvShader, WgpuShader}; +use anyhow::Context; +use spirv_builder::{ModuleResult, SpirvBuilder}; +use std::borrow::Cow; +use std::path::PathBuf; +use std::{env, fs}; + +/// A compute shader written in Rust compiled with spirv-builder. +pub struct RustComputeShader { + pub path: PathBuf, + pub target: String, + pub capabilities: Vec, +} + +impl RustComputeShader { + pub fn new>(path: P) -> Self { + Self { + path: path.into(), + target: "spirv-unknown-vulkan1.1".to_string(), + capabilities: Vec::new(), + } + } + + pub fn with_target>(path: P, target: impl Into) -> Self { + Self { + path: path.into(), + target: target.into(), + capabilities: Vec::new(), + } + } + + pub fn with_capability(mut self, capability: spirv_builder::Capability) -> Self { + self.capabilities.push(capability); + self + } +} + +impl SpirvShader for RustComputeShader { + fn spirv_bytes(&self) -> anyhow::Result<(Vec, String)> { + let mut builder = SpirvBuilder::new(&self.path, &self.target) + .print_metadata(spirv_builder::MetadataPrintout::None) + .release(true) + .multimodule(false) + .shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit) + .preserve_bindings(true); + + for capability in &self.capabilities { + builder = builder.capability(*capability); + } + + let artifact = builder.build().context("SpirvBuilder::build() failed")?; + + if artifact.entry_points.len() != 1 { + anyhow::bail!( + "Expected exactly one entry point, found {}", + artifact.entry_points.len() + ); + } + let entry_point = artifact.entry_points.into_iter().next().unwrap(); + + let shader_bytes = match artifact.module { + ModuleResult::SingleModule(path) => fs::read(&path) + .with_context(|| format!("reading spv file '{}' failed", path.display()))?, + ModuleResult::MultiModule(_modules) => { + anyhow::bail!("MultiModule modules produced"); + } + }; + + Ok((shader_bytes, entry_point)) + } +} + +impl WgpuShader for RustComputeShader { + fn create_module( + &self, + device: &wgpu::Device, + ) -> anyhow::Result<(wgpu::ShaderModule, Option)> { + let (shader_bytes, entry_point) = self.spirv_bytes()?; + + if shader_bytes.len() % 4 != 0 { + anyhow::bail!("SPIR-V binary length is not a multiple of 4"); + } + let shader_words: Vec = bytemuck::cast_slice(&shader_bytes).to_vec(); + let module = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Compute Shader"), + source: wgpu::ShaderSource::SpirV(Cow::Owned(shader_words)), + }); + Ok((module, Some(entry_point))) + } +} + +/// For the SPIR-V shader, the manifest directory is used as the build path. +impl Default for RustComputeShader { + fn default() -> Self { + let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); + Self::new(PathBuf::from(manifest_dir)) + } +} diff --git a/tests/difftests/lib/src/scaffold/shader/wgsl_shader.rs b/tests/difftests/lib/src/scaffold/shader/wgsl_shader.rs new file mode 100644 index 0000000000..a7e3f354f9 --- /dev/null +++ b/tests/difftests/lib/src/scaffold/shader/wgsl_shader.rs @@ -0,0 +1,69 @@ +use crate::scaffold::shader::WgpuShader; +use anyhow::Context; +use std::borrow::Cow; +use std::path::PathBuf; +use std::{env, fs}; + +/// A WGSL compute shader. +pub struct WgslComputeShader { + pub path: PathBuf, + pub entry_point: Option, +} + +impl WgslComputeShader { + pub fn new>(path: P, entry_point: Option) -> Self { + Self { + path: path.into(), + entry_point, + } + } +} + +impl WgpuShader for WgslComputeShader { + fn create_module( + &self, + device: &wgpu::Device, + ) -> anyhow::Result<(wgpu::ShaderModule, Option)> { + let shader_source = fs::read_to_string(&self.path) + .with_context(|| format!("reading wgsl source file '{}'", &self.path.display()))?; + let module = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Compute Shader"), + source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader_source)), + }); + Ok((module, self.entry_point.clone())) + } +} + +/// For WGSL, the code checks for "shader.wgsl" then "compute.wgsl". +impl Default for WgslComputeShader { + fn default() -> Self { + let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); + let manifest_path = PathBuf::from(manifest_dir); + let shader_path = manifest_path.join("shader.wgsl"); + let compute_path = manifest_path.join("compute.wgsl"); + + let (file, source) = if shader_path.exists() { + ( + shader_path.clone(), + fs::read_to_string(&shader_path).unwrap_or_default(), + ) + } else if compute_path.exists() { + ( + compute_path.clone(), + fs::read_to_string(&compute_path).unwrap_or_default(), + ) + } else { + panic!("No default WGSL shader found in manifest directory"); + }; + + let entry_point = if source.contains("fn main_cs(") { + Some("main_cs".to_string()) + } else if source.contains("fn main(") { + Some("main".to_string()) + } else { + None + }; + + Self::new(file, entry_point) + } +} From 2420d873cfe5081f98cf898b0cf6a8fbea9aabec Mon Sep 17 00:00:00 2001 From: Firestar99 Date: Mon, 14 Jul 2025 15:48:25 +0200 Subject: [PATCH 24/24] unify result writing --- tests/difftests/lib/src/config.rs | 10 ++++++++++ .../difftests/lib/src/scaffold/compute/backend.rs | 5 +---- tests/difftests/lib/src/scaffold/compute/wgpu.rs | 14 +++++--------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/difftests/lib/src/config.rs b/tests/difftests/lib/src/config.rs index e3a8709c41..8107104668 100644 --- a/tests/difftests/lib/src/config.rs +++ b/tests/difftests/lib/src/config.rs @@ -1,4 +1,6 @@ use serde::{Deserialize, Serialize}; +use std::fs::File; +use std::io::Write; use std::{fs, path::Path}; #[derive(Debug, Deserialize, Serialize)] @@ -7,6 +9,14 @@ pub struct Config { pub metadata_path: std::path::PathBuf, } +impl Config { + pub fn write_result(&self, output: &[A]) -> anyhow::Result<()> { + let mut f = File::create(&self.output_path)?; + f.write_all(bytemuck::cast_slice(output))?; + Ok(()) + } +} + /// Test metadata that controls output comparison behavior /// /// This metadata is written alongside test output to specify how the test harness diff --git a/tests/difftests/lib/src/scaffold/compute/backend.rs b/tests/difftests/lib/src/scaffold/compute/backend.rs index 7be5237a45..2eeb26eeda 100644 --- a/tests/difftests/lib/src/scaffold/compute/backend.rs +++ b/tests/difftests/lib/src/scaffold/compute/backend.rs @@ -103,10 +103,7 @@ impl ComputeTest { // Write the first storage buffer output to the file for (output, buffer_config) in outputs.iter().zip(&buffers) { if matches!(buffer_config.usage, BufferUsage::Storage) && !output.is_empty() { - use std::fs::File; - use std::io::Write; - let mut f = File::create(&config.output_path)?; - f.write_all(output)?; + config.write_result(output)?; return Ok(()); } } diff --git a/tests/difftests/lib/src/scaffold/compute/wgpu.rs b/tests/difftests/lib/src/scaffold/compute/wgpu.rs index dc5c2d8c97..c1f64d8c30 100644 --- a/tests/difftests/lib/src/scaffold/compute/wgpu.rs +++ b/tests/difftests/lib/src/scaffold/compute/wgpu.rs @@ -6,7 +6,7 @@ use crate::scaffold::shader::WgslComputeShader; use anyhow::Context; use bytemuck::Pod; use futures::executor::block_on; -use std::{borrow::Cow, fs::File, io::Write, sync::Arc}; +use std::{borrow::Cow, sync::Arc}; use wgpu::{PipelineCompilationOptions, util::DeviceExt}; pub type BufferConfig = backend::BufferConfig; @@ -217,8 +217,7 @@ where /// Runs the compute shader with no input and writes the output to a file. pub fn run_test(self, config: &Config) -> anyhow::Result<()> { let output = self.run()?; - let mut f = File::create(&config.output_path)?; - f.write_all(&output)?; + config.write_result(&output)?; Ok(()) } @@ -228,8 +227,7 @@ where I: Sized + Pod, { let output = self.run_with_input(input)?; - let mut f = File::create(&config.output_path)?; - f.write_all(&output)?; + config.write_result(&output)?; Ok(()) } } @@ -565,8 +563,7 @@ where // Write the first storage buffer output to the file. for (output, buffer_config) in outputs.iter().zip(&buffers) { if matches!(buffer_config.usage, BufferUsage::Storage) && !output.is_empty() { - let mut f = File::create(&config.output_path)?; - f.write_all(output)?; + config.write_result(output)?; return Ok(()); } } @@ -775,8 +772,7 @@ where // Write first storage buffer output to file. for (data, buffer_config) in results.iter().zip(&buffers) { if buffer_config.usage == BufferUsage::Storage && !data.is_empty() { - let mut f = File::create(&config.output_path)?; - f.write_all(data)?; + config.write_result(data)?; return Ok(()); } }