From 05ef1f8f6c7d352cd2a1272917ccb3caaef47f34 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sat, 30 Nov 2024 14:18:28 -0800 Subject: [PATCH 01/14] add workgroup_size_overrides --- naga/src/back/pipeline_constants.rs | 49 ++++++++++++++ naga/src/front/glsl/functions.rs | 1 + naga/src/front/spv/function.rs | 1 + naga/src/front/wgsl/lower/mod.rs | 66 +++++++++++++++++-- naga/src/lib.rs | 2 + naga/tests/out/ir/access.compact.ron | 4 ++ naga/tests/out/ir/access.ron | 4 ++ .../out/ir/atomic_i_increment.compact.ron | 1 + naga/tests/out/ir/atomic_i_increment.ron | 1 + naga/tests/out/ir/collatz.compact.ron | 1 + naga/tests/out/ir/collatz.ron | 1 + naga/tests/out/ir/fetch_depth.compact.ron | 1 + naga/tests/out/ir/fetch_depth.ron | 1 + naga/tests/out/ir/index-by-value.compact.ron | 1 + naga/tests/out/ir/index-by-value.ron | 1 + ...ides-atomicCompareExchangeWeak.compact.ron | 1 + .../overrides-atomicCompareExchangeWeak.ron | 1 + .../out/ir/overrides-ray-query.compact.ron | 1 + naga/tests/out/ir/overrides-ray-query.ron | 1 + naga/tests/out/ir/overrides.compact.ron | 1 + naga/tests/out/ir/overrides.ron | 1 + naga/tests/out/ir/shadow.compact.ron | 1 + naga/tests/out/ir/shadow.ron | 1 + naga/tests/out/ir/spec-constants.compact.ron | 1 + naga/tests/out/ir/spec-constants.ron | 1 + 25 files changed, 141 insertions(+), 4 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 0005cbcb0e..cf7b84a342 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -25,6 +25,8 @@ pub enum PipelineConstantError { ConstantEvaluatorError(#[from] ConstantEvaluatorError), #[error(transparent)] ValidationError(#[from] WithSpan), + #[error("workgroup_size was overridden to a negative value")] + NegativeWorkgroupSize, } /// Replace all overrides in `module` with constants. @@ -190,6 +192,7 @@ pub fn process_overrides<'a>( let mut entry_points = mem::take(&mut module.entry_points); for ep in entry_points.iter_mut() { process_function(&mut module, &override_map, &mut ep.function)?; + process_workgroup_size_override(&mut module, &override_map, ep)?; } module.entry_points = entry_points; @@ -202,6 +205,52 @@ pub fn process_overrides<'a>( Ok((Cow::Owned(module), Cow::Owned(module_info))) } +fn process_workgroup_size_override( + module: &mut Module, + override_map: &HandleVec>, + ep: &mut crate::EntryPoint +) -> Result<(), PipelineConstantError> { + match ep.workgroup_size_overrides { + None => {} + Some(overrides) => { + overrides.iter().enumerate().try_for_each( + |(i, overridden)| -> Result<(), PipelineConstantError> { + match overridden { + None => Ok(()), + Some(h) => { + let c = module.constants[override_map[*h]].init; + let n = &module.global_expressions[c]; + match n { + crate::Expression::Literal(literal) => { + ep.workgroup_size[i] = match literal { + crate::Literal::U32(m) => (*m).into(), + crate::Literal::I32(m) => { + if *m < 0 { + Err(PipelineConstantError::NegativeWorkgroupSize)?; + unreachable!(); + } else { + *m as u32 + } + } + _ => { + unreachable!(); + } + }; + } + _ => { + unreachable!(); + } + } + Ok(()) + } + } + } + )?; + } + } + Ok(()) +} + /// Add a [`Constant`] to `module` for the override `old_h`. /// /// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. diff --git a/naga/src/front/glsl/functions.rs b/naga/src/front/glsl/functions.rs index 658632e872..394be22eaa 100644 --- a/naga/src/front/glsl/functions.rs +++ b/naga/src/front/glsl/functions.rs @@ -1366,6 +1366,7 @@ impl Frontend { early_depth_test: Some(crate::EarlyDepthTest { conservative: None }) .filter(|_| self.meta.early_fragment_tests), workgroup_size: self.meta.workgroup_size, + workgroup_size_overrides: None, function: Function { arguments, expressions, diff --git a/naga/src/front/spv/function.rs b/naga/src/front/spv/function.rs index 7122e44837..271b96926b 100644 --- a/naga/src/front/spv/function.rs +++ b/naga/src/front/spv/function.rs @@ -569,6 +569,7 @@ impl> super::Frontend { stage: ep.stage, early_depth_test: ep.early_depth_test, workgroup_size: ep.workgroup_size, + workgroup_size_overrides: None, function, }); diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 5173e73d79..8401df4ecd 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1311,24 +1311,55 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .collect(); if let Some(ref entry) = f.entry_point { - let workgroup_size = if let Some(workgroup_size) = entry.workgroup_size { + let workgroup_size_info = if let Some(workgroup_size) = entry.workgroup_size { // TODO: replace with try_map once stabilized let mut workgroup_size_out = [1; 3]; + let mut workgroup_size_overrides_out = [None; 3]; for (i, size) in workgroup_size.into_iter().enumerate() { if let Some(size_expr) = size { - workgroup_size_out[i] = self.const_u32(size_expr, &mut ctx.as_const())?.0; + match self.const_u32(size_expr, &mut ctx.as_const()) { + Ok(value) => { + workgroup_size_out[i] = value.0; + } + err => { + if let Err(Error::ConstantEvaluatorError(ref ty, _)) = err { + match **ty { + crate::proc::ConstantEvaluatorError::OverrideExpr => { + workgroup_size_overrides_out[i] = Some( + self.workgroup_size_override( + size_expr, + &mut ctx.as_override(), + i, + )? + ); + } + _ => { + err?; + } + } + } else { + err?; + } + } + } } } - workgroup_size_out + if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { + (workgroup_size_out, None) + } else { + (workgroup_size_out, Some(workgroup_size_overrides_out)) + } } else { - [0; 3] + ([0; 3], None) }; + let (workgroup_size, workgroup_size_overrides) = workgroup_size_info; ctx.module.entry_points.push(crate::EntryPoint { name: f.name.name.to_string(), stage: entry.stage, early_depth_test: entry.early_depth_test, workgroup_size, + workgroup_size_overrides, function, }); Ok(LoweredGlobalDecl::EntryPoint) @@ -1338,6 +1369,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } + fn workgroup_size_override( + &mut self, + size_expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, + i: usize, + ) -> Result, Error<'source>> { + let span = ctx.ast_expressions.get_span(size_expr); + let expr = self.expression(size_expr, ctx)?; + let ty = ctx.register_type(expr)?; + match ctx.module.types[ty].inner.scalar_kind().ok_or(0) { + Ok(crate::ScalarKind::Sint) | Ok(crate::ScalarKind::Uint) => Ok({ + ctx.module.overrides.append( + crate::Override { + name: Some(format!("__workgroup_size_{}", i)), + id: None, + ty, + init: Some(expr), + }, + span, + ) + }), + _ => { + Err(Error::ExpectedConstExprConcreteIntegerScalar(span)) + } + } + } + fn block( &mut self, b: &ast::Block<'source>, diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 4afbfff9d7..1c1929efa2 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -2186,6 +2186,8 @@ pub struct EntryPoint { pub early_depth_test: Option, /// Workgroup size for compute stages pub workgroup_size: [u32; 3], + /// Override expressions for workgroup size + pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function, } diff --git a/naga/tests/out/ir/access.compact.ron b/naga/tests/out/ir/access.compact.ron index 974080e998..e314078c2b 100644 --- a/naga/tests/out/ir/access.compact.ron +++ b/naga/tests/out/ir/access.compact.ron @@ -1854,6 +1854,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_vert"), arguments: [ @@ -2156,6 +2157,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_frag"), arguments: [], @@ -2348,6 +2350,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_through_ptr"), arguments: [], @@ -2430,6 +2433,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_to_ptr_components"), arguments: [], diff --git a/naga/tests/out/ir/access.ron b/naga/tests/out/ir/access.ron index 974080e998..e314078c2b 100644 --- a/naga/tests/out/ir/access.ron +++ b/naga/tests/out/ir/access.ron @@ -1854,6 +1854,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_vert"), arguments: [ @@ -2156,6 +2157,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_frag"), arguments: [], @@ -2348,6 +2350,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_through_ptr"), arguments: [], @@ -2430,6 +2433,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_to_ptr_components"), arguments: [], diff --git a/naga/tests/out/ir/atomic_i_increment.compact.ron b/naga/tests/out/ir/atomic_i_increment.compact.ron index 7d024f4e81..5bb6820258 100644 --- a/naga/tests/out/ir/atomic_i_increment.compact.ron +++ b/naga/tests/out/ir/atomic_i_increment.compact.ron @@ -263,6 +263,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("stage::test_atomic_i_increment_wrap"), arguments: [], diff --git a/naga/tests/out/ir/atomic_i_increment.ron b/naga/tests/out/ir/atomic_i_increment.ron index aab4c07206..ae14821330 100644 --- a/naga/tests/out/ir/atomic_i_increment.ron +++ b/naga/tests/out/ir/atomic_i_increment.ron @@ -288,6 +288,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("stage::test_atomic_i_increment_wrap"), arguments: [], diff --git a/naga/tests/out/ir/collatz.compact.ron b/naga/tests/out/ir/collatz.compact.ron index 48ce8e76bc..6a7aebe544 100644 --- a/naga/tests/out/ir/collatz.compact.ron +++ b/naga/tests/out/ir/collatz.compact.ron @@ -257,6 +257,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [ diff --git a/naga/tests/out/ir/collatz.ron b/naga/tests/out/ir/collatz.ron index 48ce8e76bc..6a7aebe544 100644 --- a/naga/tests/out/ir/collatz.ron +++ b/naga/tests/out/ir/collatz.ron @@ -257,6 +257,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [ diff --git a/naga/tests/out/ir/fetch_depth.compact.ron b/naga/tests/out/ir/fetch_depth.compact.ron index 0d998e205c..f10ccb94f7 100644 --- a/naga/tests/out/ir/fetch_depth.compact.ron +++ b/naga/tests/out/ir/fetch_depth.compact.ron @@ -176,6 +176,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("cull::fetch_depth_wrap"), arguments: [], diff --git a/naga/tests/out/ir/fetch_depth.ron b/naga/tests/out/ir/fetch_depth.ron index c66b7eb065..d25e046d57 100644 --- a/naga/tests/out/ir/fetch_depth.ron +++ b/naga/tests/out/ir/fetch_depth.ron @@ -246,6 +246,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("cull::fetch_depth_wrap"), arguments: [], diff --git a/naga/tests/out/ir/index-by-value.compact.ron b/naga/tests/out/ir/index-by-value.compact.ron index f0ea76f496..93a9821426 100644 --- a/naga/tests/out/ir/index-by-value.compact.ron +++ b/naga/tests/out/ir/index-by-value.compact.ron @@ -300,6 +300,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("index_let_array_1d"), arguments: [ diff --git a/naga/tests/out/ir/index-by-value.ron b/naga/tests/out/ir/index-by-value.ron index f0ea76f496..93a9821426 100644 --- a/naga/tests/out/ir/index-by-value.ron +++ b/naga/tests/out/ir/index-by-value.ron @@ -300,6 +300,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("index_let_array_1d"), arguments: [ diff --git a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron index e762de0385..56be2f8ab6 100644 --- a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron +++ b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron @@ -85,6 +85,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("f"), arguments: [], diff --git a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron index e762de0385..56be2f8ab6 100644 --- a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron +++ b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron @@ -85,6 +85,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("f"), arguments: [], diff --git a/naga/tests/out/ir/overrides-ray-query.compact.ron b/naga/tests/out/ir/overrides-ray-query.compact.ron index f7d05aa92f..10cad83538 100644 --- a/naga/tests/out/ir/overrides-ray-query.compact.ron +++ b/naga/tests/out/ir/overrides-ray-query.compact.ron @@ -111,6 +111,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/overrides-ray-query.ron b/naga/tests/out/ir/overrides-ray-query.ron index f7d05aa92f..10cad83538 100644 --- a/naga/tests/out/ir/overrides-ray-query.ron +++ b/naga/tests/out/ir/overrides-ray-query.ron @@ -111,6 +111,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index d2df01c0db..d99beb19c6 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -108,6 +108,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index d2df01c0db..d99beb19c6 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -108,6 +108,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/shadow.compact.ron b/naga/tests/out/ir/shadow.compact.ron index 39a25fd10b..24b4674515 100644 --- a/naga/tests/out/ir/shadow.compact.ron +++ b/naga/tests/out/ir/shadow.compact.ron @@ -958,6 +958,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("fs_main_wrap"), arguments: [ diff --git a/naga/tests/out/ir/shadow.ron b/naga/tests/out/ir/shadow.ron index 196536d56b..386b9d36b0 100644 --- a/naga/tests/out/ir/shadow.ron +++ b/naga/tests/out/ir/shadow.ron @@ -1236,6 +1236,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("fs_main_wrap"), arguments: [ diff --git a/naga/tests/out/ir/spec-constants.compact.ron b/naga/tests/out/ir/spec-constants.compact.ron index 9ea75cd468..cde3117225 100644 --- a/naga/tests/out/ir/spec-constants.compact.ron +++ b/naga/tests/out/ir/spec-constants.compact.ron @@ -495,6 +495,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("main_wrap"), arguments: [ diff --git a/naga/tests/out/ir/spec-constants.ron b/naga/tests/out/ir/spec-constants.ron index 5d48e94efc..fa4139a1da 100644 --- a/naga/tests/out/ir/spec-constants.ron +++ b/naga/tests/out/ir/spec-constants.ron @@ -601,6 +601,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("main_wrap"), arguments: [ From 6474295a5542e0da6415bb8fc88ebdd6c5e6c94a Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sat, 30 Nov 2024 14:35:43 -0800 Subject: [PATCH 02/14] linting --- naga/src/back/pipeline_constants.rs | 16 ++++++++-------- naga/src/front/wgsl/lower/mod.rs | 11 ++++------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index cf7b84a342..1a2dc17023 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -208,28 +208,28 @@ pub fn process_overrides<'a>( fn process_workgroup_size_override( module: &mut Module, override_map: &HandleVec>, - ep: &mut crate::EntryPoint + ep: &mut crate::EntryPoint, ) -> Result<(), PipelineConstantError> { match ep.workgroup_size_overrides { None => {} Some(overrides) => { overrides.iter().enumerate().try_for_each( |(i, overridden)| -> Result<(), PipelineConstantError> { - match overridden { + match *overridden { None => Ok(()), Some(h) => { - let c = module.constants[override_map[*h]].init; + let c = module.constants[override_map[h]].init; let n = &module.global_expressions[c]; - match n { + match *n { crate::Expression::Literal(literal) => { ep.workgroup_size[i] = match literal { - crate::Literal::U32(m) => (*m).into(), + crate::Literal::U32(m) => m, crate::Literal::I32(m) => { - if *m < 0 { + if m < 0 { Err(PipelineConstantError::NegativeWorkgroupSize)?; unreachable!(); } else { - *m as u32 + m as u32 } } _ => { @@ -244,7 +244,7 @@ fn process_workgroup_size_override( Ok(()) } } - } + }, )?; } } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 8401df4ecd..7de6fc945b 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1325,13 +1325,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { if let Err(Error::ConstantEvaluatorError(ref ty, _)) = err { match **ty { crate::proc::ConstantEvaluatorError::OverrideExpr => { - workgroup_size_overrides_out[i] = Some( - self.workgroup_size_override( + workgroup_size_overrides_out[i] = + Some(self.workgroup_size_override( size_expr, &mut ctx.as_override(), i, - )? - ); + )?); } _ => { err?; @@ -1390,9 +1389,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, ) }), - _ => { - Err(Error::ExpectedConstExprConcreteIntegerScalar(span)) - } + _ => Err(Error::ExpectedConstExprConcreteIntegerScalar(span)), } } From 13f98184f1298c3d4967aba5b374b09b13abf3f1 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sat, 30 Nov 2024 14:45:22 -0800 Subject: [PATCH 03/14] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f60dff7eb1..ca7131118f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148] - Implement `quantizeToF16()` for WGSL frontend, and WGSL, SPIR-V, HLSL, MSL, and GLSL backends. By @jamienicol in [#6519](https://github.com/gfx-rs/wgpu/pull/6519). - Add support for GLSL `usampler*` and `isampler*`. By @DavidPeicho in [#6513](https://github.com/gfx-rs/wgpu/pull/6513). - Expose Ray Query flags as constants in WGSL. Implement candidate intersections. By @kvark in [#5429](https://github.com/gfx-rs/wgpu/pull/5429) +- Allow for override-expressions in `workgroup_size`. By @KentSlaney in [#6635](https://github.com/gfx-rs/wgpu/pull/6635). #### General From 59861d47d8e61a90b689338d9088b732cadb1f06 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 1 Dec 2024 13:06:22 -0800 Subject: [PATCH 04/14] integration test --- tests/tests/shader/mod.rs | 1 + .../tests/shader/workgroup_size_overrides.rs | 103 ++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 tests/tests/shader/workgroup_size_overrides.rs diff --git a/tests/tests/shader/mod.rs b/tests/tests/shader/mod.rs index 7d6ed7aaaa..f05fbac25c 100644 --- a/tests/tests/shader/mod.rs +++ b/tests/tests/shader/mod.rs @@ -19,6 +19,7 @@ pub mod compilation_messages; pub mod data_builtins; pub mod numeric_builtins; pub mod struct_layout; +pub mod workgroup_size_overrides; pub mod zero_init_workgroup_mem; #[derive(Clone, Copy, PartialEq)] diff --git a/tests/tests/shader/workgroup_size_overrides.rs b/tests/tests/shader/workgroup_size_overrides.rs new file mode 100644 index 0000000000..90f319847d --- /dev/null +++ b/tests/tests/shader/workgroup_size_overrides.rs @@ -0,0 +1,103 @@ +use std::mem::size_of_val; +use wgpu::util::DeviceExt; +use wgpu::{BufferDescriptor, BufferUsages, Maintain, MapMode}; +use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; + +const SHADER: &str = r#" + override n = 3; + + @group(0) @binding(0) + var output: array; + + @compute @workgroup_size(n - 2) + fn main(@builtin(local_invocation_index) lii: u32) { + output[lii] = lii + 2; + } +"#; + +#[gpu_test] +static WORKGROUP_SIZE_OVERRIDES: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters(TestParameters::default().limits(wgpu::Limits::default())) + .run_async(move |ctx| async move { + workgroup_size_overrides(&ctx, 0, &[2, 0, 0]).await; + workgroup_size_overrides(&ctx, 4, &[2, 3, 0]).await; + // Expected to fail during pipeline creation: + //workgroup_size_overrides(&ctx, 1, &[0, 0, 0]).await; + }); + +async fn workgroup_size_overrides(ctx: &TestingContext, n: u32, out: &[u32]) { + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(SHADER)), + }); + let pipeline_options = wgpu::PipelineCompilationOptions { + constants: &[("n".to_owned(), n.into())].into(), + ..Default::default() + }; + let compute_pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: Some("main"), + compilation_options: if n == 0 { + wgpu::PipelineCompilationOptions::default() + } else { + pipeline_options + }, + cache: None, + }); + let init: &[u32] = &[0, 0, 0]; + let init_size: u64 = size_of_val(init).try_into().unwrap(); + let buffer = DeviceExt::create_buffer_init( + &ctx.device, + &wgpu::util::BufferInitDescriptor { + label: None, + contents: bytemuck::cast_slice(init), + usage: wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::COPY_SRC, + }, + ); + let mapping_buffer = ctx.device.create_buffer(&BufferDescriptor { + label: Some("mapping buffer"), + size: init_size, + usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&compute_pipeline); + let bind_group_layout = compute_pipeline.get_bind_group_layout(0); + let bind_group_entries = [wgpu::BindGroupEntry { + binding: 0, + resource: buffer.as_entire_binding(), + }]; + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &bind_group_entries, + }); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups(1, 1, 1); + } + encoder.copy_buffer_to_buffer(&buffer, 0, &mapping_buffer, 0, init_size); + ctx.queue.submit(Some(encoder.finish())); + + mapping_buffer.slice(..).map_async(MapMode::Read, |_| ()); + ctx.async_poll(Maintain::wait()).await.panic_on_timeout(); + + let mapped = mapping_buffer.slice(..).get_mapped_range(); + + let typed: &[u32] = bytemuck::cast_slice(&mapped); + assert_eq!(typed, out); +} From 5fd355fc900faa9b6362c64de828e2201eace48c Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 1 Dec 2024 14:49:10 -0800 Subject: [PATCH 05/14] separate use_override from value and add should_fail --- .../tests/shader/workgroup_size_overrides.rs | 53 ++++++++++++------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/tests/tests/shader/workgroup_size_overrides.rs b/tests/tests/shader/workgroup_size_overrides.rs index 90f319847d..424e1201e0 100644 --- a/tests/tests/shader/workgroup_size_overrides.rs +++ b/tests/tests/shader/workgroup_size_overrides.rs @@ -1,7 +1,7 @@ use std::mem::size_of_val; use wgpu::util::DeviceExt; use wgpu::{BufferDescriptor, BufferUsages, Maintain, MapMode}; -use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; +use wgpu_test::{fail_if, gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; const SHADER: &str = r#" override n = 3; @@ -19,13 +19,18 @@ const SHADER: &str = r#" static WORKGROUP_SIZE_OVERRIDES: GpuTestConfiguration = GpuTestConfiguration::new() .parameters(TestParameters::default().limits(wgpu::Limits::default())) .run_async(move |ctx| async move { - workgroup_size_overrides(&ctx, 0, &[2, 0, 0]).await; - workgroup_size_overrides(&ctx, 4, &[2, 3, 0]).await; - // Expected to fail during pipeline creation: - //workgroup_size_overrides(&ctx, 1, &[0, 0, 0]).await; + workgroup_size_overrides(&ctx, false, 0, &[2, 0, 0], false).await; + workgroup_size_overrides(&ctx, true, 4, &[2, 3, 0], false).await; + workgroup_size_overrides(&ctx, true, 1, &[0, 0, 0], true).await; }); -async fn workgroup_size_overrides(ctx: &TestingContext, n: u32, out: &[u32]) { +async fn workgroup_size_overrides( + ctx: &TestingContext, + use_override: bool, + n: u32, + out: &[u32], + should_fail: bool, +) { let module = ctx .device .create_shader_module(wgpu::ShaderModuleDescriptor { @@ -36,20 +41,28 @@ async fn workgroup_size_overrides(ctx: &TestingContext, n: u32, out: &[u32]) { constants: &[("n".to_owned(), n.into())].into(), ..Default::default() }; - let compute_pipeline = ctx - .device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: None, - layout: None, - module: &module, - entry_point: Some("main"), - compilation_options: if n == 0 { - wgpu::PipelineCompilationOptions::default() - } else { - pipeline_options - }, - cache: None, - }); + let compute_pipeline = fail_if( + &ctx.device, + should_fail, + || { + ctx.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: Some("main"), + compilation_options: if use_override { + pipeline_options + } else { + wgpu::PipelineCompilationOptions::default() + }, + cache: None, + }) + }, + None + ); + if should_fail { + return; + } let init: &[u32] = &[0, 0, 0]; let init_size: u64 = size_of_val(init).try_into().unwrap(); let buffer = DeviceExt::create_buffer_init( From cfbda3197a2abff1d94ba20594f1b57effb0e95b Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 1 Dec 2024 14:50:25 -0800 Subject: [PATCH 06/14] linting --- .../tests/shader/workgroup_size_overrides.rs | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/tests/shader/workgroup_size_overrides.rs b/tests/tests/shader/workgroup_size_overrides.rs index 424e1201e0..038d95e02d 100644 --- a/tests/tests/shader/workgroup_size_overrides.rs +++ b/tests/tests/shader/workgroup_size_overrides.rs @@ -20,8 +20,8 @@ static WORKGROUP_SIZE_OVERRIDES: GpuTestConfiguration = GpuTestConfiguration::ne .parameters(TestParameters::default().limits(wgpu::Limits::default())) .run_async(move |ctx| async move { workgroup_size_overrides(&ctx, false, 0, &[2, 0, 0], false).await; - workgroup_size_overrides(&ctx, true, 4, &[2, 3, 0], false).await; - workgroup_size_overrides(&ctx, true, 1, &[0, 0, 0], true).await; + workgroup_size_overrides(&ctx, true, 4, &[2, 3, 0], false).await; + workgroup_size_overrides(&ctx, true, 1, &[0, 0, 0], true).await; }); async fn workgroup_size_overrides( @@ -45,20 +45,21 @@ async fn workgroup_size_overrides( &ctx.device, should_fail, || { - ctx.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: None, - layout: None, - module: &module, - entry_point: Some("main"), - compilation_options: if use_override { - pipeline_options - } else { - wgpu::PipelineCompilationOptions::default() - }, - cache: None, - }) + ctx.device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: Some("main"), + compilation_options: if use_override { + pipeline_options + } else { + wgpu::PipelineCompilationOptions::default() + }, + cache: None, + }) }, - None + None, ); if should_fail { return; From 79c1a4917310ed84cecd28308033e1dd0ab56c16 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 1 Dec 2024 15:27:22 -0800 Subject: [PATCH 07/14] n as an option instead of an option bool to use n --- tests/tests/shader/workgroup_size_overrides.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/tests/shader/workgroup_size_overrides.rs b/tests/tests/shader/workgroup_size_overrides.rs index 038d95e02d..458011079a 100644 --- a/tests/tests/shader/workgroup_size_overrides.rs +++ b/tests/tests/shader/workgroup_size_overrides.rs @@ -19,15 +19,14 @@ const SHADER: &str = r#" static WORKGROUP_SIZE_OVERRIDES: GpuTestConfiguration = GpuTestConfiguration::new() .parameters(TestParameters::default().limits(wgpu::Limits::default())) .run_async(move |ctx| async move { - workgroup_size_overrides(&ctx, false, 0, &[2, 0, 0], false).await; - workgroup_size_overrides(&ctx, true, 4, &[2, 3, 0], false).await; - workgroup_size_overrides(&ctx, true, 1, &[0, 0, 0], true).await; + workgroup_size_overrides(&ctx, None, &[2, 0, 0], false).await; + workgroup_size_overrides(&ctx, Some(4), &[2, 3, 0], false).await; + workgroup_size_overrides(&ctx, Some(1), &[0, 0, 0], true).await; }); async fn workgroup_size_overrides( ctx: &TestingContext, - use_override: bool, - n: u32, + n: Option, out: &[u32], should_fail: bool, ) { @@ -38,7 +37,7 @@ async fn workgroup_size_overrides( source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(SHADER)), }); let pipeline_options = wgpu::PipelineCompilationOptions { - constants: &[("n".to_owned(), n.into())].into(), + constants: &[("n".to_owned(), n.unwrap_or(0).into())].into(), ..Default::default() }; let compute_pipeline = fail_if( @@ -51,7 +50,7 @@ async fn workgroup_size_overrides( layout: None, module: &module, entry_point: Some("main"), - compilation_options: if use_override { + compilation_options: if n.is_some() { pipeline_options } else { wgpu::PipelineCompilationOptions::default() From 4f5f404b78a24f7edc131122ad8a0c46a8daee6a Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Tue, 3 Dec 2024 09:16:35 -0800 Subject: [PATCH 08/14] resolve workgroup_size using eval_expr_to_u32 --- naga/src/back/pipeline_constants.rs | 36 ++++++++++------------------- 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 1a2dc17023..214c78f66c 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -25,7 +25,7 @@ pub enum PipelineConstantError { ConstantEvaluatorError(#[from] ConstantEvaluatorError), #[error(transparent)] ValidationError(#[from] WithSpan), - #[error("workgroup_size was overridden to a negative value")] + #[error("workgroup_size overridde isn't strictly positive")] NegativeWorkgroupSize, } @@ -218,29 +218,17 @@ fn process_workgroup_size_override( match *overridden { None => Ok(()), Some(h) => { - let c = module.constants[override_map[h]].init; - let n = &module.global_expressions[c]; - match *n { - crate::Expression::Literal(literal) => { - ep.workgroup_size[i] = match literal { - crate::Literal::U32(m) => m, - crate::Literal::I32(m) => { - if m < 0 { - Err(PipelineConstantError::NegativeWorkgroupSize)?; - unreachable!(); - } else { - m as u32 - } - } - _ => { - unreachable!(); - } - }; - } - _ => { - unreachable!(); - } - } + ep.workgroup_size[i] = module + .to_ctx() + .eval_expr_to_u32(module.constants[override_map[h]].init) + .map(|n| { + if n == 0 { + Err(PipelineConstantError::NegativeWorkgroupSize) + } else { + Ok(n) + } + }) + .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??; Ok(()) } } From 25fcfc2e2f7e74ee4bde4d7ab89b383f6a72c25f Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Tue, 3 Dec 2024 09:22:28 -0800 Subject: [PATCH 09/14] document workgroup_size_overrides' arena --- naga/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 1c1929efa2..df93542648 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -2186,7 +2186,7 @@ pub struct EntryPoint { pub early_depth_test: Option, /// Workgroup size for compute stages pub workgroup_size: [u32; 3], - /// Override expressions for workgroup size + /// Override expressions in the global_expressions arena for workgroup size pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function, From 0f9f900c21801bf5460ba9439d1d92d752003bd0 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Tue, 3 Dec 2024 11:24:55 -0800 Subject: [PATCH 10/14] remove bad reference to drained override arena --- naga/src/back/pipeline_constants.rs | 3 ++- naga/src/valid/handles.rs | 5 +++++ tests/tests/shader/workgroup_size_overrides.rs | 3 ++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 214c78f66c..f99c27986c 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -25,7 +25,7 @@ pub enum PipelineConstantError { ConstantEvaluatorError(#[from] ConstantEvaluatorError), #[error(transparent)] ValidationError(#[from] WithSpan), - #[error("workgroup_size overridde isn't strictly positive")] + #[error("workgroup_size override isn't strictly positive")] NegativeWorkgroupSize, } @@ -234,6 +234,7 @@ fn process_workgroup_size_override( } }, )?; + ep.workgroup_size_overrides = None; } } Ok(()) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 680c2d3ba0..358f28ef5e 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -175,6 +175,11 @@ impl super::Validator { for entry_point in entry_points.iter() { validate_function(None, &entry_point.function)?; + if let Some(sizes) = entry_point.workgroup_size_overrides { + for size in sizes.iter().filter(|x| x.is_some()) { + Self::validate_override_handle(size.unwrap(), overrides)?; + } + } } for (function_handle, function) in functions.iter() { diff --git a/tests/tests/shader/workgroup_size_overrides.rs b/tests/tests/shader/workgroup_size_overrides.rs index 458011079a..c6fdcb1863 100644 --- a/tests/tests/shader/workgroup_size_overrides.rs +++ b/tests/tests/shader/workgroup_size_overrides.rs @@ -5,13 +5,14 @@ use wgpu_test::{fail_if, gpu_test, GpuTestConfiguration, TestParameters, Testing const SHADER: &str = r#" override n = 3; + const m = 2u; @group(0) @binding(0) var output: array; @compute @workgroup_size(n - 2) fn main(@builtin(local_invocation_index) lii: u32) { - output[lii] = lii + 2; + output[lii] = lii + m; } "#; From 3f73224cd6346201bb7291593c85a578237c6afc Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Tue, 3 Dec 2024 11:26:02 -0800 Subject: [PATCH 11/14] remove incorrect arena documentation --- naga/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/src/lib.rs b/naga/src/lib.rs index df93542648..1c1929efa2 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -2186,7 +2186,7 @@ pub struct EntryPoint { pub early_depth_test: Option, /// Workgroup size for compute stages pub workgroup_size: [u32; 3], - /// Override expressions in the global_expressions arena for workgroup size + /// Override expressions for workgroup size pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function, From 0d4a88cd3246852bd00675cb37f4be7fe6bed22f Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Tue, 3 Dec 2024 11:35:01 -0800 Subject: [PATCH 12/14] remove test debugging statement --- tests/tests/shader/workgroup_size_overrides.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/tests/shader/workgroup_size_overrides.rs b/tests/tests/shader/workgroup_size_overrides.rs index c6fdcb1863..458011079a 100644 --- a/tests/tests/shader/workgroup_size_overrides.rs +++ b/tests/tests/shader/workgroup_size_overrides.rs @@ -5,14 +5,13 @@ use wgpu_test::{fail_if, gpu_test, GpuTestConfiguration, TestParameters, Testing const SHADER: &str = r#" override n = 3; - const m = 2u; @group(0) @binding(0) var output: array; @compute @workgroup_size(n - 2) fn main(@builtin(local_invocation_index) lii: u32) { - output[lii] = lii + m; + output[lii] = lii + 2; } "#; From 9f8f1efab1bb3558d493f4c6b93d3bb8f72f3671 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Thu, 5 Dec 2024 08:54:36 -0800 Subject: [PATCH 13/14] use global_expressions directly instead of creating a faux override --- naga/src/back/pipeline_constants.rs | 6 +++--- naga/src/compact/mod.rs | 20 ++++++++++++++++++++ naga/src/front/wgsl/lower/mod.rs | 19 +++---------------- naga/src/lib.rs | 2 +- naga/src/valid/handles.rs | 4 ++-- 5 files changed, 29 insertions(+), 22 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index f99c27986c..b7a80cfbf7 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -192,7 +192,7 @@ pub fn process_overrides<'a>( let mut entry_points = mem::take(&mut module.entry_points); for ep in entry_points.iter_mut() { process_function(&mut module, &override_map, &mut ep.function)?; - process_workgroup_size_override(&mut module, &override_map, ep)?; + process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?; } module.entry_points = entry_points; @@ -207,7 +207,7 @@ pub fn process_overrides<'a>( fn process_workgroup_size_override( module: &mut Module, - override_map: &HandleVec>, + adjusted_global_expressions: &HandleVec>, ep: &mut crate::EntryPoint, ) -> Result<(), PipelineConstantError> { match ep.workgroup_size_overrides { @@ -220,7 +220,7 @@ fn process_workgroup_size_override( Some(h) => { ep.workgroup_size[i] = module .to_ctx() - .eval_expr_to_u32(module.constants[override_map[h]].init) + .eval_expr_to_u32(adjusted_global_expressions[h]) .map(|n| { if n == 0 { Err(PipelineConstantError::NegativeWorkgroupSize) diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index a9fc7bc945..fcd35de380 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -63,6 +63,14 @@ pub fn compact(module: &mut crate::Module) { } } + for e in module.entry_points.iter() { + if let Some(sizes) = e.workgroup_size_overrides { + for size in sizes.iter().filter_map(|x| *x) { + module_tracer.global_expressions_used.insert(size); + } + } + } + // We assume that all functions are used. // // Observe which types, constant expressions, constants, and @@ -176,6 +184,18 @@ pub fn compact(module: &mut crate::Module) { } } + // Adjust workgroup_size_overrides + log::trace!("adjusting workgroup_size_overrides"); + for e in module.entry_points.iter_mut() { + if let Some(sizes) = e.workgroup_size_overrides.as_mut() { + for size in sizes.iter_mut() { + if let Some(expr) = size.as_mut() { + module_map.global_expressions.adjust(expr); + } + } + } + } + // Adjust global variables' types and initializers. log::trace!("adjusting global variables"); for (_, global) in module.global_variables.iter_mut() { diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 7de6fc945b..f221ff97c6 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1329,7 +1329,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Some(self.workgroup_size_override( size_expr, &mut ctx.as_override(), - i, )?); } _ => { @@ -1372,23 +1371,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &mut self, size_expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, - i: usize, - ) -> Result, Error<'source>> { + ) -> Result, Error<'source>> { let span = ctx.ast_expressions.get_span(size_expr); let expr = self.expression(size_expr, ctx)?; - let ty = ctx.register_type(expr)?; - match ctx.module.types[ty].inner.scalar_kind().ok_or(0) { - Ok(crate::ScalarKind::Sint) | Ok(crate::ScalarKind::Uint) => Ok({ - ctx.module.overrides.append( - crate::Override { - name: Some(format!("__workgroup_size_{}", i)), - id: None, - ty, - init: Some(expr), - }, - span, - ) - }), + match resolve_inner!(ctx, expr).scalar_kind().ok_or(0) { + Ok(crate::ScalarKind::Sint) | Ok(crate::ScalarKind::Uint) => Ok(expr), _ => Err(Error::ExpectedConstExprConcreteIntegerScalar(span)), } } diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 1c1929efa2..0ddace7b48 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -2187,7 +2187,7 @@ pub struct EntryPoint { /// Workgroup size for compute stages pub workgroup_size: [u32; 3], /// Override expressions for workgroup size - pub workgroup_size_overrides: Option<[Option>; 3]>, + pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function, } diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 358f28ef5e..be4eb3dbac 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -176,8 +176,8 @@ impl super::Validator { for entry_point in entry_points.iter() { validate_function(None, &entry_point.function)?; if let Some(sizes) = entry_point.workgroup_size_overrides { - for size in sizes.iter().filter(|x| x.is_some()) { - Self::validate_override_handle(size.unwrap(), overrides)?; + for size in sizes.iter().filter_map(|x| *x) { + validate_const_expr(size)?; } } } From c3df8a1532e699bef792cd804db593edde377b13 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Thu, 5 Dec 2024 08:56:54 -0800 Subject: [PATCH 14/14] document workgroup_size_overrides' arena --- naga/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 0ddace7b48..6a7ab4bdff 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -2186,7 +2186,7 @@ pub struct EntryPoint { pub early_depth_test: Option, /// Workgroup size for compute stages pub workgroup_size: [u32; 3], - /// Override expressions for workgroup size + /// Override expressions for workgroup size in the global_expressions arena pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function,