From 7f708edd1f4b3698d658a0cfeb15c563bf7cad25 Mon Sep 17 00:00:00 2001 From: Teodor Tanasoaia <28601907+teoxoy@users.noreply.github.com> Date: Mon, 14 Oct 2024 15:02:01 +0200 Subject: [PATCH] Ensure safety of indirect dispatch (#5714) by injecting a compute shader that validates the content of the indirect buffer --- CHANGELOG.md | 1 + tests/tests/dispatch_workgroups_indirect.rs | 241 +++++++++++++ tests/tests/root.rs | 1 + wgpu-core/Cargo.toml | 4 + wgpu-core/src/command/bind.rs | 20 +- wgpu-core/src/command/compute.rs | 158 +++++++- wgpu-core/src/device/global.rs | 4 +- wgpu-core/src/device/mod.rs | 2 +- wgpu-core/src/device/resource.rs | 78 +++- wgpu-core/src/indirect_validation.rs | 378 ++++++++++++++++++++ wgpu-core/src/lib.rs | 2 + wgpu-core/src/pipeline.rs | 2 +- wgpu-core/src/resource.rs | 32 +- wgpu-core/src/snatch.rs | 6 + wgpu/Cargo.toml | 6 + 15 files changed, 913 insertions(+), 22 deletions(-) create mode 100644 tests/tests/dispatch_workgroups_indirect.rs create mode 100644 wgpu-core/src/indirect_validation.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 65ea26072d..4894b6a3ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -133,6 +133,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216). - Call `flush_mapped_ranges` when unmapping write-mapped buffers. By @teoxoy in [#6089](https://github.com/gfx-rs/wgpu/pull/6089). - When mapping buffers for reading, mark buffers as initialized only when they have `MAP_WRITE` usage. By @teoxoy in [#6178](https://github.com/gfx-rs/wgpu/pull/6178). - Add a separate pipeline constants error. By @teoxoy in [#6094](https://github.com/gfx-rs/wgpu/pull/6094). +- Ensure safety of indirect dispatch by injecting a compute shader that validates the content of the indirect buffer. By @teoxoy in [#5714](https://github.com/gfx-rs/wgpu/pull/5714) #### GLES / OpenGL diff --git a/tests/tests/dispatch_workgroups_indirect.rs b/tests/tests/dispatch_workgroups_indirect.rs new file mode 100644 index 0000000000..7ba3d7f638 --- /dev/null +++ b/tests/tests/dispatch_workgroups_indirect.rs @@ -0,0 +1,241 @@ +use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext}; + +/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12). +#[gpu_test] +static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .features(wgpu::Features::PUSH_CONSTANTS) + .downlevel_flags( + wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION, + ) + .limits(wgpu::Limits { + max_push_constant_size: 4, + ..wgpu::Limits::downlevel_defaults() + }) + .expect_fail(FailureCase::backend(wgt::Backends::DX12)), + ) + .run_async(|ctx| async move { + let num_workgroups = [1, 2, 3]; + let res = run_test(&ctx, &num_workgroups, false).await; + assert_eq!(res, num_workgroups); + }); + +/// Make sure that we discard (don't run) the dispatch if its size exceeds the device limit. +#[gpu_test] +static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .features(wgpu::Features::PUSH_CONSTANTS) + .downlevel_flags( + wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION, + ) + .limits(wgpu::Limits { + max_compute_workgroups_per_dimension: 10, + max_push_constant_size: 4, + ..wgpu::Limits::downlevel_defaults() + }) + .expect_fail(FailureCase::backend(wgt::Backends::DX12)), + ) + .run_async(|ctx| async move { + let max = ctx.device.limits().max_compute_workgroups_per_dimension; + + let res = run_test(&ctx, &[max, max, max], false).await; + assert_eq!(res, [max; 3]); + + let res = run_test(&ctx, &[max + 1, 1, 1], false).await; + assert_eq!(res, [0; 3]); + + let res = run_test(&ctx, &[1, max + 1, 1], false).await; + assert_eq!(res, [0; 3]); + + let res = run_test(&ctx, &[1, 1, max + 1], false).await; + assert_eq!(res, [0; 3]); + }); + +/// Make sure that resetting the bind groups set by the validation code works properly. +#[gpu_test] +static RESET_BIND_GROUPS: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .features(wgpu::Features::PUSH_CONSTANTS) + .downlevel_flags( + wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION, + ) + .limits(wgpu::Limits { + max_push_constant_size: 4, + ..wgpu::Limits::downlevel_defaults() + }), + ) + .run_async(|ctx| async move { + ctx.device.push_error_scope(wgpu::ErrorFilter::Validation); + + let _ = run_test(&ctx, &[0, 0, 0], true).await; + + let error = pollster::block_on(ctx.device.pop_error_scope()); + assert!(error.map_or(false, |error| { + format!("{error}").contains("The current set ComputePipeline with '' label expects a BindGroup to be set at index 0") + })); + }); + +async fn run_test( + ctx: &TestingContext, + num_workgroups: &[u32; 3], + forget_to_set_bind_group: bool, +) -> [u32; 3] { + const SHADER_SRC: &str = " + struct TestOffsetPc { + inner: u32, + } + + // `test_offset.inner` should always be 0; we test that resetting the push constant set by the validation code works properly. + var test_offset: TestOffsetPc; + + @group(0) @binding(0) + var out: array; + + @compute @workgroup_size(1) + fn main(@builtin(num_workgroups) num_workgroups: vec3u, @builtin(workgroup_id) workgroup_id: vec3u) { + if (all(workgroup_id == vec3u())) { + out[0] = num_workgroups.x + test_offset.inner; + out[1] = num_workgroups.y + test_offset.inner; + out[2] = num_workgroups.z + test_offset.inner; + } + } + "; + + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()), + }); + + let bgl = ctx + .device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &[wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }], + }); + + let layout = ctx + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[&bgl], + push_constant_ranges: &[wgt::PushConstantRange { + stages: wgt::ShaderStages::COMPUTE, + range: 0..4, + }], + }); + + let pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: Some(&layout), + module: &module, + entry_point: Some("main"), + compilation_options: Default::default(), + cache: None, + }); + + let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 12, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 12, + usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &pipeline.get_bind_group_layout(0), + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: out_buffer.as_entire_binding(), + }], + }); + + let mut res = None; + + for (indirect_offset, indirect_buffer_size) in [ + // internal src buffer binding size will be buffer.size + (0, 12), + (4, 4 + 12), + (4, 8 + 12), + (256 * 2 - 4 - 12, 256 * 2 - 4), + // internal src buffer binding size will be 256 * 2 + x + (0, 256 * 2 * 2 + 4), + (256, 256 * 2 * 2 + 8), + (256 + 4, 256 * 2 * 2 + 12), + (256 * 2 + 16, 256 * 2 * 2 + 16), + (256 * 2 * 2, 256 * 2 * 2 + 32), + (256 + 12, 256 * 2 * 2 + 64), + ] { + let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: indirect_buffer_size, + usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::INDIRECT, + mapped_at_creation: false, + }); + + ctx.queue.write_buffer( + &indirect_buffer, + indirect_offset, + bytemuck::bytes_of(num_workgroups), + ); + + let mut encoder = ctx.device.create_command_encoder(&Default::default()); + { + let mut compute_pass = encoder.begin_compute_pass(&Default::default()); + compute_pass.set_pipeline(&pipeline); + compute_pass.set_push_constants(0, &[0, 0, 0, 0]); + if !forget_to_set_bind_group { + compute_pass.set_bind_group(0, Some(&bind_group), &[]); + } + compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset); + } + + encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 12); + + ctx.queue.submit(Some(encoder.finish())); + + readback_buffer + .slice(..) + .map_async(wgpu::MapMode::Read, |_| {}); + + ctx.async_poll(wgpu::Maintain::wait()) + .await + .panic_on_timeout(); + + let view = readback_buffer.slice(..).get_mapped_range(); + + let current_res = *bytemuck::from_bytes(&view); + drop(view); + readback_buffer.unmap(); + + if let Some(past_res) = res { + assert_eq!(past_res, current_res); + } else { + res = Some(current_res); + } + } + + res.unwrap() +} diff --git a/tests/tests/root.rs b/tests/tests/root.rs index 3bb8e14a90..886f9da58b 100644 --- a/tests/tests/root.rs +++ b/tests/tests/root.rs @@ -19,6 +19,7 @@ mod clear_texture; mod compute_pass_ownership; mod create_surface_error; mod device; +mod dispatch_workgroups_indirect; mod encoder; mod external_texture; mod float32_filterable; diff --git a/wgpu-core/Cargo.toml b/wgpu-core/Cargo.toml index 1b9ce98488..60b4165a42 100644 --- a/wgpu-core/Cargo.toml +++ b/wgpu-core/Cargo.toml @@ -51,6 +51,10 @@ renderdoc = ["hal/renderdoc"] ## to the validation carried out at public APIs in all builds. strict_asserts = ["wgt/strict_asserts"] +## Validates indirect draw/dispatch calls. This will also enable naga's +## WGSL frontend since we use a WGSL compute shader to do the validation. +indirect-validation = ["naga/wgsl-in"] + ## Enables serialization via `serde` on common wgpu types. serde = ["dep:serde", "wgt/serde", "arrayvec/serde"] diff --git a/wgpu-core/src/command/bind.rs b/wgpu-core/src/command/bind.rs index 620027994f..22831c7a81 100644 --- a/wgpu-core/src/command/bind.rs +++ b/wgpu-core/src/command/bind.rs @@ -200,13 +200,17 @@ mod compat { entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(), } } - fn make_range(&self, start_index: usize) -> Range { + + pub fn num_valid_entries(&self) -> usize { // find first incompatible entry - let end = self - .entries + self.entries .iter() .position(|e| e.is_incompatible()) - .unwrap_or(self.entries.len()); + .unwrap_or(self.entries.len()) + } + + fn make_range(&self, start_index: usize) -> Range { + let end = self.num_valid_entries(); start_index..end.max(start_index) } @@ -406,6 +410,14 @@ impl Binder { .map(move |index| payloads[index].group.as_ref().unwrap()) } + #[cfg(feature = "indirect-validation")] + pub(super) fn list_valid<'a>(&'a self) -> impl Iterator + '_ { + self.payloads + .iter() + .take(self.manager.num_valid_entries()) + .enumerate() + } + pub(super) fn check_compatibility( &self, pipeline: &T, diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 133a5af35a..07dcc2475a 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -18,7 +18,7 @@ use crate::{ pipeline::ComputePipeline, resource::{ self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled, - MissingBufferUsageError, ParentDevice, Trackable, + MissingBufferUsageError, ParentDevice, }, snatch::SnatchGuard, track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope}, @@ -216,6 +216,8 @@ struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> { string_offset: usize, active_query: Option<(Arc, u32)>, + push_constants: Vec, + intermediate_trackers: Tracker, /// Immediate texture inits required because of prior discards. Need to @@ -443,6 +445,8 @@ impl Global { string_offset: 0, active_query: None, + push_constants: Vec::new(), + intermediate_trackers: Tracker::new(), pending_discard_init_fixups: SurfacesInDiscardState::new(), @@ -746,6 +750,21 @@ fn set_pipeline( } } + // TODO: integrate this in the code below once we simplify push constants + state.push_constants.clear(); + // Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error. + if let Some(push_constant_range) = + pipeline.layout.push_constant_ranges.iter().find_map(|pcr| { + pcr.stages + .contains(wgt::ShaderStages::COMPUTE) + .then_some(pcr.range.clone()) + }) + { + // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502 + let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize; + state.push_constants.extend(core::iter::repeat(0).take(len)); + } + // Clear push constant ranges let non_overlapping = super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges); @@ -791,6 +810,10 @@ fn set_push_constant( end_offset_bytes, )?; + let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize; + let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize; + state.push_constants[offset_in_elements..][..size_in_elements].copy_from_slice(data_slice); + unsafe { state.raw_encoder.set_push_constants( pipeline_layout.raw(), @@ -841,10 +864,6 @@ fn dispatch_indirect( .device .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?; - state - .scope - .buffers - .merge_single(&buffer, hal::BufferUses::INDIRECT)?; buffer.check_usage(wgt::BufferUsages::INDIRECT)?; if offset % 4 != 0 { @@ -861,7 +880,6 @@ fn dispatch_indirect( } let stride = 3 * 4; // 3 integers, x/y/z group size - state .buffer_memory_init_actions .extend(buffer.initialization_status.read().create_action( @@ -870,12 +888,132 @@ fn dispatch_indirect( MemoryInitKind::NeedsInitializedMemory, )); - state.flush_states(Some(buffer.tracker_index()))?; + #[cfg(feature = "indirect-validation")] + { + let params = state.device.indirect_validation.as_ref().unwrap().params( + &state.device.limits, + offset, + buffer.size, + ); - let buf_raw = buffer.try_raw(&state.snatch_guard)?; - unsafe { - state.raw_encoder.dispatch_indirect(buf_raw, offset); + unsafe { + state.raw_encoder.set_compute_pipeline(params.pipeline); + } + + unsafe { + state.raw_encoder.set_push_constants( + params.pipeline_layout, + wgt::ShaderStages::COMPUTE, + 0, + &[params.offset_remainder as u32 / 4], + ); + } + + unsafe { + state.raw_encoder.set_bind_group( + params.pipeline_layout, + 0, + Some(params.dst_bind_group), + &[], + ); + } + unsafe { + state.raw_encoder.set_bind_group( + params.pipeline_layout, + 1, + Some( + buffer + .raw_indirect_validation_bind_group + .get(&state.snatch_guard) + .unwrap() + .as_ref(), + ), + &[params.aligned_offset as u32], + ); + } + + let src_transition = state + .intermediate_trackers + .buffers + .set_single(&buffer, hal::BufferUses::STORAGE_READ); + let src_barrier = + src_transition.map(|transition| transition.into_hal(&buffer, &state.snatch_guard)); + unsafe { + state.raw_encoder.transition_buffers(src_barrier.as_slice()); + } + + unsafe { + state.raw_encoder.transition_buffers(&[hal::BufferBarrier { + buffer: params.dst_buffer, + usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE, + }]); + } + + unsafe { + state.raw_encoder.dispatch([1, 1, 1]); + } + + // reset state + { + let pipeline = state.pipeline.as_ref().unwrap(); + + unsafe { + state.raw_encoder.set_compute_pipeline(pipeline.raw()); + } + + if !state.push_constants.is_empty() { + unsafe { + state.raw_encoder.set_push_constants( + pipeline.layout.raw(), + wgt::ShaderStages::COMPUTE, + 0, + &state.push_constants, + ); + } + } + + for (i, e) in state.binder.list_valid() { + let group = e.group.as_ref().unwrap(); + let raw_bg = group.try_raw(&state.snatch_guard)?; + unsafe { + state.raw_encoder.set_bind_group( + pipeline.layout.raw(), + i as u32, + Some(raw_bg), + &e.dynamic_offsets, + ); + } + } + } + + unsafe { + state.raw_encoder.transition_buffers(&[hal::BufferBarrier { + buffer: params.dst_buffer, + usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT, + }]); + } + + state.flush_states(None)?; + unsafe { + state.raw_encoder.dispatch_indirect(params.dst_buffer, 0); + } + }; + #[cfg(not(feature = "indirect-validation"))] + { + state + .scope + .buffers + .merge_single(&buffer, hal::BufferUses::INDIRECT)?; + + use crate::resource::Trackable; + state.flush_states(Some(buffer.tracker_index()))?; + + let buf_raw = buffer.try_raw(&state.snatch_guard)?; + unsafe { + state.raw_encoder.dispatch_indirect(buf_raw, offset); + } } + Ok(()) } diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 10b82a73ae..583d3e03d3 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -406,12 +406,12 @@ impl Global { trace.add(trace::Action::CreateBuffer(fid.id(), desc.clone())); } - let buffer = device.create_buffer_from_hal(Box::new(hal_buffer), desc); + let (buffer, err) = device.create_buffer_from_hal(Box::new(hal_buffer), desc); let id = fid.assign(buffer); api_log!("Device::create_buffer -> {id:?}"); - (id, None) + (id, err) } pub fn texture_destroy(&self, texture_id: id::TextureId) -> Result<(), resource::DestroyError> { diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index 959f3cada7..a4c2d22277 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -36,7 +36,7 @@ pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10; // See https://github.com/gfx-rs/wgpu/issues/4589. 60s to reduce the chances of this. const CLEANUP_WAIT_MS: u32 = 60000; -const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid"; +pub(crate) const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid"; pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor>; diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index f84b2c5733..4ed9dff60a 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -31,7 +31,7 @@ use crate::{ UsageScopePool, }, validation::{self, validate_color_attachment_bytes_per_sample}, - FastHashMap, LabelHelpers as _, PreHashedKey, PreHashedMap, + FastHashMap, LabelHelpers, PreHashedKey, PreHashedMap, }; use arrayvec::ArrayVec; @@ -144,6 +144,9 @@ pub struct Device { #[cfg(feature = "trace")] pub(crate) trace: Mutex>, pub(crate) usage_scopes: UsageScopePool, + + #[cfg(feature = "indirect-validation")] + pub(crate) indirect_validation: Option, } pub(crate) enum DeferredDestroy { @@ -175,6 +178,11 @@ impl Drop for Device { let fence = unsafe { ManuallyDrop::take(&mut self.fence.write()) }; pending_writes.dispose(raw.as_ref()); self.command_allocator.dispose(raw.as_ref()); + #[cfg(feature = "indirect-validation")] + self.indirect_validation + .take() + .unwrap() + .dispose(raw.as_ref()); unsafe { raw.destroy_buffer(zero_buffer); raw.destroy_fence(fence); @@ -261,6 +269,25 @@ impl Device { let alignments = adapter.raw.capabilities.alignments.clone(); let downlevel = adapter.raw.capabilities.downlevel.clone(); + #[cfg(feature = "indirect-validation")] + let indirect_validation = if downlevel + .flags + .contains(wgt::DownlevelFlags::INDIRECT_EXECUTION) + { + match crate::indirect_validation::IndirectValidation::new( + raw_device.as_ref(), + &desc.required_limits, + ) { + Ok(indirect_validation) => Some(indirect_validation), + Err(e) => { + log::error!("indirect-validation error: {e:?}"); + return Err(DeviceError::Lost); + } + } + } else { + None + }; + Ok(Self { raw: ManuallyDrop::new(raw_device), adapter: adapter.clone(), @@ -306,6 +333,8 @@ impl Device { ), deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()), usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()), + #[cfg(feature = "indirect-validation")] + indirect_validation, }) } @@ -547,6 +576,13 @@ impl Device { let mut usage = conv::map_buffer_usage(desc.usage); + if desc.usage.contains(wgt::BufferUsages::INDIRECT) { + self.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?; + // We are going to be reading from it, internally; + // when validating the content of the buffer + usage |= hal::BufferUses::STORAGE_READ | hal::BufferUses::STORAGE_READ_WRITE; + } + if desc.mapped_at_creation { if desc.size % wgt::COPY_BUFFER_ALIGNMENT != 0 { return Err(resource::CreateBufferError::UnalignedSize); @@ -586,6 +622,10 @@ impl Device { let buffer = unsafe { self.raw().create_buffer(&hal_desc) }.map_err(|e| self.handle_hal_error(e))?; + #[cfg(feature = "indirect-validation")] + let raw_indirect_validation_bind_group = + self.create_indirect_validation_bind_group(buffer.as_ref(), desc.size, desc.usage)?; + let buffer = Buffer { raw: Snatchable::new(buffer), device: self.clone(), @@ -599,6 +639,8 @@ impl Device { label: desc.label.to_string(), tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()), bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()), + #[cfg(feature = "indirect-validation")] + raw_indirect_validation_bind_group, }; let buffer = Arc::new(buffer); @@ -686,7 +728,17 @@ impl Device { self: &Arc, hal_buffer: Box, desc: &resource::BufferDescriptor, - ) -> Fallible { + ) -> (Fallible, Option) { + #[cfg(feature = "indirect-validation")] + let raw_indirect_validation_bind_group = match self.create_indirect_validation_bind_group( + hal_buffer.as_ref(), + desc.size, + desc.usage, + ) { + Ok(ok) => ok, + Err(e) => return (Fallible::Invalid(Arc::new(desc.label.to_string())), Some(e)), + }; + unsafe { self.raw().add_raw_buffer(&*hal_buffer) }; let buffer = Buffer { @@ -702,6 +754,8 @@ impl Device { label: desc.label.to_string(), tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()), bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()), + #[cfg(feature = "indirect-validation")] + raw_indirect_validation_bind_group, }; let buffer = Arc::new(buffer); @@ -711,7 +765,25 @@ impl Device { .buffers .insert_single(&buffer, hal::BufferUses::empty()); - Fallible::Valid(buffer) + (Fallible::Valid(buffer), None) + } + + #[cfg(feature = "indirect-validation")] + fn create_indirect_validation_bind_group( + &self, + raw_buffer: &dyn hal::DynBuffer, + buffer_size: u64, + usage: wgt::BufferUsages, + ) -> Result>, resource::CreateBufferError> { + if usage.contains(wgt::BufferUsages::INDIRECT) { + let indirect_validation = self.indirect_validation.as_ref().unwrap(); + let bind_group = indirect_validation + .create_src_bind_group(self.raw(), &self.limits, buffer_size, raw_buffer) + .map_err(resource::CreateBufferError::IndirectValidationBindGroup)?; + Ok(Snatchable::new(bind_group)) + } else { + Ok(Snatchable::empty()) + } } pub(crate) fn create_texture( diff --git a/wgpu-core/src/indirect_validation.rs b/wgpu-core/src/indirect_validation.rs new file mode 100644 index 0000000000..ca73731465 --- /dev/null +++ b/wgpu-core/src/indirect_validation.rs @@ -0,0 +1,378 @@ +use thiserror::Error; + +use crate::{ + device::DeviceError, + pipeline::{CreateComputePipelineError, CreateShaderModuleError}, +}; + +#[derive(Clone, Debug, Error)] +#[non_exhaustive] +pub enum CreateDispatchIndirectValidationPipelineError { + #[error(transparent)] + DeviceError(#[from] DeviceError), + #[error(transparent)] + ShaderModule(#[from] CreateShaderModuleError), + #[error(transparent)] + ComputePipeline(#[from] CreateComputePipelineError), +} + +/// This machinery requires the following limits: +/// +/// - max_bind_groups: 2, +/// - max_dynamic_storage_buffers_per_pipeline_layout: 1, +/// - max_storage_buffers_per_shader_stage: 2, +/// - max_storage_buffer_binding_size: 3 * min_storage_buffer_offset_alignment, +/// - max_push_constant_size: 4, +/// - max_compute_invocations_per_workgroup 1 +/// +/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also +/// required for this module's functionality to work. +#[derive(Debug)] +pub struct IndirectValidation { + module: Box, + dst_bind_group_layout: Box, + src_bind_group_layout: Box, + pipeline_layout: Box, + pipeline: Box, + dst_buffer: Box, + dst_bind_group: Box, +} + +pub struct Params<'a> { + pub pipeline_layout: &'a dyn hal::DynPipelineLayout, + pub pipeline: &'a dyn hal::DynComputePipeline, + pub dst_buffer: &'a dyn hal::DynBuffer, + pub dst_bind_group: &'a dyn hal::DynBindGroup, + pub aligned_offset: u64, + pub offset_remainder: u64, +} + +impl IndirectValidation { + pub fn new( + device: &dyn hal::DynDevice, + limits: &wgt::Limits, + ) -> Result { + let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension; + + let src = format!( + " + @group(0) @binding(0) + var dst: array; + @group(1) @binding(0) + var src: array; + struct OffsetPc {{ + inner: u32, + }} + var offset: OffsetPc; + + @compute @workgroup_size(1) + fn main() {{ + let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]); + let max_compute_workgroups_per_dimension = {max_compute_workgroups_per_dimension}u; + if ( + src.x > max_compute_workgroups_per_dimension || + src.y > max_compute_workgroups_per_dimension || + src.z > max_compute_workgroups_per_dimension + ) {{ + dst = array(0u, 0u, 0u); + }} else {{ + dst = array(src.x, src.y, src.z); + }} + }} + " + ); + + let module = naga::front::wgsl::parse_str(&src).map_err(|inner| { + CreateShaderModuleError::Parsing(naga::error::ShaderError { + source: src.clone(), + label: None, + inner: Box::new(inner), + }) + })?; + let info = crate::device::create_validator( + wgt::Features::PUSH_CONSTANTS, + wgt::DownlevelFlags::empty(), + naga::valid::ValidationFlags::all(), + ) + .validate(&module) + .map_err(|inner| { + CreateShaderModuleError::Validation(naga::error::ShaderError { + source: src, + label: None, + inner: Box::new(inner), + }) + })?; + let hal_shader = hal::ShaderInput::Naga(hal::NagaShader { + module: std::borrow::Cow::Owned(module), + info, + debug_source: None, + }); + let hal_desc = hal::ShaderModuleDescriptor { + label: None, + runtime_checks: false, + }; + let module = + unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| { + match error { + hal::ShaderError::Device(error) => { + CreateShaderModuleError::Device(DeviceError::from_hal(error)) + } + hal::ShaderError::Compilation(ref msg) => { + log::error!("Shader error: {}", msg); + CreateShaderModuleError::Generation + } + } + })?; + + let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor { + label: None, + flags: hal::BindGroupLayoutFlags::empty(), + entries: &[wgt::BindGroupLayoutEntry { + binding: 0, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgt::BindingType::Buffer { + ty: wgt::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }, + count: None, + }], + }; + let dst_bind_group_layout = unsafe { + device + .create_bind_group_layout(&dst_bind_group_layout_desc) + .map_err(DeviceError::from_hal)? + }; + + let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor { + label: None, + flags: hal::BindGroupLayoutFlags::empty(), + entries: &[wgt::BindGroupLayoutEntry { + binding: 0, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgt::BindingType::Buffer { + ty: wgt::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: true, + min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }, + count: None, + }], + }; + let src_bind_group_layout = unsafe { + device + .create_bind_group_layout(&src_bind_group_layout_desc) + .map_err(DeviceError::from_hal)? + }; + + let pipeline_layout_desc = hal::PipelineLayoutDescriptor { + label: None, + flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE, + bind_group_layouts: &[ + dst_bind_group_layout.as_ref(), + src_bind_group_layout.as_ref(), + ], + push_constant_ranges: &[wgt::PushConstantRange { + stages: wgt::ShaderStages::COMPUTE, + range: 0..4, + }], + }; + let pipeline_layout = unsafe { + device + .create_pipeline_layout(&pipeline_layout_desc) + .map_err(DeviceError::from_hal)? + }; + + let pipeline_desc = hal::ComputePipelineDescriptor { + label: None, + layout: pipeline_layout.as_ref(), + stage: hal::ProgrammableStage { + module: module.as_ref(), + entry_point: "main", + constants: &Default::default(), + zero_initialize_workgroup_memory: false, + }, + cache: None, + }; + let pipeline = + unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err { + hal::PipelineError::Device(error) => { + CreateComputePipelineError::Device(DeviceError::from_hal(error)) + } + hal::PipelineError::Linkage(_stages, msg) => { + CreateComputePipelineError::Internal(msg) + } + hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal( + crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(), + ), + hal::PipelineError::PipelineConstants(_, error) => { + CreateComputePipelineError::PipelineConstants(error) + } + })?; + + let dst_buffer_desc = hal::BufferDescriptor { + label: None, + size: 4 * 3, + usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE, + memory_flags: hal::MemoryFlags::empty(), + }; + let dst_buffer = + unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?; + + let dst_bind_group_desc = hal::BindGroupDescriptor { + label: None, + layout: dst_bind_group_layout.as_ref(), + entries: &[hal::BindGroupEntry { + binding: 0, + resource_index: 0, + count: 1, + }], + buffers: &[hal::BufferBinding { + buffer: dst_buffer.as_ref(), + offset: 0, + size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }], + samplers: &[], + textures: &[], + acceleration_structures: &[], + }; + let dst_bind_group = unsafe { + device + .create_bind_group(&dst_bind_group_desc) + .map_err(DeviceError::from_hal) + }?; + + Ok(Self { + module, + dst_bind_group_layout, + src_bind_group_layout, + pipeline_layout, + pipeline, + dst_buffer, + dst_bind_group, + }) + } + + pub fn create_src_bind_group( + &self, + device: &dyn hal::DynDevice, + limits: &wgt::Limits, + buffer_size: u64, + buffer: &dyn hal::DynBuffer, + ) -> Result, DeviceError> { + let binding_size = calculate_src_buffer_binding_size(buffer_size, limits); + let hal_desc = hal::BindGroupDescriptor { + label: None, + layout: self.src_bind_group_layout.as_ref(), + entries: &[hal::BindGroupEntry { + binding: 0, + resource_index: 0, + count: 1, + }], + buffers: &[hal::BufferBinding { + buffer, + offset: 0, + size: Some(std::num::NonZeroU64::new(binding_size).unwrap()), + }], + samplers: &[], + textures: &[], + acceleration_structures: &[], + }; + unsafe { + device + .create_bind_group(&hal_desc) + .map_err(DeviceError::from_hal) + } + } + + pub fn params<'a>(&'a self, limits: &wgt::Limits, offset: u64, buffer_size: u64) -> Params<'a> { + // The offset we receive is only required to be aligned to 4 bytes. + // + // Binding offsets and dynamic offsets are required to be aligned to + // min_storage_buffer_offset_alignment (256 bytes by default). + // + // So, we work around this limitation by calculating an aligned offset + // and pass the remainder through a push constant. + // + // We could bind the whole buffer and only have to pass the offset + // through a push constant but we might run into the + // max_storage_buffer_binding_size limit. + // + // See the inner docs of `calculate_src_buffer_binding_size` to + // see how we get the appropriate `binding_size`. + let alignment = limits.min_storage_buffer_offset_alignment as u64; + let binding_size = calculate_src_buffer_binding_size(buffer_size, limits); + let aligned_offset = offset - offset % alignment; + // This works because `binding_size` is either `buffer_size` or `alignment * 2 + buffer_size % alignment`. + let max_aligned_offset = buffer_size - binding_size; + let aligned_offset = aligned_offset.min(max_aligned_offset); + let offset_remainder = offset - aligned_offset; + + Params { + pipeline_layout: self.pipeline_layout.as_ref(), + pipeline: self.pipeline.as_ref(), + dst_buffer: self.dst_buffer.as_ref(), + dst_bind_group: self.dst_bind_group.as_ref(), + aligned_offset, + offset_remainder, + } + } + + pub fn dispose(self, device: &dyn hal::DynDevice) { + let IndirectValidation { + module, + dst_bind_group_layout, + src_bind_group_layout, + pipeline_layout, + pipeline, + dst_buffer, + dst_bind_group, + } = self; + + unsafe { + device.destroy_bind_group(dst_bind_group); + device.destroy_buffer(dst_buffer); + device.destroy_compute_pipeline(pipeline); + device.destroy_pipeline_layout(pipeline_layout); + device.destroy_bind_group_layout(src_bind_group_layout); + device.destroy_bind_group_layout(dst_bind_group_layout); + device.destroy_shader_module(module); + } + } +} + +fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 { + let alignment = limits.min_storage_buffer_offset_alignment as u64; + + // We need to choose a binding size that can address all possible sets of 12 contiguous bytes in the buffer taking + // into account that the dynamic offset needs to be a multiple of `min_storage_buffer_offset_alignment`. + + // Given the know variables: `offset`, `buffer_size`, `alignment` and the rule `offset + 12 <= buffer_size`. + + // Let `chunks = floor(buffer_size / alignment)`. + // Let `chunk` be the interval `[0, chunks]`. + // Let `offset = alignment * chunk + r` where `r` is the interval [0, alignment - 4]. + // Let `binding` be the interval `[offset, offset + 12]`. + // Let `aligned_offset = alignment * chunk`. + // Let `aligned_binding` be the interval `[aligned_offset, aligned_offset + r + 12]`. + // Let `aligned_binding_size = r + 12 = [12, alignment + 8]`. + // Let `min_aligned_binding_size = alignment + 8`. + + // `min_aligned_binding_size` is the minimum binding size required to address all 12 contiguous bytes in the buffer + // but the last aligned_offset + min_aligned_binding_size might overflow the buffer. In order to avoid this we must + // pick a larger `binding_size` that satisfies: `last_aligned_offset + binding_size = buffer_size` and + // `binding_size >= min_aligned_binding_size`. + + // Let `buffer_size = alignment * chunks + sr` where `sr` is the interval [0, alignment - 4]. + // Let `last_aligned_offset = alignment * (chunks - u)` where `u` is the interval [0, chunks]. + // => `binding_size = buffer_size - last_aligned_offset` + // => `binding_size = alignment * chunks + sr - alignment * (chunks - u)` + // => `binding_size = alignment * chunks + sr - alignment * chunks + alignment * u` + // => `binding_size = sr + alignment * u` + // => `min_aligned_binding_size <= sr + alignment * u` + // => `alignment + 8 <= sr + alignment * u` + // => `u` must be at least 2 + // => `binding_size = sr + alignment * 2` + + let binding_size = 2 * alignment + (buffer_size % alignment); + binding_size.min(buffer_size) +} diff --git a/wgpu-core/src/lib.rs b/wgpu-core/src/lib.rs index 521238a7d6..c85288f47a 100644 --- a/wgpu-core/src/lib.rs +++ b/wgpu-core/src/lib.rs @@ -67,6 +67,8 @@ mod hash_utils; pub mod hub; pub mod id; pub mod identity; +#[cfg(feature = "indirect-validation")] +mod indirect_validation; mod init_tracker; pub mod instance; mod lock; diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index 08e7167db6..01ceabf669 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -92,7 +92,7 @@ impl ShaderModule { #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum CreateShaderModuleError { - #[cfg(feature = "wgsl")] + #[cfg(any(feature = "wgsl", feature = "indirect-validation"))] #[error(transparent)] Parsing(#[from] ShaderError), #[cfg(feature = "glsl")] diff --git a/wgpu-core/src/resource.rs b/wgpu-core/src/resource.rs index 5007b45614..3cd5765bdd 100644 --- a/wgpu-core/src/resource.rs +++ b/wgpu-core/src/resource.rs @@ -475,10 +475,18 @@ pub struct Buffer { pub(crate) tracking_data: TrackingData, pub(crate) map_state: Mutex, pub(crate) bind_groups: Mutex>>, + #[cfg(feature = "indirect-validation")] + pub(crate) raw_indirect_validation_bind_group: Snatchable>, } impl Drop for Buffer { fn drop(&mut self) { + #[cfg(feature = "indirect-validation")] + if let Some(raw) = self.raw_indirect_validation_bind_group.take() { + unsafe { + self.device.raw().destroy_bind_group(raw); + } + } if let Some(raw) = self.raw.take() { resource_log!("Destroy raw {}", self.error_ident()); unsafe { @@ -737,13 +745,22 @@ impl Buffer { let device = &self.device; let temp = { - let raw = match self.raw.snatch(&mut device.snatchable_lock.write()) { + let mut snatch_guard = device.snatchable_lock.write(); + + let raw = match self.raw.snatch(&mut snatch_guard) { Some(raw) => raw, None => { return Err(DestroyError::AlreadyDestroyed); } }; + #[cfg(feature = "indirect-validation")] + let raw_indirect_validation_bind_group = self + .raw_indirect_validation_bind_group + .snatch(&mut snatch_guard); + + drop(snatch_guard); + let bind_groups = { let mut guard = self.bind_groups.lock(); mem::take(&mut *guard) @@ -754,6 +771,8 @@ impl Buffer { device: Arc::clone(&self.device), label: self.label().to_owned(), bind_groups, + #[cfg(feature = "indirect-validation")] + raw_indirect_validation_bind_group, }) }; @@ -789,6 +808,8 @@ pub enum CreateBufferError { MaxBufferSize { requested: u64, maximum: u64 }, #[error(transparent)] MissingDownlevelFlags(#[from] MissingDownlevelFlags), + #[error("Failed to create bind group for indirect buffer validation: {0}")] + IndirectValidationBindGroup(DeviceError), } crate::impl_resource_type!(Buffer); @@ -804,6 +825,8 @@ pub struct DestroyedBuffer { device: Arc, label: String, bind_groups: Vec>, + #[cfg(feature = "indirect-validation")] + raw_indirect_validation_bind_group: Option>, } impl DestroyedBuffer { @@ -820,6 +843,13 @@ impl Drop for DestroyedBuffer { } drop(deferred); + #[cfg(feature = "indirect-validation")] + if let Some(raw) = self.raw_indirect_validation_bind_group.take() { + unsafe { + self.device.raw().destroy_bind_group(raw); + } + } + resource_log!("Destroy raw Buffer (destroyed) {:?}", self.label()); // SAFETY: We are in the Drop impl and we don't use self.raw anymore after this point. let raw = unsafe { ManuallyDrop::take(&mut self.raw) }; diff --git a/wgpu-core/src/snatch.rs b/wgpu-core/src/snatch.rs index a817e2068c..0d57f41fcc 100644 --- a/wgpu-core/src/snatch.rs +++ b/wgpu-core/src/snatch.rs @@ -32,6 +32,12 @@ impl Snatchable { } } + pub fn empty() -> Self { + Snatchable { + value: UnsafeCell::new(None), + } + } + /// Get read access to the value. Requires a the snatchable lock's read guard. pub fn get<'a>(&'a self, _guard: &'a SnatchGuard) -> Option<&'a T> { unsafe { (*self.value.get()).as_ref() } diff --git a/wgpu/Cargo.toml b/wgpu/Cargo.toml index 9569281eec..ed630133e2 100644 --- a/wgpu/Cargo.toml +++ b/wgpu/Cargo.toml @@ -130,6 +130,12 @@ features = ["raw-window-handle"] workspace = true features = ["raw-window-handle"] +# If we are not targeting WebGL, enable indirect-validation. +# WebGL doesn't support indirect execution so this is not needed. +[target.'cfg(not(target_arch = "wasm32"))'.dependencies.wgc] +workspace = true +features = ["indirect-validation"] + # Enable `wgc` by default on macOS and iOS to allow the `metal` crate feature to # enable the Metal backend while being no-op on other targets. [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies.wgc]