diff --git a/tests/tests/dispatch_workgroups_indirect.rs b/tests/tests/dispatch_workgroups_indirect.rs index 7ba3d7f638b..0e11186eccf 100644 --- a/tests/tests/dispatch_workgroups_indirect.rs +++ b/tests/tests/dispatch_workgroups_indirect.rs @@ -209,6 +209,9 @@ async fn run_test( if !forget_to_set_bind_group { compute_pass.set_bind_group(0, Some(&bind_group), &[]); } + // Issue multiple dispatches to test the internal destination buffer switching + compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset); + compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset); compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset); } diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 07dcc2475a2..f9566dd6c0f 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -942,13 +942,6 @@ fn dispatch_indirect( state.raw_encoder.transition_buffers(src_barrier.as_slice()); } - unsafe { - state.raw_encoder.transition_buffers(&[hal::BufferBarrier { - buffer: params.dst_buffer, - usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE, - }]); - } - unsafe { state.raw_encoder.dispatch([1, 1, 1]); } @@ -987,10 +980,16 @@ fn dispatch_indirect( } unsafe { - state.raw_encoder.transition_buffers(&[hal::BufferBarrier { - buffer: params.dst_buffer, - usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT, - }]); + state.raw_encoder.transition_buffers(&[ + hal::BufferBarrier { + buffer: params.dst_buffer, + usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT, + }, + hal::BufferBarrier { + buffer: params.other_dst_buffer, + usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE, + }, + ]); } state.flush_states(None)?; diff --git a/wgpu-core/src/indirect_validation.rs b/wgpu-core/src/indirect_validation.rs index ca73731465f..8e498acf030 100644 --- a/wgpu-core/src/indirect_validation.rs +++ b/wgpu-core/src/indirect_validation.rs @@ -1,3 +1,5 @@ +use std::sync::atomic::AtomicBool; + use thiserror::Error; use crate::{ @@ -34,14 +36,18 @@ pub struct IndirectValidation { src_bind_group_layout: Box, pipeline_layout: Box, pipeline: Box, - dst_buffer: Box, - dst_bind_group: Box, + dst_buffer_0: Box, + dst_buffer_1: Box, + dst_bind_group_0: Box, + dst_bind_group_1: Box, + is_next_dst_0: AtomicBool, } pub struct Params<'a> { pub pipeline_layout: &'a dyn hal::DynPipelineLayout, pub pipeline: &'a dyn hal::DynComputePipeline, pub dst_buffer: &'a dyn hal::DynBuffer, + pub other_dst_buffer: &'a dyn hal::DynBuffer, pub dst_bind_group: &'a dyn hal::DynBindGroup, pub aligned_offset: u64, pub offset_remainder: u64, @@ -215,10 +221,12 @@ impl IndirectValidation { usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE, memory_flags: hal::MemoryFlags::empty(), }; - let dst_buffer = + let dst_buffer_0 = + unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?; + let dst_buffer_1 = unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?; - let dst_bind_group_desc = hal::BindGroupDescriptor { + let dst_bind_group_desc_0 = hal::BindGroupDescriptor { label: None, layout: dst_bind_group_layout.as_ref(), entries: &[hal::BindGroupEntry { @@ -227,7 +235,7 @@ impl IndirectValidation { count: 1, }], buffers: &[hal::BufferBinding { - buffer: dst_buffer.as_ref(), + buffer: dst_buffer_0.as_ref(), offset: 0, size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), }], @@ -235,9 +243,32 @@ impl IndirectValidation { textures: &[], acceleration_structures: &[], }; - let dst_bind_group = unsafe { + let dst_bind_group_0 = unsafe { device - .create_bind_group(&dst_bind_group_desc) + .create_bind_group(&dst_bind_group_desc_0) + .map_err(DeviceError::from_hal) + }?; + + let dst_bind_group_desc_1 = hal::BindGroupDescriptor { + label: None, + layout: dst_bind_group_layout.as_ref(), + entries: &[hal::BindGroupEntry { + binding: 0, + resource_index: 0, + count: 1, + }], + buffers: &[hal::BufferBinding { + buffer: dst_buffer_1.as_ref(), + offset: 0, + size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }], + samplers: &[], + textures: &[], + acceleration_structures: &[], + }; + let dst_bind_group_1 = unsafe { + device + .create_bind_group(&dst_bind_group_desc_1) .map_err(DeviceError::from_hal) }?; @@ -247,8 +278,11 @@ impl IndirectValidation { src_bind_group_layout, pipeline_layout, pipeline, - dst_buffer, - dst_bind_group, + dst_buffer_0, + dst_buffer_1, + dst_bind_group_0, + dst_bind_group_1, + is_next_dst_0: AtomicBool::new(false), }) } @@ -307,11 +341,29 @@ impl IndirectValidation { let aligned_offset = aligned_offset.min(max_aligned_offset); let offset_remainder = offset - aligned_offset; + let (dst_buffer, other_dst_buffer, dst_bind_group) = if self + .is_next_dst_0 + .fetch_xor(true, core::sync::atomic::Ordering::AcqRel) + { + ( + self.dst_buffer_0.as_ref(), + self.dst_buffer_1.as_ref(), + self.dst_bind_group_0.as_ref(), + ) + } else { + ( + self.dst_buffer_1.as_ref(), + self.dst_buffer_0.as_ref(), + self.dst_bind_group_1.as_ref(), + ) + }; + Params { pipeline_layout: self.pipeline_layout.as_ref(), pipeline: self.pipeline.as_ref(), - dst_buffer: self.dst_buffer.as_ref(), - dst_bind_group: self.dst_bind_group.as_ref(), + dst_buffer, + other_dst_buffer, + dst_bind_group, aligned_offset, offset_remainder, } @@ -324,13 +376,18 @@ impl IndirectValidation { src_bind_group_layout, pipeline_layout, pipeline, - dst_buffer, - dst_bind_group, + dst_buffer_0, + dst_buffer_1, + dst_bind_group_0, + dst_bind_group_1, + is_next_dst_0: _, } = self; unsafe { - device.destroy_bind_group(dst_bind_group); - device.destroy_buffer(dst_buffer); + device.destroy_bind_group(dst_bind_group_0); + device.destroy_bind_group(dst_bind_group_1); + device.destroy_buffer(dst_buffer_0); + device.destroy_buffer(dst_buffer_1); device.destroy_compute_pipeline(pipeline); device.destroy_pipeline_layout(pipeline_layout); device.destroy_bind_group_layout(src_bind_group_layout);