From e26a85365484411f42bc5ddbc08c3a069bba2f83 Mon Sep 17 00:00:00 2001 From: Connor Fitzgerald Date: Tue, 2 Jan 2024 09:27:22 -0500 Subject: [PATCH] BGL Weak Pointer Deduplication Pool (#4927) * Ho boy * BGL pool finished * Remove id32 feature * Add BGL Test * Iteration * Working Dedupe * Tests * Iteration * Re-iteration * Hash Cleanup * Add large slew of tests * Whoops --- .config/nextest.toml | 4 + Cargo.lock | 2 + tests/tests/bind_group_layout_dedup.rs | 413 ++++++++++++++++++++----- wgpu-core/Cargo.toml | 2 + wgpu-core/src/binding_model.rs | 38 ++- wgpu-core/src/command/bind.rs | 64 ++-- wgpu-core/src/device/bgl.rs | 129 ++++++++ wgpu-core/src/device/global.rs | 70 +++-- wgpu-core/src/device/mod.rs | 1 + wgpu-core/src/device/resource.rs | 308 ++++++++---------- wgpu-core/src/hash_utils.rs | 86 +++++ wgpu-core/src/lib.rs | 11 +- wgpu-core/src/pipeline.rs | 3 +- wgpu-core/src/pool.rs | 312 +++++++++++++++++++ wgpu-core/src/registry.rs | 11 +- wgpu-core/src/storage.rs | 6 + wgpu-core/src/validation.rs | 119 ++++--- 17 files changed, 1207 insertions(+), 372 deletions(-) create mode 100644 .config/nextest.toml create mode 100644 wgpu-core/src/device/bgl.rs create mode 100644 wgpu-core/src/hash_utils.rs create mode 100644 wgpu-core/src/pool.rs diff --git a/.config/nextest.toml b/.config/nextest.toml new file mode 100644 index 0000000000..3d209d3ec9 --- /dev/null +++ b/.config/nextest.toml @@ -0,0 +1,4 @@ +# Use two threads for tests with "2_threads" in their name +[[profile.default.overrides]] +filter = 'test(~2_threads)' +threads-required = 2 diff --git a/Cargo.lock b/Cargo.lock index 737dd49e12..e3a75d3995 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4019,8 +4019,10 @@ dependencies = [ "bit-vec", "bitflags 2.4.1", "codespan-reporting", + "indexmap", "log", "naga", + "once_cell", "parking_lot", "profiling", "raw-window-handle 0.6.0", diff --git a/tests/tests/bind_group_layout_dedup.rs b/tests/tests/bind_group_layout_dedup.rs index b3d99e781b..66ea687f2a 100644 --- a/tests/tests/bind_group_layout_dedup.rs +++ b/tests/tests/bind_group_layout_dedup.rs @@ -1,22 +1,42 @@ -use wgpu_test::{gpu_test, GpuTestConfiguration, TestingContext}; +use std::num::NonZeroU64; + +use wgpu_test::{ + fail, gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext, +}; +use wgt::Backends; + +const SHADER_SRC: &str = " +@group(0) @binding(0) +var buffer : f32; + +@compute @workgroup_size(1, 1, 1) fn no_resources() {} +@compute @workgroup_size(1, 1, 1) fn resources() { + // Just need a static use. + let _value = buffer; +} +"; + +const ENTRY: wgpu::BindGroupLayoutEntry = wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Uniform, + has_dynamic_offset: false, + // Should be Some(.unwrap()) but unwrap is not const. + min_binding_size: NonZeroU64::new(4), + }, + count: None, +}; #[gpu_test] -static BIND_GROUP_LAYOUT_DEDUPLICATION: GpuTestConfiguration = - GpuTestConfiguration::new().run_sync(bgl_dedupe); +static BIND_GROUP_LAYOUT_DEDUPLICATION: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters(TestParameters::default().test_features_limits()) + .run_sync(bgl_dedupe); fn bgl_dedupe(ctx: TestingContext) { let entries_1 = &[]; - let entries_2 = &[wgpu::BindGroupLayoutEntry { - binding: 0, - visibility: wgpu::ShaderStages::VERTEX, - ty: wgpu::BindingType::Buffer { - ty: wgpu::BufferBindingType::Uniform, - has_dynamic_offset: false, - min_binding_size: None, - }, - count: None, - }]; + let entries_2 = &[ENTRY]; // Block so we can force all resource to die. { @@ -68,75 +88,30 @@ fn bgl_dedupe(ctx: TestingContext) { source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()), }); - let targets = &[Some(wgpu::ColorTargetState { - format: wgpu::TextureFormat::Rgba8Unorm, - blend: None, - write_mask: Default::default(), - })]; - - let desc = wgpu::RenderPipelineDescriptor { + let desc = wgpu::ComputePipelineDescriptor { label: None, layout: Some(&pipeline_layout), - vertex: wgpu::VertexState { - module: &module, - entry_point: "vs_main", - buffers: &[], - }, - fragment: Some(wgpu::FragmentState { - module: &module, - entry_point: "fs_main", - targets, - }), - primitive: wgpu::PrimitiveState::default(), - depth_stencil: None, - multiview: None, - multisample: wgpu::MultisampleState::default(), + module: &module, + entry_point: "no_resources", }; - let pipeline = ctx.device.create_render_pipeline(&desc); - - let texture = ctx.device.create_texture(&wgpu::TextureDescriptor { - label: None, - dimension: wgpu::TextureDimension::D2, - size: wgpu::Extent3d { - width: 32, - height: 32, - depth_or_array_layers: 1, - }, - sample_count: 1, - mip_level_count: 1, - format: wgpu::TextureFormat::Rgba8Unorm, - usage: wgpu::TextureUsages::RENDER_ATTACHMENT, - view_formats: &[], - }); - - let texture_view = texture.create_view(&Default::default()); + let pipeline = ctx.device.create_compute_pipeline(&desc); let mut encoder = ctx.device.create_command_encoder(&Default::default()); - { - let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor { - label: None, - color_attachments: &[Some(wgpu::RenderPassColorAttachment { - view: &texture_view, - resolve_target: None, - ops: Default::default(), - })], - depth_stencil_attachment: None, - occlusion_query_set: None, - timestamp_writes: None, - }); - - pass.set_bind_group(0, &bg_1b, &[]); - - pass.set_pipeline(&pipeline); + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); - pass.draw(0..6, 0..1); + pass.set_bind_group(0, &bg_1b, &[]); + pass.set_pipeline(&pipeline); + pass.dispatch_workgroups(1, 1, 1); - pass.set_bind_group(0, &bg_1a, &[]); + pass.set_bind_group(0, &bg_1a, &[]); + pass.dispatch_workgroups(1, 1, 1); - pass.draw(0..6, 0..1); - } + drop(pass); ctx.queue.submit(Some(encoder.finish())); @@ -177,7 +152,293 @@ fn bgl_dedupe(ctx: TestingContext) { } } -const SHADER_SRC: &str = " -@vertex fn vs_main() -> @builtin(position) vec4 { return vec4(1.0); } -@fragment fn fs_main() -> @location(0) vec4 { return vec4(1.0); } -"; +#[gpu_test] +static BIND_GROUP_LAYOUT_DEDUPLICATION_WITH_DROPPED_USER_HANDLE: GpuTestConfiguration = + GpuTestConfiguration::new() + .parameters(TestParameters::default().test_features_limits()) + .run_sync(bgl_dedupe_with_dropped_user_handle); + +// https://github.com/gfx-rs/wgpu/issues/4824 +fn bgl_dedupe_with_dropped_user_handle(ctx: TestingContext) { + let bgl_1 = ctx + .device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &[ENTRY], + }); + + let pipeline_layout = ctx + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[&bgl_1], + push_constant_ranges: &[], + }); + + // We drop bgl_1 here. As bgl_1 is still alive, referenced by the pipeline layout, + // the deduplication should work as expected. Previously this did not work. + drop(bgl_1); + + let bgl_2 = ctx + .device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &[ENTRY], + }); + + let buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 4, + usage: wgpu::BufferUsages::UNIFORM, + mapped_at_creation: false, + }); + + let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bgl_2, + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: buffer.as_entire_binding(), + }], + }); + + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()), + }); + + let pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: Some(&pipeline_layout), + module: &module, + entry_point: "no_resources", + }); + + let mut encoder = ctx.device.create_command_encoder(&Default::default()); + + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + + pass.set_bind_group(0, &bg, &[]); + pass.set_pipeline(&pipeline); + pass.dispatch_workgroups(1, 1, 1); + + drop(pass); + + ctx.queue.submit(Some(encoder.finish())); +} + +#[gpu_test] +static BIND_GROUP_LAYOUT_DEDUPLICATION_DERIVED: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters(TestParameters::default().test_features_limits()) + .run_sync(bgl_dedupe_derived); + +fn bgl_dedupe_derived(ctx: TestingContext) { + let buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 4, + usage: wgpu::BufferUsages::UNIFORM, + mapped_at_creation: false, + }); + + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()), + }); + + let pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: "resources", + }); + + // We create two bind groups, pulling the bind_group_layout from the pipeline each time. + // + // This ensures a derived BGLs are properly deduplicated despite multiple external + // references. + let bg1 = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &pipeline.get_bind_group_layout(0), + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: buffer.as_entire_binding(), + }], + }); + + let bg2 = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &pipeline.get_bind_group_layout(0), + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: buffer.as_entire_binding(), + }], + }); + + let mut encoder = ctx.device.create_command_encoder(&Default::default()); + + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + + pass.set_pipeline(&pipeline); + + pass.set_bind_group(0, &bg1, &[]); + pass.dispatch_workgroups(1, 1, 1); + + pass.set_bind_group(0, &bg2, &[]); + pass.dispatch_workgroups(1, 1, 1); + + drop(pass); + + ctx.queue.submit(Some(encoder.finish())); +} + +const DX12_VALIDATION_ERROR: &str = "The command allocator cannot be reset because a command list is currently being recorded with the allocator."; + +#[gpu_test] +static SEPARATE_PROGRAMS_HAVE_INCOMPATIBLE_DERIVED_BGLS: GpuTestConfiguration = + GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .test_features_limits() + .expect_fail( + FailureCase::backend(Backends::DX12).validation_error(DX12_VALIDATION_ERROR), + ), + ) + .run_sync(separate_programs_have_incompatible_derived_bgls); + +fn separate_programs_have_incompatible_derived_bgls(ctx: TestingContext) { + let buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 4, + usage: wgpu::BufferUsages::UNIFORM, + mapped_at_creation: false, + }); + + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()), + }); + + let desc = wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: "resources", + }; + // Create two pipelines, creating a BG from the second. + let pipeline1 = ctx.device.create_compute_pipeline(&desc); + let pipeline2 = ctx.device.create_compute_pipeline(&desc); + + let bg2 = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &pipeline2.get_bind_group_layout(0), + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: buffer.as_entire_binding(), + }], + }); + + let mut encoder = ctx.device.create_command_encoder(&Default::default()); + + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + + pass.set_pipeline(&pipeline1); + + // We use the wrong bind group for this pipeline here. This should fail. + pass.set_bind_group(0, &bg2, &[]); + pass.dispatch_workgroups(1, 1, 1); + + fail(&ctx.device, || { + drop(pass); + }); +} + +#[gpu_test] +static DERIVED_BGLS_INCOMPATIBLE_WITH_REGULAR_BGLS: GpuTestConfiguration = + GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .test_features_limits() + .expect_fail( + FailureCase::backend(Backends::DX12).validation_error(DX12_VALIDATION_ERROR), + ), + ) + .run_sync(derived_bgls_incompatible_with_regular_bgls); + +fn derived_bgls_incompatible_with_regular_bgls(ctx: TestingContext) { + let buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 4, + usage: wgpu::BufferUsages::UNIFORM, + mapped_at_creation: false, + }); + + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()), + }); + + // Create a pipeline. + let pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: "resources", + }); + + // Create a matching BGL + let bgl = ctx + .device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &[ENTRY], + }); + + // Create a bind group from the explicit BGL. This should be incompatible with the derived BGL used by the pipeline. + let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bgl, + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: buffer.as_entire_binding(), + }], + }); + + let mut encoder = ctx.device.create_command_encoder(&Default::default()); + + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + + pass.set_pipeline(&pipeline); + + pass.set_bind_group(0, &bg, &[]); + pass.dispatch_workgroups(1, 1, 1); + + fail(&ctx.device, || { + drop(pass); + }) +} diff --git a/wgpu-core/Cargo.toml b/wgpu-core/Cargo.toml index a9f6319873..49381aea49 100644 --- a/wgpu-core/Cargo.toml +++ b/wgpu-core/Cargo.toml @@ -94,7 +94,9 @@ arrayvec = "0.7" bit-vec = "0.6" bitflags = "2" codespan-reporting = "0.11" +indexmap = "2" log = "0.4" +once_cell = "1" # parking_lot 0.12 switches from `winapi` to `windows`; permit either parking_lot = ">=0.11,<0.13" profiling = { version = "1", default-features = false } diff --git a/wgpu-core/src/binding_model.rs b/wgpu-core/src/binding_model.rs index c5db9b02c4..5f4dfb434e 100644 --- a/wgpu-core/src/binding_model.rs +++ b/wgpu-core/src/binding_model.rs @@ -1,5 +1,7 @@ use crate::{ - device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures, SHADER_STAGE_COUNT}, + device::{ + bgl, Device, DeviceError, MissingDownlevelFlags, MissingFeatures, SHADER_STAGE_COUNT, + }, error::{ErrorFormatter, PrettyError}, hal_api::HalApi, id::{ @@ -12,7 +14,7 @@ use crate::{ snatch::SnatchGuard, track::{BindGroupStates, UsageConflict}, validation::{MissingBufferUsageError, MissingTextureUsageError}, - FastHashMap, Label, + Label, }; use arrayvec::ArrayVec; @@ -440,32 +442,32 @@ pub struct BindGroupLayoutDescriptor<'a> { pub entries: Cow<'a, [wgt::BindGroupLayoutEntry]>, } -pub(crate) type BindEntryMap = FastHashMap; - pub type BindGroupLayouts = crate::storage::Storage, BindGroupLayoutId>; /// Bind group layout. -/// -/// The lifetime of BGLs is a bit special. They are only referenced on CPU -/// without considering GPU operations. And on CPU they get manual -/// inc-refs and dec-refs. In particular, the following objects depend on them: -/// - produced bind groups -/// - produced pipeline layouts -/// - pipelines with implicit layouts #[derive(Debug)] pub struct BindGroupLayout { pub(crate) raw: Option, pub(crate) device: Arc>, - pub(crate) entries: BindEntryMap, + pub(crate) entries: bgl::EntryMap, + /// It is very important that we know if the bind group comes from the BGL pool. + /// + /// If it does, then we need to remove it from the pool when we drop it. + /// + /// We cannot unconditionally remove from the pool, as BGLs that don't come from the pool + /// (derived BGLs) must not be removed. + pub(crate) origin: bgl::Origin, #[allow(unused)] - pub(crate) dynamic_count: usize, - pub(crate) count_validator: BindingTypeMaxCountValidator, + pub(crate) binding_count_validator: BindingTypeMaxCountValidator, pub(crate) info: ResourceInfo, pub(crate) label: String, } impl Drop for BindGroupLayout { fn drop(&mut self) { + if matches!(self.origin, bgl::Origin::Pool) { + self.device.bgl_pool.remove(&self.entries); + } if let Some(raw) = self.raw.take() { resource_log!("Destroy raw BindGroupLayout {:?}", self.info.label()); unsafe { @@ -618,6 +620,14 @@ impl PipelineLayout { pub(crate) fn raw(&self) -> &A::PipelineLayout { self.raw.as_ref().unwrap() } + + pub(crate) fn get_binding_maps(&self) -> ArrayVec<&bgl::EntryMap, { hal::MAX_BIND_GROUPS }> { + self.bind_group_layouts + .iter() + .map(|bgl| &bgl.entries) + .collect() + } + /// Validate push constants match up with expected ranges. pub(crate) fn validate_push_constant_ranges( &self, diff --git a/wgpu-core/src/command/bind.rs b/wgpu-core/src/command/bind.rs index 6bf849a42a..7b2ac54552 100644 --- a/wgpu-core/src/command/bind.rs +++ b/wgpu-core/src/command/bind.rs @@ -16,7 +16,7 @@ type BindGroupMask = u8; mod compat { use arrayvec::ArrayVec; - use crate::{binding_model::BindGroupLayout, hal_api::HalApi, resource::Resource}; + use crate::{binding_model::BindGroupLayout, device::bgl, hal_api::HalApi, resource::Resource}; use std::{ops::Range, sync::Arc}; #[derive(Debug, Clone)] @@ -60,17 +60,35 @@ mod compat { let mut diff = Vec::new(); if let Some(expected_bgl) = self.expected.as_ref() { + let expected_bgl_type = match expected_bgl.origin { + bgl::Origin::Derived => "implicit", + bgl::Origin::Pool => "explicit", + }; + let expected_label = expected_bgl.label(); diff.push(format!( - "Should be compatible with bind group layout with label = `{}`", - expected_bgl.label() + "Should be compatible an with an {expected_bgl_type} bind group layout {}", + if expected_label.is_empty() { + "without label".to_string() + } else { + format!("with label = `{}`", expected_label) + } )); if let Some(assigned_bgl) = self.assigned.as_ref() { + let assigned_bgl_type = match assigned_bgl.origin { + bgl::Origin::Derived => "implicit", + bgl::Origin::Pool => "explicit", + }; + let assigned_label = assigned_bgl.label(); diff.push(format!( - "Assigned bind group layout with label = `{}`", - assigned_bgl.label() + "Assigned {assigned_bgl_type} bind group layout {}", + if assigned_label.is_empty() { + "without label".to_string() + } else { + format!("with label = `{}`", assigned_label) + } )); - for (id, e_entry) in &expected_bgl.entries { - if let Some(a_entry) = assigned_bgl.entries.get(id) { + for (id, e_entry) in expected_bgl.entries.iter() { + if let Some(a_entry) = assigned_bgl.entries.get(*id) { if a_entry.binding != e_entry.binding { diff.push(format!( "Entry {id} binding expected {}, got {}", @@ -96,32 +114,28 @@ mod compat { )); } } else { - diff.push(format!("Entry {id} not found in assigned bindgroup layout")) + diff.push(format!( + "Entry {id} not found in assigned bind group layout" + )) } } assigned_bgl.entries.iter().for_each(|(id, _e_entry)| { - if !expected_bgl.entries.contains_key(id) { - diff.push(format!("Entry {id} not found in expected bindgroup layout")) + if !expected_bgl.entries.contains_key(*id) { + diff.push(format!( + "Entry {id} not found in expected bind group layout" + )) } }); + + if expected_bgl.origin != assigned_bgl.origin { + diff.push(format!("Expected {expected_bgl_type} bind group layout, got {assigned_bgl_type}")) + } } else { - diff.push( - "Assigned bindgroup layout is implicit, expected explicit".to_owned(), - ); + diff.push("Assigned bind group layout not found (internal error)".to_owned()); } - } else if let Some(assigned_bgl) = self.assigned.as_ref() { - diff.push(format!( - "Assigned bind group layout = `{}`", - assigned_bgl.label() - )); - diff.push( - "Assigned bindgroup layout is not implicit, expected implicit".to_owned(), - ); - } - - if diff.is_empty() { - diff.push("But no differences found? (internal error)".to_owned()) + } else { + diff.push("Expected bind group layout not found (internal error)".to_owned()); } diff diff --git a/wgpu-core/src/device/bgl.rs b/wgpu-core/src/device/bgl.rs new file mode 100644 index 0000000000..b97f87b168 --- /dev/null +++ b/wgpu-core/src/device/bgl.rs @@ -0,0 +1,129 @@ +use std::hash::{Hash, Hasher}; + +use crate::{ + binding_model::{self}, + FastIndexMap, +}; + +/// Where a given BGL came from. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Origin { + /// The bind group layout was created by the user and is present in the BGL resource pool. + Pool, + /// The bind group layout was derived and is not present in the BGL resource pool. + Derived, +} + +/// A HashMap-like structure that stores a BindGroupLayouts [`wgt::BindGroupLayoutEntry`]s. +/// +/// It is hashable, so bind group layouts can be deduplicated. +#[derive(Debug, Default, Clone, Eq)] +pub struct EntryMap { + /// We use a IndexMap here so that we can sort the entries by their binding index, + /// guarenteeing that the hash of equivilant layouts will be the same. + inner: FastIndexMap, + /// We keep track of whether the map is sorted or not, so that we can assert that + /// it is sorted, so that PartialEq and Hash will be stable. + /// + /// We only need sorted if it is used in a Hash or PartialEq, so we never need + /// to actively sort it. + sorted: bool, +} + +impl PartialEq for EntryMap { + fn eq(&self, other: &Self) -> bool { + self.assert_sorted(); + other.assert_sorted(); + + self.inner == other.inner + } +} + +impl Hash for EntryMap { + fn hash(&self, state: &mut H) { + self.assert_sorted(); + + // We don't need to hash the keys, since they are just extracted from the values. + // + // We know this is stable and will match the behavior of PartialEq as we ensure + // that the array is sorted. + for entry in self.inner.values() { + entry.hash(state); + } + } +} + +impl EntryMap { + fn assert_sorted(&self) { + assert!(self.sorted); + } + + /// Create a new [`BindGroupLayoutEntryMap`] from a slice of [`wgt::BindGroupLayoutEntry`]s. + /// + /// Errors if there are duplicate bindings or if any binding index is greater than + /// the device's limits. + pub fn from_entries( + device_limits: &wgt::Limits, + entries: &[wgt::BindGroupLayoutEntry], + ) -> Result { + let mut inner = FastIndexMap::with_capacity_and_hasher(entries.len(), Default::default()); + for entry in entries { + if entry.binding > device_limits.max_bindings_per_bind_group { + return Err( + binding_model::CreateBindGroupLayoutError::InvalidBindingIndex { + binding: entry.binding, + maximum: device_limits.max_bindings_per_bind_group, + }, + ); + } + if inner.insert(entry.binding, *entry).is_some() { + return Err(binding_model::CreateBindGroupLayoutError::ConflictBinding( + entry.binding, + )); + } + } + inner.sort_unstable_keys(); + + Ok(Self { + inner, + sorted: true, + }) + } + + /// Get the count of [`wgt::BindGroupLayoutEntry`]s in this map. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Get the [`wgt::BindGroupLayoutEntry`] for the given binding index. + pub fn get(&self, binding: u32) -> Option<&wgt::BindGroupLayoutEntry> { + self.inner.get(&binding) + } + + /// Iterator over all the binding indices in this map. + pub fn indices(&self) -> impl ExactSizeIterator + '_ { + self.inner.keys().copied() + } + + /// Iterator over all the [`wgt::BindGroupLayoutEntry`]s in this map. + pub fn values(&self) -> impl ExactSizeIterator + '_ { + self.inner.values() + } + + pub fn iter(&self) -> impl ExactSizeIterator + '_ { + self.inner.iter() + } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + pub fn contains_key(&self, key: u32) -> bool { + self.inner.contains_key(&key) + } + + pub fn entry(&mut self, key: u32) -> indexmap::map::Entry<'_, u32, wgt::BindGroupLayoutEntry> { + self.sorted = false; + self.inner.entry(key) + } +} diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 9575979c8d..914d220c1a 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -3,8 +3,8 @@ use crate::device::trace; use crate::{ api_log, binding_model, command, conv, device::{ - life::WaitIdleError, map_buffer, queue, DeviceError, DeviceLostClosure, DeviceLostReason, - HostMap, IMPLICIT_BIND_GROUP_LAYOUT_ERROR_LABEL, + bgl, life::WaitIdleError, map_buffer, queue, DeviceError, DeviceLostClosure, + DeviceLostReason, HostMap, IMPLICIT_BIND_GROUP_LAYOUT_ERROR_LABEL, }, global::Global, hal_api::HalApi, @@ -16,7 +16,7 @@ use crate::{ resource::{self, BufferAccessResult}, resource::{BufferAccessError, BufferMapOperation, CreateBufferError, Resource}, validation::check_buffer_usage, - FastHashMap, Label, LabelHelpers as _, + Label, LabelHelpers as _, }; use arrayvec::ArrayVec; @@ -962,7 +962,7 @@ impl Global { let hub = A::hub(self); let fid = hub.bind_group_layouts.prepare::(id_in); - let error = 'outer: loop { + let error = loop { let device = match hub.devices.get(device_id) { Ok(device) => device, Err(_) => break DeviceError::Invalid.into(), @@ -976,38 +976,50 @@ impl Global { trace.add(trace::Action::CreateBindGroupLayout(fid.id(), desc.clone())); } - let mut entry_map = FastHashMap::default(); - for entry in desc.entries.iter() { - if entry.binding > device.limits.max_bindings_per_bind_group { - break 'outer binding_model::CreateBindGroupLayoutError::InvalidBindingIndex { - binding: entry.binding, - maximum: device.limits.max_bindings_per_bind_group, - }; - } - if entry_map.insert(entry.binding, *entry).is_some() { - break 'outer binding_model::CreateBindGroupLayoutError::ConflictBinding( - entry.binding, - ); - } - } + let entry_map = match bgl::EntryMap::from_entries(&device.limits, &desc.entries) { + Ok(map) => map, + Err(e) => break e, + }; - if let Some((id, layout)) = { - let bgl_guard = hub.bind_group_layouts.read(); - device.deduplicate_bind_group_layout(&entry_map, &*bgl_guard) - } { - api_log!("Reusing BindGroupLayout {layout:?} -> {:?}", id); - let id = fid.assign_existing(&layout); - return (id, None); - } + // Currently we make a distinction between fid.assign and fid.assign_existing. This distinction is incorrect, + // but see https://github.com/gfx-rs/wgpu/issues/4912. + // + // `assign` also registers the ID with the resource info, so it can be automatically reclaimed. This needs to + // happen with a mutable reference, which means it can only happen on creation. + // + // Because we need to call `assign` inside the closure (to get mut access), we need to "move" the future id into the closure. + // Rust cannot figure out at compile time that we only ever consume the ID once, so we need to move the check + // to runtime using an Option. + let mut fid = Some(fid); + + // The closure might get called, and it might give us an ID. Side channel it out of the closure. + let mut id = None; + + let bgl_result = device.bgl_pool.get_or_init(entry_map, |entry_map| { + let bgl = + device.create_bind_group_layout(&desc.label, entry_map, bgl::Origin::Pool)?; - let layout = match device.create_bind_group_layout(&desc.label, entry_map) { + let (id_inner, arc) = fid.take().unwrap().assign(bgl); + id = Some(id_inner); + + Ok(arc) + }); + + let layout = match bgl_result { Ok(layout) => layout, Err(e) => break e, }; - let (id, _layout) = fid.assign(layout); + // If the ID was not assigned, and we survived the above check, + // it means that the bind group layout already existed and we need to call `assign_existing`. + // + // Calling this function _will_ leak the ID. See https://github.com/gfx-rs/wgpu/issues/4912. + if id.is_none() { + id = Some(fid.take().unwrap().assign_existing(&layout)) + } + api_log!("Device::create_bind_group_layout -> {id:?}"); - return (id, None); + return (id.unwrap(), None); }; let fid = hub.bind_group_layouts.prepare::(id_in); diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index 0bfa90458a..bb0afedafc 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -19,6 +19,7 @@ use wgt::{BufferAddress, DeviceLostReason, TextureFormat}; use std::{iter, num::NonZeroU32, ptr}; pub mod any_device; +pub(crate) mod bgl; pub mod global; mod life; pub mod queue; diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index b89ad9abf0..3e18777eac 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -6,7 +6,7 @@ use crate::{ device::life::{LifetimeTracker, WaitIdleError}, device::queue::PendingWrites, device::{ - AttachmentData, CommandAllocator, DeviceLostInvocation, MissingDownlevelFlags, + bgl, AttachmentData, CommandAllocator, DeviceLostInvocation, MissingDownlevelFlags, MissingFeatures, RenderPassContext, CLEANUP_WAIT_MS, }, hal_api::HalApi, @@ -19,6 +19,7 @@ use crate::{ }, instance::Adapter, pipeline, + pool::ResourcePool, registry::Registry, resource::ResourceInfo, resource::{ @@ -120,6 +121,8 @@ pub struct Device { /// Temporary storage for resource management functions. Cleared at the end /// of every call (unless an error occurs). pub(crate) temp_suspected: Mutex>>, + /// Pool of bind group layouts, allowing deduplication. + pub(crate) bgl_pool: ResourcePool>, pub(crate) alignments: hal::Alignments, pub(crate) limits: wgt::Limits, pub(crate) features: wgt::Features, @@ -261,6 +264,7 @@ impl Device { trackers: Mutex::new(Tracker::new()), life_tracker: Mutex::new(life::LifetimeTracker::new()), temp_suspected: Mutex::new(Some(life::ResourceMaps::new())), + bgl_pool: ResourcePool::new(), #[cfg(feature = "trace")] trace: Mutex::new(trace_path.and_then(|path| match trace::Trace::new(path) { Ok(mut trace) => { @@ -1513,29 +1517,6 @@ impl Device { }) } - pub(crate) fn deduplicate_bind_group_layout<'a>( - self: &Arc, - entry_map: &'a binding_model::BindEntryMap, - guard: &'a Storage, id::BindGroupLayoutId>, - ) -> Option<(id::BindGroupLayoutId, Arc>)> { - guard - .iter(self.as_info().id().backend()) - .find(|&(_, bgl)| { - bgl.device.info.id() == self.as_info().id() && bgl.entries == *entry_map - }) - .map(|(id, resource)| (id, resource.clone())) - } - - pub(crate) fn get_introspection_bind_group_layouts<'a>( - pipeline_layout: &'a binding_model::PipelineLayout, - ) -> ArrayVec<&'a binding_model::BindEntryMap, { hal::MAX_BIND_GROUPS }> { - pipeline_layout - .bind_group_layouts - .iter() - .map(|layout| &layout.entries) - .collect() - } - /// Generate information about late-validated buffer bindings for pipelines. //TODO: should this be combined with `get_introspection_bind_group_layouts` in some way? pub(crate) fn make_late_sized_buffer_groups( @@ -1576,7 +1557,8 @@ impl Device { pub(crate) fn create_bind_group_layout( self: &Arc, label: &crate::Label, - entry_map: binding_model::BindEntryMap, + entry_map: bgl::EntryMap, + origin: bgl::Origin, ) -> Result, binding_model::CreateBindGroupLayoutError> { #[derive(PartialEq)] enum WritableStorage { @@ -1739,9 +1721,8 @@ impl Device { let bgl_flags = conv::bind_group_layout_flags(self.features); - let mut hal_bindings = entry_map.values().cloned().collect::>(); + let hal_bindings = entry_map.values().copied().collect::>(); let label = label.to_hal(self.instance_flags); - hal_bindings.sort_by_key(|b| b.binding); let hal_desc = hal::BindGroupLayoutDescriptor { label, flags: bgl_flags, @@ -1768,13 +1749,10 @@ impl Device { Ok(BindGroupLayout { raw: Some(raw), device: self.clone(), - info: ResourceInfo::new(label.unwrap_or("")), - dynamic_count: entry_map - .values() - .filter(|b| b.ty.has_dynamic_offset()) - .count(), - count_validator, entries: entry_map, + origin, + binding_count_validator: count_validator, + info: ResourceInfo::new(label.unwrap_or("")), label: label.unwrap_or_default().to_string(), }) } @@ -1996,7 +1974,7 @@ impl Device { // Find the corresponding declaration in the layout let decl = layout .entries - .get(&binding) + .get(binding) .ok_or(Error::MissingBindingDeclaration(binding))?; let (res_index, count) = match entry.resource { Br::Buffer(ref bb) => { @@ -2206,8 +2184,8 @@ impl Device { // collect in the order of BGL iteration late_buffer_binding_sizes: layout .entries - .keys() - .flat_map(|binding| late_buffer_binding_sizes.get(binding).cloned()) + .indices() + .flat_map(|binding| late_buffer_binding_sizes.get(&binding).cloned()) .collect(), }) } @@ -2442,7 +2420,7 @@ impl Device { return Err(DeviceError::WrongDevice.into()); } - count_validator.merge(&bind_group_layout.count_validator); + count_validator.merge(&bind_group_layout.binding_count_validator); } count_validator .validate(&self.limits) @@ -2486,10 +2464,10 @@ impl Device { pub(crate) fn derive_pipeline_layout( self: &Arc, implicit_context: Option, - mut derived_group_layouts: ArrayVec, + mut derived_group_layouts: ArrayVec, bgl_registry: &Registry>, pipeline_layout_registry: &Registry>, - ) -> Result { + ) -> Result>, pipeline::ImplicitLayoutError> { while derived_group_layouts .last() .map_or(false, |map| map.is_empty()) @@ -2508,16 +2486,8 @@ impl Device { } for (bgl_id, map) in ids.group_ids.iter_mut().zip(derived_group_layouts) { - let bgl = match self.deduplicate_bind_group_layout(&map, &bgl_registry.read()) { - Some((dedup_id, _)) => { - *bgl_id = dedup_id; - None - } - None => Some(self.create_bind_group_layout(&None, map)?), - }; - if let Some(bgl) = bgl { - bgl_registry.force_replace(*bgl_id, bgl); - } + let bgl = self.create_bind_group_layout(&None, map, bgl::Origin::Derived)?; + bgl_registry.force_replace(*bgl_id, bgl); } let layout_desc = binding_model::PipelineLayoutDescriptor { @@ -2527,7 +2497,7 @@ impl Device { }; let layout = self.create_pipeline_layout(&layout_desc, &bgl_registry.read())?; pipeline_layout_registry.force_replace(ids.root_id, layout); - Ok(ids.root_id) + Ok(pipeline_layout_registry.get(ids.root_id).unwrap()) } pub(crate) fn create_compute_pipeline( @@ -2549,12 +2519,6 @@ impl Device { self.require_downlevel_flags(wgt::DownlevelFlags::COMPUTE_SHADERS)?; - let mut derived_group_layouts = - ArrayVec::::new(); - let mut shader_binding_sizes = FastHashMap::default(); - - let io = validation::StageIo::default(); - let shader_module = hub .shader_modules .get(desc.stage.module) @@ -2564,59 +2528,66 @@ impl Device { return Err(DeviceError::WrongDevice.into()); } - { - let flag = wgt::ShaderStages::COMPUTE; - let pipeline_layout_guard = hub.pipeline_layouts.read(); - let provided_layouts = match desc.layout { - Some(pipeline_layout_id) => Some(Device::get_introspection_bind_group_layouts( - pipeline_layout_guard - .get(pipeline_layout_id) - .map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?, - )), - None => { - for _ in 0..self.limits.max_bind_groups { - derived_group_layouts.push(binding_model::BindEntryMap::default()); - } - None + // Get the pipeline layout from the desc if it is provided. + let pipeline_layout = match desc.layout { + Some(pipeline_layout_id) => { + let pipeline_layout = hub + .pipeline_layouts + .get(pipeline_layout_id) + .map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?; + + if pipeline_layout.device.as_info().id() != self.as_info().id() { + return Err(DeviceError::WrongDevice.into()); } - }; + + Some(pipeline_layout) + } + None => None, + }; + + let mut binding_layout_source = match pipeline_layout { + Some(ref pipeline_layout) => { + validation::BindingLayoutSource::Provided(pipeline_layout.get_binding_maps()) + } + None => validation::BindingLayoutSource::new_derived(&self.limits), + }; + let mut shader_binding_sizes = FastHashMap::default(); + let io = validation::StageIo::default(); + + { + let stage = wgt::ShaderStages::COMPUTE; + if let Some(ref interface) = shader_module.interface { let _ = interface.check_stage( - provided_layouts.as_ref().map(|p| p.as_slice()), - &mut derived_group_layouts, + &mut binding_layout_source, &mut shader_binding_sizes, &desc.stage.entry_point, - flag, + stage, io, None, )?; } } - let pipeline_layout_id = match desc.layout { - Some(id) => id, - None => self.derive_pipeline_layout( + let pipeline_layout = match binding_layout_source { + validation::BindingLayoutSource::Provided(_) => { + drop(binding_layout_source); + pipeline_layout.unwrap() + } + validation::BindingLayoutSource::Derived(entries) => self.derive_pipeline_layout( implicit_context, - derived_group_layouts, + entries, &hub.bind_group_layouts, &hub.pipeline_layouts, )?, }; - let pipeline_layout_guard = hub.pipeline_layouts.read(); - let layout = pipeline_layout_guard - .get(pipeline_layout_id) - .map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?; - - if layout.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } let late_sized_buffer_groups = - Device::make_late_sized_buffer_groups(&shader_binding_sizes, layout); + Device::make_late_sized_buffer_groups(&shader_binding_sizes, &pipeline_layout); let pipeline_desc = hal::ComputePipelineDescriptor { label: desc.label.to_hal(self.instance_flags), - layout: layout.raw(), + layout: pipeline_layout.raw(), stage: hal::ProgrammableStage { entry_point: desc.stage.entry_point.as_ref(), module: shader_module.raw(), @@ -2643,7 +2614,7 @@ impl Device { let pipeline = pipeline::ComputePipeline { raw: Some(raw), - layout: layout.clone(), + layout: pipeline_layout, device: self.clone(), _shader_module: shader_module, late_sized_buffer_groups, @@ -2661,8 +2632,6 @@ impl Device { ) -> Result, pipeline::CreateRenderPipelineError> { use wgt::TextureFormatFeatureFlags as Tfff; - let mut shader_modules = Vec::new(); - // This has to be done first, or otherwise the IDs may be pointing to entries // that are not even in the storage. if let Some(ref ids) = implicit_context { @@ -2675,8 +2644,6 @@ impl Device { } } - let mut derived_group_layouts = - ArrayVec::::new(); let mut shader_binding_sizes = FastHashMap::default(); let num_attachments = desc.fragment.as_ref().map(|f| f.targets.len()).unwrap_or(0); @@ -2944,11 +2911,29 @@ impl Device { } } - if desc.layout.is_none() { - for _ in 0..self.limits.max_bind_groups { - derived_group_layouts.push(binding_model::BindEntryMap::default()); + // Get the pipeline layout from the desc if it is provided. + let pipeline_layout = match desc.layout { + Some(pipeline_layout_id) => { + let pipeline_layout = hub + .pipeline_layouts + .get(pipeline_layout_id) + .map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?; + + if pipeline_layout.device.as_info().id() != self.as_info().id() { + return Err(DeviceError::WrongDevice.into()); + } + + Some(pipeline_layout) } - } + None => None, + }; + + let mut binding_layout_source = match pipeline_layout { + Some(ref pipeline_layout) => { + validation::BindingLayoutSource::Provided(pipeline_layout.get_binding_maps()) + } + None => validation::BindingLayoutSource::new_derived(&self.limits), + }; let samples = { let sc = desc.multisample.count; @@ -2958,122 +2943,86 @@ impl Device { sc }; - let shader_module_guard = hub.shader_modules.read(); - + let vertex_shader_module; let vertex_stage = { - let stage = &desc.vertex.stage; - let flag = wgt::ShaderStages::VERTEX; + let stage_desc = &desc.vertex.stage; + let stage = wgt::ShaderStages::VERTEX; - let shader_module = shader_module_guard.get(stage.module).map_err(|_| { + vertex_shader_module = hub.shader_modules.get(stage_desc.module).map_err(|_| { pipeline::CreateRenderPipelineError::Stage { - stage: flag, + stage, error: validation::StageError::InvalidModule, } })?; - if shader_module.device.as_info().id() != self.as_info().id() { + if vertex_shader_module.device.as_info().id() != self.as_info().id() { return Err(DeviceError::WrongDevice.into()); } - shader_modules.push(shader_module.clone()); - - let pipeline_layout_guard = hub.pipeline_layouts.read(); - - let provided_layouts = match desc.layout { - Some(pipeline_layout_id) => { - let pipeline_layout = pipeline_layout_guard - .get(pipeline_layout_id) - .map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?; - - if pipeline_layout.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } - Some(Device::get_introspection_bind_group_layouts( - pipeline_layout, - )) - } - None => None, - }; - - if let Some(ref interface) = shader_module.interface { + if let Some(ref interface) = vertex_shader_module.interface { io = interface .check_stage( - provided_layouts.as_ref().map(|p| p.as_slice()), - &mut derived_group_layouts, + &mut binding_layout_source, &mut shader_binding_sizes, - &stage.entry_point, - flag, + &stage_desc.entry_point, + stage, io, desc.depth_stencil.as_ref().map(|d| d.depth_compare), ) - .map_err(|error| pipeline::CreateRenderPipelineError::Stage { - stage: flag, - error, - })?; - validated_stages |= flag; + .map_err(|error| pipeline::CreateRenderPipelineError::Stage { stage, error })?; + validated_stages |= stage; } hal::ProgrammableStage { - module: shader_module.raw(), - entry_point: stage.entry_point.as_ref(), + module: vertex_shader_module.raw(), + entry_point: stage_desc.entry_point.as_ref(), } }; + let mut fragment_shader_module = None; let fragment_stage = match desc.fragment { - Some(ref fragment) => { - let flag = wgt::ShaderStages::FRAGMENT; + Some(ref fragment_state) => { + let stage = wgt::ShaderStages::FRAGMENT; - let shader_module = - shader_module_guard - .get(fragment.stage.module) + let shader_module = fragment_shader_module.insert( + hub.shader_modules + .get(fragment_state.stage.module) .map_err(|_| pipeline::CreateRenderPipelineError::Stage { - stage: flag, + stage, error: validation::StageError::InvalidModule, - })?; - shader_modules.push(shader_module.clone()); - - let pipeline_layout_guard = hub.pipeline_layouts.read(); - let provided_layouts = match desc.layout { - Some(pipeline_layout_id) => Some(Device::get_introspection_bind_group_layouts( - pipeline_layout_guard - .get(pipeline_layout_id) - .as_ref() - .map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?, - )), - None => None, - }; + })?, + ); if validated_stages == wgt::ShaderStages::VERTEX { if let Some(ref interface) = shader_module.interface { io = interface .check_stage( - provided_layouts.as_ref().map(|p| p.as_slice()), - &mut derived_group_layouts, + &mut binding_layout_source, &mut shader_binding_sizes, - &fragment.stage.entry_point, - flag, + &fragment_state.stage.entry_point, + stage, io, desc.depth_stencil.as_ref().map(|d| d.depth_compare), ) .map_err(|error| pipeline::CreateRenderPipelineError::Stage { - stage: flag, + stage, error, })?; - validated_stages |= flag; + validated_stages |= stage; } } if let Some(ref interface) = shader_module.interface { shader_expects_dual_source_blending = interface - .fragment_uses_dual_source_blending(&fragment.stage.entry_point) + .fragment_uses_dual_source_blending(&fragment_state.stage.entry_point) .map_err(|error| pipeline::CreateRenderPipelineError::Stage { - stage: flag, + stage, error, })?; } Some(hal::ProgrammableStage { module: shader_module.raw(), - entry_point: fragment.stage.entry_point.as_ref(), + entry_point: fragment_state.stage.entry_point.as_ref(), }) } None => None, @@ -3126,22 +3075,18 @@ impl Device { return Err(pipeline::ImplicitLayoutError::ReflectionError(last_stage).into()); } - let pipeline_layout_id = match desc.layout { - Some(id) => id, - None => self.derive_pipeline_layout( + let pipeline_layout = match binding_layout_source { + validation::BindingLayoutSource::Provided(_) => { + drop(binding_layout_source); + pipeline_layout.unwrap() + } + validation::BindingLayoutSource::Derived(entries) => self.derive_pipeline_layout( implicit_context, - derived_group_layouts, + entries, &hub.bind_group_layouts, &hub.pipeline_layouts, )?, }; - let layout = { - let pipeline_layout_guard = hub.pipeline_layouts.read(); - pipeline_layout_guard - .get(pipeline_layout_id) - .map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)? - .clone() - }; // Multiview is only supported if the feature is enabled if desc.multiview.is_some() { @@ -3165,11 +3110,11 @@ impl Device { } let late_sized_buffer_groups = - Device::make_late_sized_buffer_groups(&shader_binding_sizes, &layout); + Device::make_late_sized_buffer_groups(&shader_binding_sizes, &pipeline_layout); let pipeline_desc = hal::RenderPipelineDescriptor { label: desc.label.to_hal(self.instance_flags), - layout: layout.raw(), + layout: pipeline_layout.raw(), vertex_buffers: &vertex_buffers, vertex_stage, primitive: desc.primitive, @@ -3233,9 +3178,16 @@ impl Device { } } + let shader_modules = { + let mut shader_modules = ArrayVec::new(); + shader_modules.push(vertex_shader_module); + shader_modules.extend(fragment_shader_module); + shader_modules + }; + let pipeline = pipeline::RenderPipeline { raw: Some(raw), - layout: layout.clone(), + layout: pipeline_layout, device: self.clone(), pass_context, _shader_modules: shader_modules, diff --git a/wgpu-core/src/hash_utils.rs b/wgpu-core/src/hash_utils.rs new file mode 100644 index 0000000000..f44aad2f1a --- /dev/null +++ b/wgpu-core/src/hash_utils.rs @@ -0,0 +1,86 @@ +//! Module for hashing utilities. +//! +//! Named hash_utils to prevent clashing with the std::hash module. + +/// HashMap using a fast, non-cryptographic hash algorithm. +pub type FastHashMap = + std::collections::HashMap>; +/// HashSet using a fast, non-cryptographic hash algorithm. +pub type FastHashSet = + std::collections::HashSet>; + +/// IndexMap using a fast, non-cryptographic hash algorithm. +pub type FastIndexMap = + indexmap::IndexMap>; + +/// HashMap that uses pre-hashed keys and an identity hasher. +/// +/// This is useful when you only need the key to lookup the value, and don't need to store the key, +/// particularly when the key is large. +pub type PreHashedMap = + std::collections::HashMap, V, std::hash::BuildHasherDefault>; + +/// A pre-hashed key using FxHash which allows the hashing operation to be disconnected +/// from the storage in the map. +pub struct PreHashedKey(u64, std::marker::PhantomData K>); + +impl std::fmt::Debug for PreHashedKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("PreHashedKey").field(&self.0).finish() + } +} + +impl Copy for PreHashedKey {} + +impl Clone for PreHashedKey { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for PreHashedKey { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for PreHashedKey {} + +impl std::hash::Hash for PreHashedKey { + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + +impl PreHashedKey { + pub fn from_key(key: &K) -> Self { + use std::hash::Hasher; + + let mut hasher = rustc_hash::FxHasher::default(); + key.hash(&mut hasher); + Self(hasher.finish(), std::marker::PhantomData) + } +} + +/// A hasher which does nothing. Useful for when you want to use a map with pre-hashed keys. +/// +/// When hashing with this hasher, you must provide exactly 8 bytes. Multiple calls to `write` +/// will overwrite the previous value. +#[derive(Default)] +pub struct IdentityHasher { + hash: u64, +} + +impl std::hash::Hasher for IdentityHasher { + fn write(&mut self, bytes: &[u8]) { + self.hash = u64::from_ne_bytes( + bytes + .try_into() + .expect("identity hasher must be given exactly 8 bytes"), + ); + } + + fn finish(&self) -> u64 { + self.hash + } +} diff --git a/wgpu-core/src/lib.rs b/wgpu-core/src/lib.rs index b44248fdae..a35fcacec2 100644 --- a/wgpu-core/src/lib.rs +++ b/wgpu-core/src/lib.rs @@ -84,12 +84,14 @@ pub mod device; pub mod error; pub mod global; pub mod hal_api; +mod hash_utils; pub mod hub; pub mod id; pub mod identity; mod init_tracker; pub mod instance; pub mod pipeline; +mod pool; pub mod present; pub mod registry; pub mod resource; @@ -106,6 +108,8 @@ pub use hal::{api, MAX_BIND_GROUPS, MAX_COLOR_ATTACHMENTS, MAX_VERTEX_BUFFERS}; use std::{borrow::Cow, os::raw::c_char}; +pub(crate) use hash_utils::*; + /// The index of a queue submission. /// /// These are the values stored in `Device::fence`. @@ -335,13 +339,6 @@ macro_rules! resource_log { } pub(crate) use resource_log; -/// Fast hash map used internally. -type FastHashMap = - std::collections::HashMap>; -/// Fast hash set used internally. -type FastHashSet = - std::collections::HashSet>; - #[inline] pub(crate) fn get_lowest_common_denom(a: u32, b: u32) -> u32 { let gcd = if a >= b { diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index 27a58fd966..1d487a1bfc 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -479,7 +479,8 @@ pub struct RenderPipeline { pub(crate) raw: Option, pub(crate) device: Arc>, pub(crate) layout: Arc>, - pub(crate) _shader_modules: Vec>>, + pub(crate) _shader_modules: + ArrayVec>, { hal::MAX_CONCURRENT_SHADER_STAGES }>, pub(crate) pass_context: RenderPassContext, pub(crate) flags: PipelineFlags, pub(crate) strip_index_format: Option, diff --git a/wgpu-core/src/pool.rs b/wgpu-core/src/pool.rs new file mode 100644 index 0000000000..ddf45fbcb3 --- /dev/null +++ b/wgpu-core/src/pool.rs @@ -0,0 +1,312 @@ +use std::{ + collections::{hash_map::Entry, HashMap}, + hash::Hash, + sync::{Arc, Weak}, +}; + +use once_cell::sync::OnceCell; +use parking_lot::Mutex; + +use crate::{PreHashedKey, PreHashedMap}; + +type SlotInner = Weak; +type ResourcePoolSlot = Arc>>; + +pub struct ResourcePool { + // We use a pre-hashed map as we never actually need to read the keys. + // + // This additionally allows us to not need to hash more than once on get_or_init. + inner: Mutex>>, +} + +impl ResourcePool { + pub fn new() -> Self { + Self { + inner: Mutex::new(HashMap::default()), + } + } + + /// Get a resource from the pool with the given entry map, or create a new one if it doesn't exist using the given constructor. + /// + /// Behaves such that only one resource will be created for each unique entry map at any one time. + pub fn get_or_init(&self, key: K, constructor: F) -> Result, E> + where + F: FnOnce(K) -> Result, E>, + { + // Hash the key outside of the lock. + let hashed_key = PreHashedKey::from_key(&key); + + // We can't prove at compile time that these will only ever be consumed once, + // so we need to do the check at runtime. + let mut key = Some(key); + let mut constructor = Some(constructor); + + 'race: loop { + let mut map_guard = self.inner.lock(); + + let entry = match map_guard.entry(hashed_key) { + // An entry exists for this resource. + // + // We know that either: + // - The resource is still alive, and Weak::upgrade will succeed. + // - The resource is in the process of being dropped, and Weak::upgrade will fail. + // + // The entry will never be empty while the BGL is still alive. + Entry::Occupied(entry) => Arc::clone(entry.get()), + // No entry exists for this resource. + // + // We know that the resource is not alive, so we can create a new entry. + Entry::Vacant(entry) => Arc::clone(entry.insert(Arc::new(OnceCell::new()))), + }; + + drop(map_guard); + + // Some other thread may beat us to initializing the entry, but OnceCell guarentees that only one thread + // will actually initialize the entry. + // + // We pass the strong reference outside of the closure to keep it alive while we're the only one keeping a reference to it. + let mut strong = None; + let weak = entry.get_or_try_init(|| { + let strong_inner = constructor.take().unwrap()(key.take().unwrap())?; + let weak = Arc::downgrade(&strong_inner); + strong = Some(strong_inner); + Ok(weak) + })?; + + // If strong is Some, that means we just initialized the entry, so we can just return it. + if let Some(strong) = strong { + return Ok(strong); + } + + // The entry was already initialized by someone else, so we need to try to upgrade it. + if let Some(strong) = weak.upgrade() { + // We succeed, the resource is still alive, just return that. + return Ok(strong); + } + + // The resource is in the process of being dropped, because upgrade failed. The entry still exists in the map, but it points to nothing. + // + // We're in a race with the drop implementation of the resource, so lets just go around again. When we go around again: + // - If the entry exists, we might need to go around a few more times. + // - If the entry doesn't exist, we'll create a new one. + continue 'race; + } + } + + /// Remove the given entry map from the pool. + /// + /// Must *only* be called in the Drop impl of [`BindGroupLayout`]. + pub fn remove(&self, key: &K) { + let hashed_key = PreHashedKey::from_key(key); + + let mut map_guard = self.inner.lock(); + + // Weak::upgrade will be failing long before this code is called. All threads trying to access the resource will be spinning, + // waiting for the entry to be removed. It is safe to remove the entry from the map. + map_guard.remove(&hashed_key); + } +} + +#[cfg(test)] +mod tests { + use std::sync::{ + atomic::{AtomicU32, Ordering}, + Barrier, + }; + + use super::*; + + #[test] + fn deduplication() { + let pool = ResourcePool::::new(); + + let mut counter = 0_u32; + + let arc1 = pool + .get_or_init::<_, ()>(0, |key| { + counter += 1; + Ok(Arc::new(key)) + }) + .unwrap(); + + assert_eq!(*arc1, 0); + assert_eq!(counter, 1); + + let arc2 = pool + .get_or_init::<_, ()>(0, |key| { + counter += 1; + Ok(Arc::new(key)) + }) + .unwrap(); + + assert!(Arc::ptr_eq(&arc1, &arc2)); + assert_eq!(*arc2, 0); + assert_eq!(counter, 1); + + drop(arc1); + drop(arc2); + pool.remove(&0); + + let arc3 = pool + .get_or_init::<_, ()>(0, |key| { + counter += 1; + Ok(Arc::new(key)) + }) + .unwrap(); + + assert_eq!(*arc3, 0); + assert_eq!(counter, 2); + } + + // Test name has "2_threads" in the name so nextest reserves two threads for it. + #[test] + fn concurrent_creation_2_threads() { + struct Resources { + pool: ResourcePool, + counter: AtomicU32, + barrier: Barrier, + } + + let resources = Arc::new(Resources { + pool: ResourcePool::::new(), + counter: AtomicU32::new(0), + barrier: Barrier::new(2), + }); + + // Like all races, this is not inherently guaranteed to work, but in practice it should work fine. + // + // To validate the expected order of events, we've put print statements in the code, indicating when each thread is at a certain point. + // The output will look something like this if the test is working as expected: + // + // ``` + // 0: prewait + // 1: prewait + // 1: postwait + // 0: postwait + // 1: init + // 1: postget + // 0: postget + // ``` + fn thread_inner(idx: u8, resources: &Resources) -> Arc { + eprintln!("{idx}: prewait"); + + // Once this returns, both threads should hit get_or_init at about the same time, + // allowing us to actually test concurrent creation. + // + // Like all races, this is not inherently guaranteed to work, but in practice it should work fine. + resources.barrier.wait(); + + eprintln!("{idx}: postwait"); + + let ret = resources + .pool + .get_or_init::<_, ()>(0, |key| { + eprintln!("{idx}: init"); + + // Simulate long running constructor, ensuring that both threads will be in get_or_init. + std::thread::sleep(std::time::Duration::from_millis(250)); + + resources.counter.fetch_add(1, Ordering::SeqCst); + + Ok(Arc::new(key)) + }) + .unwrap(); + + eprintln!("{idx}: postget"); + + ret + } + + let thread1 = std::thread::spawn({ + let resource_clone = Arc::clone(&resources); + move || thread_inner(1, &resource_clone) + }); + + let arc0 = thread_inner(0, &resources); + + assert_eq!(resources.counter.load(Ordering::Acquire), 1); + + let arc1 = thread1.join().unwrap(); + + assert!(Arc::ptr_eq(&arc0, &arc1)); + } + + // Test name has "2_threads" in the name so nextest reserves two threads for it. + #[test] + fn create_while_drop_2_threads() { + struct Resources { + pool: ResourcePool, + barrier: Barrier, + } + + let resources = Arc::new(Resources { + pool: ResourcePool::::new(), + barrier: Barrier::new(2), + }); + + // Like all races, this is not inherently guaranteed to work, but in practice it should work fine. + // + // To validate the expected order of events, we've put print statements in the code, indicating when each thread is at a certain point. + // The output will look something like this if the test is working as expected: + // + // ``` + // 0: prewait + // 1: prewait + // 1: postwait + // 0: postwait + // 1: postsleep + // 1: removal + // 0: postget + // ``` + // + // The last two _may_ be flipped. + + let existing_entry = resources + .pool + .get_or_init::<_, ()>(0, |key| Ok(Arc::new(key))) + .unwrap(); + + // Drop the entry, but do _not_ remove it from the pool. + // This simulates the situation where the resource arc has been dropped, but the Drop implementation + // has not yet run, which calls remove. + drop(existing_entry); + + fn thread0_inner(resources: &Resources) { + eprintln!("0: prewait"); + resources.barrier.wait(); + + eprintln!("0: postwait"); + // We try to create a new entry, but the entry already exists. + // + // As Arc::upgrade is failing, we will just keep spinning until remove is called. + resources + .pool + .get_or_init::<_, ()>(0, |key| Ok(Arc::new(key))) + .unwrap(); + eprintln!("0: postget"); + } + + fn thread1_inner(resources: &Resources) { + eprintln!("1: prewait"); + resources.barrier.wait(); + + eprintln!("1: postwait"); + // We wait a little bit, making sure that thread0_inner has started spinning. + std::thread::sleep(std::time::Duration::from_millis(250)); + eprintln!("1: postsleep"); + + // We remove the entry from the pool, allowing thread0_inner to re-create. + resources.pool.remove(&0); + eprintln!("1: removal"); + } + + let thread1 = std::thread::spawn({ + let resource_clone = Arc::clone(&resources); + move || thread1_inner(&resource_clone) + }); + + thread0_inner(&resources); + + thread1.join().unwrap(); + } +} diff --git a/wgpu-core/src/registry.rs b/wgpu-core/src/registry.rs index 0fe3b7dd13..79d68921ac 100644 --- a/wgpu-core/src/registry.rs +++ b/wgpu-core/src/registry.rs @@ -80,12 +80,21 @@ impl> FutureId<'_, I, T> { Arc::new(value) } + /// Assign a new resource to this ID. + /// + /// Registers it with the registry, and fills out the resource info. pub fn assign(self, value: T) -> (I, Arc) { let mut data = self.data.write(); data.insert(self.id, self.init(value)); (self.id, data.get(self.id).unwrap().clone()) } + /// Assign an existing resource to a new ID. + /// + /// Registers it with the registry. + /// + /// This _will_ leak the ID, and it will not be recycled again. + /// See https://github.com/gfx-rs/wgpu/issues/4912. pub fn assign_existing(self, value: &Arc) -> I { let mut data = self.data.write(); debug_assert!(!data.contains(self.id)); @@ -125,7 +134,7 @@ impl> Registry { self.read().try_get(id).map(|o| o.cloned()) } pub(crate) fn get(&self, id: I) -> Result, InvalidId> { - self.read().get(id).map(|v| v.clone()) + self.read().get_owned(id) } pub(crate) fn read<'a>(&'a self) -> RwLockReadGuard<'a, Storage> { self.storage.read() diff --git a/wgpu-core/src/storage.rs b/wgpu-core/src/storage.rs index 891b7954e6..cf81e65eb8 100644 --- a/wgpu-core/src/storage.rs +++ b/wgpu-core/src/storage.rs @@ -131,6 +131,12 @@ where result } + /// Get an owned reference to an item behind a potentially invalid ID. + /// Panics if there is an epoch mismatch, or the entry is empty. + pub(crate) fn get_owned(&self, id: I) -> Result, InvalidId> { + Ok(Arc::clone(self.get(id)?)) + } + pub(crate) fn label_for_invalid_id(&self, id: I) -> &str { let (index, _, _) = id.unzip(); match self.map.get(index as usize) { diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index 1c05f47ec5..a0947ae83f 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -1,4 +1,5 @@ -use crate::{binding_model::BindEntryMap, FastHashMap, FastHashSet}; +use crate::{device::bgl, FastHashMap, FastHashSet}; +use arrayvec::ArrayVec; use std::{collections::hash_map::Entry, fmt}; use thiserror::Error; use wgt::{BindGroupLayoutEntry, BindingType}; @@ -774,6 +775,27 @@ pub fn check_texture_format( } } +pub enum BindingLayoutSource<'a> { + /// The binding layout is derived from the pipeline layout. + /// + /// This will be filled in by the shader binding validation, as it iterates the shader's interfaces. + Derived(ArrayVec), + /// The binding layout is provided by the user in BGLs. + /// + /// This will be validated against the shader's interfaces. + Provided(ArrayVec<&'a bgl::EntryMap, { hal::MAX_BIND_GROUPS }>), +} + +impl<'a> BindingLayoutSource<'a> { + pub fn new_derived(limits: &wgt::Limits) -> Self { + let mut array = ArrayVec::new(); + for _ in 0..limits.max_bind_groups { + array.push(Default::default()); + } + BindingLayoutSource::Derived(array) + } +} + pub type StageIo = FastHashMap; impl Interface { @@ -933,8 +955,7 @@ impl Interface { pub fn check_stage( &self, - given_layouts: Option<&[&BindEntryMap]>, - derived_layouts: &mut [BindEntryMap], + layouts: &mut BindingLayoutSource<'_>, shader_binding_sizes: &mut FastHashMap, entry_point_name: &str, stage_bit: wgt::ShaderStages, @@ -958,45 +979,53 @@ impl Interface { // check resources visibility for &handle in entry_point.resources.iter() { let res = &self.resources[handle]; - let result = match given_layouts { - Some(layouts) => { - // update the required binding size for this buffer - if let ResourceType::Buffer { size } = res.ty { - match shader_binding_sizes.entry(res.bind.clone()) { - Entry::Occupied(e) => { - *e.into_mut() = size.max(*e.get()); - } - Entry::Vacant(e) => { - e.insert(size); + let result = 'err: { + match layouts { + BindingLayoutSource::Provided(layouts) => { + // update the required binding size for this buffer + if let ResourceType::Buffer { size } = res.ty { + match shader_binding_sizes.entry(res.bind.clone()) { + Entry::Occupied(e) => { + *e.into_mut() = size.max(*e.get()); + } + Entry::Vacant(e) => { + e.insert(size); + } } } + + let Some(map) = layouts.get(res.bind.group as usize) else { + break 'err Err(BindingError::Missing); + }; + + let Some(entry) = map.get(res.bind.binding) else { + break 'err Err(BindingError::Missing); + }; + + if !entry.visibility.contains(stage_bit) { + break 'err Err(BindingError::Invisible); + } + + res.check_binding_use(entry) } - layouts - .get(res.bind.group as usize) - .and_then(|map| map.get(&res.bind.binding)) - .ok_or(BindingError::Missing) - .and_then(|entry| { - if entry.visibility.contains(stage_bit) { - Ok(entry) - } else { - Err(BindingError::Invisible) - } - }) - .and_then(|entry| res.check_binding_use(entry)) - } - None => derived_layouts - .get_mut(res.bind.group as usize) - .ok_or(BindingError::Missing) - .and_then(|set| { - let ty = res.derive_binding_type()?; - match set.entry(res.bind.binding) { - Entry::Occupied(e) if e.get().ty != ty => { - return Err(BindingError::InconsistentlyDerivedType) + BindingLayoutSource::Derived(layouts) => { + let Some(map) = layouts.get_mut(res.bind.group as usize) else { + break 'err Err(BindingError::Missing); + }; + + let ty = match res.derive_binding_type() { + Ok(ty) => ty, + Err(error) => break 'err Err(error), + }; + + match map.entry(res.bind.binding) { + indexmap::map::Entry::Occupied(e) if e.get().ty != ty => { + break 'err Err(BindingError::InconsistentlyDerivedType) } - Entry::Occupied(e) => { + indexmap::map::Entry::Occupied(e) => { e.into_mut().visibility |= stage_bit; } - Entry::Vacant(e) => { + indexmap::map::Entry::Vacant(e) => { e.insert(BindGroupLayoutEntry { binding: res.bind.binding, ty, @@ -1006,20 +1035,28 @@ impl Interface { } } Ok(()) - }), + } + } }; if let Err(error) = result { return Err(StageError::Binding(res.bind.clone(), error)); } } - // check the compatibility between textures and samplers - if let Some(layouts) = given_layouts { + // Check the compatibility between textures and samplers + // + // We only need to do this if the binding layout is provided by the user, as derived + // layouts will inherently be correctly tagged. + if let BindingLayoutSource::Provided(layouts) = layouts { for &(texture_handle, sampler_handle) in entry_point.sampling_pairs.iter() { let texture_bind = &self.resources[texture_handle].bind; let sampler_bind = &self.resources[sampler_handle].bind; - let texture_layout = &layouts[texture_bind.group as usize][&texture_bind.binding]; - let sampler_layout = &layouts[sampler_bind.group as usize][&sampler_bind.binding]; + let texture_layout = layouts[texture_bind.group as usize] + .get(texture_bind.binding) + .unwrap(); + let sampler_layout = layouts[sampler_bind.group as usize] + .get(sampler_bind.binding) + .unwrap(); assert!(texture_layout.visibility.contains(stage_bit)); assert!(sampler_layout.visibility.contains(stage_bit));