diff --git a/CHANGELOG.md b/CHANGELOG.md index b763d71ef0..9086d982b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -116,6 +116,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148] - Allow for override-expressions in `workgroup_size`. By @KentSlaney in [#6635](https://github.com/gfx-rs/wgpu/pull/6635). - Add support for OpAtomicCompareExchange in SPIR-V frontend. By @schell in [#6590](https://github.com/gfx-rs/wgpu/pull/6590). - Implement type inference for abstract arguments to user-defined functions. By @jamienicol in [#6577](https://github.com/gfx-rs/wgpu/pull/6577). +- Allow for override-expressions in array sizes. By @KentSlaney in [#6654](https://github.com/gfx-rs/wgpu/pull/6654). #### General diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index e8a5a1d6ad..4cd60fc3cc 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -978,6 +978,7 @@ impl<'a, W: Write> Writer<'a, W> { crate::ArraySize::Constant(size) => { write!(self.out, "{size}")?; } + crate::ArraySize::Pending(_) => unreachable!(), crate::ArraySize::Dynamic => (), } @@ -4459,6 +4460,7 @@ impl<'a, W: Write> Writer<'a, W> { .expect("Bad array size") { proc::IndexableLength::Known(count) => count, + proc::IndexableLength::Pending => unreachable!(), proc::IndexableLength::Dynamic => return Ok(()), }; self.write_type(base)?; diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index 5345f0a4d6..83c7667eab 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -68,6 +68,7 @@ impl crate::TypeInner { let count = match size { crate::ArraySize::Constant(size) => size.get(), // A dynamically-sized array has to have at least one element + crate::ArraySize::Pending(_) => unreachable!(), crate::ArraySize::Dynamic => 1, }; let last_el_size = gctx.types[base].inner.size_hlsl(gctx); diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 236a7bc796..bc6086d539 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -988,6 +988,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::ArraySize::Constant(size) => { write!(self.out, "{size}")?; } + crate::ArraySize::Pending(_) => unreachable!(), crate::ArraySize::Dynamic => unreachable!(), } @@ -2634,6 +2635,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { index::IndexableLength::Known(limit) => { write!(self.out, "{}u", limit - 1)?; } + index::IndexableLength::Pending => unreachable!(), index::IndexableLength::Dynamic => unreachable!(), } write!(self.out, ")")?; diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index c0916dc796..c119823800 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -2407,6 +2407,7 @@ impl Writer { self.out.write_str(") < ")?; match length { index::IndexableLength::Known(value) => write!(self.out, "{value}")?, + index::IndexableLength::Pending => unreachable!(), index::IndexableLength::Dynamic => { let global = context.function.originating_global(base).ok_or_else(|| { @@ -2569,6 +2570,7 @@ impl Writer { index::IndexableLength::Known(limit) => { write!(self.out, "{}u", limit - 1)?; } + index::IndexableLength::Pending => unreachable!(), index::IndexableLength::Dynamic => { let global = context.function.originating_global(base).ok_or_else(|| { Error::GenericValidation("Could not find originating global".into()) @@ -3740,6 +3742,9 @@ impl Writer { )?; writeln!(self.out, "}};")?; } + crate::ArraySize::Pending(_) => { + unreachable!() + } crate::ArraySize::Dynamic => { writeln!(self.out, "typedef {base_name} {name}[1];")?; } @@ -6008,6 +6013,7 @@ mod workgroup_mem_init { let count = match size.to_indexable_length(module).expect("Bad array size") { proc::IndexableLength::Known(count) => count, + proc::IndexableLength::Pending => unreachable!(), proc::IndexableLength::Dynamic => unreachable!(), }; diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index b7a80cfbf7..eb01dd5feb 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -196,6 +196,8 @@ pub fn process_overrides<'a>( } module.entry_points = entry_points; + process_pending(&mut module, &override_map, &adjusted_global_expressions)?; + // Now that we've rewritten all the expressions, we need to // recompute their types and other metadata. For the time being, // do a full re-validation. @@ -205,6 +207,64 @@ pub fn process_overrides<'a>( Ok((Cow::Owned(module), Cow::Owned(module_info))) } +fn process_pending( + module: &mut Module, + override_map: &HandleVec>, + adjusted_global_expressions: &HandleVec>, +) -> Result<(), PipelineConstantError> { + for (handle, ty) in module.types.clone().iter() { + if let crate::TypeInner::Array { + base, + size: crate::ArraySize::Pending(size), + stride, + } = ty.inner + { + let expr = match size { + crate::PendingArraySize::Expression(size_expr) => { + adjusted_global_expressions[size_expr] + } + crate::PendingArraySize::Override(size_override) => { + module.constants[override_map[size_override]].init + } + }; + let value = module + .to_ctx() + .eval_expr_to_u32(expr) + .map(|n| { + if n == 0 { + Err(PipelineConstantError::ValidationError( + WithSpan::new(ValidationError::ArraySizeError { handle: expr }) + .with_span( + module.global_expressions.get_span(expr), + "evaluated to zero", + ), + )) + } else { + Ok(std::num::NonZeroU32::new(n).unwrap()) + } + }) + .map_err(|_| { + PipelineConstantError::ValidationError( + WithSpan::new(ValidationError::ArraySizeError { handle: expr }) + .with_span(module.global_expressions.get_span(expr), "negative"), + ) + })??; + module.types.replace( + handle, + crate::Type { + name: None, + inner: crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(value), + stride, + }, + }, + ); + } + } + Ok(()) +} + fn process_workgroup_size_override( module: &mut Module, adjusted_global_expressions: &HandleVec>, diff --git a/naga/src/back/spv/index.rs b/naga/src/back/spv/index.rs index bd91aa4025..15e5df3f10 100644 --- a/naga/src/back/spv/index.rs +++ b/naga/src/back/spv/index.rs @@ -271,6 +271,9 @@ impl BlockContext<'_> { Ok(crate::proc::IndexableLength::Known(known_length)) => { Ok(MaybeKnown::Known(known_length)) } + Ok(crate::proc::IndexableLength::Pending) => { + unreachable!() + } Ok(crate::proc::IndexableLength::Dynamic) => { let length_id = self.write_runtime_array_length(sequence, block)?; Ok(MaybeKnown::Computed(length_id)) diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index b3dd145321..47f3ec513b 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -971,6 +971,7 @@ impl Writer { let length_id = self.get_index_constant(length.get()); Instruction::type_array(id, type_id, length_id) } + crate::ArraySize::Pending(_) => unreachable!(), crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id), } } @@ -981,6 +982,7 @@ impl Writer { let length_id = self.get_index_constant(length.get()); Instruction::type_array(id, type_id, length_id) } + crate::ArraySize::Pending(_) => unreachable!(), crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id), } } diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 6e7ac0bf5c..ed581c59e2 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -520,6 +520,9 @@ impl Writer { self.write_type(module, base)?; write!(self.out, ", {len}")?; } + crate::ArraySize::Pending(_) => { + unreachable!(); + } crate::ArraySize::Dynamic => { self.write_type(module, base)?; } @@ -534,6 +537,9 @@ impl Writer { self.write_type(module, base)?; write!(self.out, ", {len}")?; } + crate::ArraySize::Pending(_) => { + unreachable!(); + } crate::ArraySize::Dynamic => { self.write_type(module, base)?; } diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index fcd35de380..9dff4a6cc2 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -63,6 +63,16 @@ pub fn compact(module: &mut crate::Module) { } } + for (_, ty) in module.types.iter() { + if let crate::TypeInner::Array { + size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(size_expr)), + .. + } = ty.inner + { + module_tracer.global_expressions_used.insert(size_expr); + } + } + for e in module.entry_points.iter() { if let Some(sizes) = e.workgroup_size_overrides { for size in sizes.iter().filter_map(|x| *x) { @@ -206,6 +216,30 @@ pub fn compact(module: &mut crate::Module) { } } + for (handle, ty) in module.types.clone().iter() { + if let crate::TypeInner::Array { + base, + size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(mut size_expr)), + stride, + } = ty.inner + { + module_map.global_expressions.adjust(&mut size_expr); + module.types.replace( + handle, + crate::Type { + name: None, + inner: crate::TypeInner::Array { + base, + size: crate::ArraySize::Pending(crate::PendingArraySize::Expression( + size_expr, + )), + stride, + }, + }, + ); + } + } + // Temporary storage to help us reuse allocations of existing // named expression tables. let mut reused_named_expressions = crate::NamedExpressions::default(); diff --git a/naga/src/front/glsl/offset.rs b/naga/src/front/glsl/offset.rs index c88c46598d..6e8d5ada10 100644 --- a/naga/src/front/glsl/offset.rs +++ b/naga/src/front/glsl/offset.rs @@ -84,6 +84,7 @@ pub fn calculate_offset( let span = match size { crate::ArraySize::Constant(size) => size.get() * stride, + crate::ArraySize::Pending(_) => unreachable!(), crate::ArraySize::Dynamic => stride, }; diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 63406d3220..91938d69fb 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -1329,6 +1329,9 @@ impl> Frontend { crate::TypeInner::Array { size, .. } => { let size = match size { crate::ArraySize::Constant(size) => size.get(), + crate::ArraySize::Pending(_) => { + unreachable!(); + } // A runtime sized array is not a composite type crate::ArraySize::Dynamic => { return Err(Error::InvalidAccessType(root_type_id)) diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 95a4902d16..fc31e43ecf 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -3056,26 +3056,69 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(match size { ast::ArraySize::Constant(expr) => { let span = ctx.ast_expressions.get_span(expr); - let const_expr = self.expression(expr, &mut ctx.as_const())?; - let len = - ctx.module - .to_ctx() - .eval_expr_to_u32(const_expr) - .map_err(|err| match err { - crate::proc::U32EvalError::NonConst => { - Error::ExpectedConstExprConcreteIntegerScalar(span) - } - crate::proc::U32EvalError::Negative => { - Error::ExpectedPositiveArrayLength(span) + let const_expr = self.expression(expr, &mut ctx.as_const()); + match const_expr { + Ok(value) => { + let len = + ctx.module.to_ctx().eval_expr_to_u32(value).map_err( + |err| match err { + crate::proc::U32EvalError::NonConst => { + Error::ExpectedConstExprConcreteIntegerScalar(span) + } + crate::proc::U32EvalError::Negative => { + Error::ExpectedPositiveArrayLength(span) + } + }, + )?; + let size = + NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?; + crate::ArraySize::Constant(size) + } + err => { + if let Err(Error::ConstantEvaluatorError(ref ty, _)) = err { + match **ty { + crate::proc::ConstantEvaluatorError::OverrideExpr => { + crate::ArraySize::Pending(self.array_size_override( + expr, + &mut ctx.as_override(), + span, + )?) + } + _ => { + err?; + unreachable!() + } } - })?; - let size = NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?; - crate::ArraySize::Constant(size) + } else { + err?; + unreachable!() + } + } + } } ast::ArraySize::Dynamic => crate::ArraySize::Dynamic, }) } + fn array_size_override( + &mut self, + size_expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, + span: Span, + ) -> Result> { + let expr = self.expression(size_expr, ctx)?; + match resolve_inner!(ctx, expr).scalar_kind().ok_or(0) { + Ok(crate::ScalarKind::Sint) | Ok(crate::ScalarKind::Uint) => Ok({ + if let crate::Expression::Override(handle) = ctx.module.global_expressions[expr] { + crate::PendingArraySize::Override(handle) + } else { + crate::PendingArraySize::Expression(expr) + } + }), + _ => Err(Error::ExpectedConstExprConcreteIntegerScalar(span)), + } + } + /// Build the Naga equivalent of a named AST type. /// /// Return a Naga `Handle` representing the front-end type diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index b9a08cc41e..0233347c36 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -134,7 +134,7 @@ impl<'a> ExpressionContext<'a, '_, '_> { /// This is used for error checking. `Parser` maintains a stack of /// these and (occasionally) checks that it is being pushed and popped /// as expected. -#[derive(Clone, Debug, PartialEq)] +#[derive(Copy, Clone, Debug, PartialEq)] enum Rule { Attribute, VariableDecl, @@ -147,6 +147,8 @@ enum Rule { UnaryExpr, GeneralExpr, Directive, + GenericExpr, + EnclosedExpr, } struct ParsedAttribute { @@ -284,6 +286,16 @@ impl Parser { lexer.span_from(initial) } + fn race_rules(&self, rule0: Rule, rule1: Rule) -> Option { + Some( + self.rules + .iter() + .rev() + .find(|&x| x.0 == rule0 || x.0 == rule1)? + .0, + ) + } + fn switch_value<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -547,7 +559,7 @@ impl Parser { lexer.expect_generic_paren('<')?; let base = self.type_decl(lexer, ctx)?; let size = if lexer.skip(Token::Separator(',')) { - let expr = self.unary_expression(lexer, ctx)?; + let expr = self.const_generic_expression(lexer, ctx)?; ast::ArraySize::Constant(expr) } else { ast::ArraySize::Dynamic @@ -566,6 +578,7 @@ impl Parser { lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>>, Error<'a>> { + self.push_rule_span(Rule::EnclosedExpr, lexer); lexer.open_arguments()?; let mut arguments = Vec::new(); loop { @@ -580,9 +593,21 @@ impl Parser { arguments.push(arg); } + self.pop_rule_span(lexer); Ok(arguments) } + fn enclosed_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result>, Error<'a>> { + self.push_rule_span(Rule::EnclosedExpr, lexer); + let expr = self.general_expression(lexer, ctx)?; + self.pop_rule_span(lexer); + Ok(expr) + } + /// Expects [`Rule::PrimaryExpr`] or [`Rule::SingularExpr`] on top; does not pop it. /// Expects `name` to be consumed (not in lexer). fn function_call<'a>( @@ -667,7 +692,7 @@ impl Parser { let expr = match lexer.peek() { (Token::Paren('('), _) => { let _ = lexer.next(); - let expr = self.general_expression(lexer, ctx)?; + let expr = self.enclosed_expression(lexer, ctx)?; lexer.expect(Token::Paren(')'))?; self.pop_rule_span(lexer); return Ok(expr); @@ -803,7 +828,7 @@ impl Parser { } Token::Paren('[') => { let _ = lexer.next(); - let index = self.general_expression(lexer, ctx)?; + let index = self.enclosed_expression(lexer, ctx)?; lexer.expect(Token::Paren(']'))?; ast::Expression::Index { base: expr, index } @@ -818,6 +843,17 @@ impl Parser { Ok(expr) } + fn const_generic_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result>, Error<'a>> { + self.push_rule_span(Rule::GenericExpr, lexer); + let expr = self.general_expression(lexer, ctx)?; + self.pop_rule_span(lexer); + Ok(expr) + } + /// Parse a `unary_expression`. fn unary_expression<'a>( &mut self, @@ -908,27 +944,47 @@ impl Parser { }, // relational_expression |lexer, context| { + let enclosing = self.race_rules(Rule::GenericExpr, Rule::EnclosedExpr); context.parse_binary_op( lexer, - |token| match token { - Token::Paren('<') => Some(crate::BinaryOperator::Less), - Token::Paren('>') => Some(crate::BinaryOperator::Greater), - Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual), - Token::LogicalOperation('>') => Some(crate::BinaryOperator::GreaterEqual), - _ => None, + match enclosing { + Some(Rule::GenericExpr) => |token| match token { + Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual), + Token::LogicalOperation('>') => { + Some(crate::BinaryOperator::GreaterEqual) + } + _ => None, + }, + _ => |token| match token { + Token::Paren('<') => Some(crate::BinaryOperator::Less), + Token::Paren('>') => Some(crate::BinaryOperator::Greater), + Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual), + Token::LogicalOperation('>') => { + Some(crate::BinaryOperator::GreaterEqual) + } + _ => None, + }, }, // shift_expression |lexer, context| { context.parse_binary_op( lexer, - |token| match token { - Token::ShiftOperation('<') => { - Some(crate::BinaryOperator::ShiftLeft) - } - Token::ShiftOperation('>') => { - Some(crate::BinaryOperator::ShiftRight) - } - _ => None, + match enclosing { + Some(Rule::GenericExpr) => |token| match token { + Token::ShiftOperation('<') => { + Some(crate::BinaryOperator::ShiftLeft) + } + _ => None, + }, + _ => |token| match token { + Token::ShiftOperation('<') => { + Some(crate::BinaryOperator::ShiftLeft) + } + Token::ShiftOperation('>') => { + Some(crate::BinaryOperator::ShiftRight) + } + _ => None, + }, }, // additive_expression |lexer, context| { @@ -1364,7 +1420,7 @@ impl Parser { lexer.expect_generic_paren('<')?; let base = self.type_decl(lexer, ctx)?; let size = if lexer.skip(Token::Separator(',')) { - let size = self.unary_expression(lexer, ctx)?; + let size = self.const_generic_expression(lexer, ctx)?; ast::ArraySize::Constant(size) } else { ast::ArraySize::Dynamic diff --git a/naga/src/front/wgsl/to_wgsl.rs b/naga/src/front/wgsl/to_wgsl.rs index 0884e0003b..4d401b0708 100644 --- a/naga/src/front/wgsl/to_wgsl.rs +++ b/naga/src/front/wgsl/to_wgsl.rs @@ -67,6 +67,7 @@ impl crate::TypeInner { let base = base.to_wgsl(gctx); match size { crate::ArraySize::Constant(size) => format!("array<{base}, {size}>"), + crate::ArraySize::Pending(_) => unreachable!(), crate::ArraySize::Dynamic => format!("array<{base}>"), } } @@ -123,6 +124,7 @@ impl crate::TypeInner { let base = member_type.name.as_deref().unwrap_or("unknown"); match size { crate::ArraySize::Constant(size) => format!("binding_array<{base}, {size}>"), + crate::ArraySize::Pending(_) => unreachable!(), crate::ArraySize::Dynamic => format!("binding_array<{base}>"), } } diff --git a/naga/src/lib.rs b/naga/src/lib.rs index cfed0c38fd..687dc5b441 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -487,6 +487,15 @@ pub struct Scalar { pub width: Bytes, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum PendingArraySize { + Expression(Handle), + Override(Handle), +} + /// Size of an array. #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -496,6 +505,8 @@ pub struct Scalar { pub enum ArraySize { /// The array size is constant. Constant(std::num::NonZeroU32), + /// The array size is an override-expression. + Pending(PendingArraySize), /// The array size can change at runtime. Dynamic, } diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 585174384c..2baf918118 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -509,6 +509,8 @@ pub enum ConstantEvaluatorError { InvalidArrayLengthArg, #[error("Constants cannot get the array length of a dynamically sized array")] ArrayLengthDynamic, + #[error("Cannot call arrayLength on array sized by override-expression")] + ArrayLengthOverridden, #[error("Constants cannot call functions")] Call, #[error("Constants don't support workGroupUniformLoad")] @@ -1311,6 +1313,7 @@ impl<'a> ConstantEvaluator<'a> { let expr = Expression::Literal(Literal::U32(len.get())); self.register_evaluated_expr(expr, span) } + ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden), ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic), }, _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg), diff --git a/naga/src/proc/index.rs b/naga/src/proc/index.rs index 77fdf92d0e..ac2a6589c1 100644 --- a/naga/src/proc/index.rs +++ b/naga/src/proc/index.rs @@ -416,6 +416,8 @@ pub enum IndexableLength { /// Values of this type always have the given number of elements. Known(u32), + Pending, + /// The number of elements is determined at runtime. Dynamic, } @@ -427,6 +429,7 @@ impl crate::ArraySize { ) -> Result { Ok(match self { Self::Constant(length) => IndexableLength::Known(length.get()), + Self::Pending(_) => IndexableLength::Pending, Self::Dynamic => IndexableLength::Dynamic, }) } diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index b6f7a55ef9..76698fd102 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -297,6 +297,9 @@ impl super::TypeInner { } => { let count = match size { super::ArraySize::Constant(count) => count.get(), + // any struct member or array element needing a size at pipeline-creation time + // must have a creation-fixed footprint + super::ArraySize::Pending(_) => 0, // A dynamically-sized array has to have at least one element super::ArraySize::Dynamic => 1, }; diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index ce9bc12a00..335826d12c 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -497,7 +497,10 @@ impl super::Validator { if access == crate::StorageAccess::STORE { return Err(GlobalVariableError::StorageAddressSpaceWriteOnlyNotSupported); } - (TypeFlags::DATA | TypeFlags::HOST_SHAREABLE, true) + ( + TypeFlags::DATA | TypeFlags::HOST_SHAREABLE | TypeFlags::CREATION_RESOLVED, + true, + ) } crate::AddressSpace::Uniform => { if let Err((ty_handle, disalignment)) = type_info.uniform_layout { @@ -513,7 +516,8 @@ impl super::Validator { TypeFlags::DATA | TypeFlags::COPY | TypeFlags::SIZED - | TypeFlags::HOST_SHAREABLE, + | TypeFlags::HOST_SHAREABLE + | TypeFlags::CREATION_RESOLVED, true, ) } @@ -551,7 +555,10 @@ impl super::Validator { (TypeFlags::empty(), true) } - crate::AddressSpace::Private => (TypeFlags::CONSTRUCTIBLE, false), + crate::AddressSpace::Private => ( + TypeFlags::CONSTRUCTIBLE | TypeFlags::CREATION_RESOLVED, + false, + ), crate::AddressSpace::WorkGroup => (TypeFlags::DATA | TypeFlags::SIZED, false), crate::AddressSpace::PushConstant => { if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) { diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index c314ec2ac8..6a81bd7c2d 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -332,6 +332,8 @@ pub enum ValidationError { handle: Handle, source: ConstExpressionError, }, + #[error("Array size expression {handle:?} is not strictly positive")] + ArraySizeError { handle: Handle }, #[error("Constant {handle:?} '{name}' is invalid")] Constant { handle: Handle, @@ -612,6 +614,20 @@ impl Validator { } .with_span_handle(handle, &module.types) })?; + if !self.allow_overrides { + if let crate::TypeInner::Array { + size: crate::ArraySize::Pending(_), + .. + } = ty.inner + { + return Err((ValidationError::Type { + handle, + name: ty.name.clone().unwrap_or_default(), + source: TypeError::UnresolvedOverride(handle), + }) + .with_span_handle(handle, &module.types)); + } + } mod_info.type_flags.push(ty_info.flags); self.types[handle.index()] = ty_info; } diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index c0c25dab79..35158b8013 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -54,6 +54,10 @@ bitflags::bitflags! { /// Can be used for host-shareable structures. const HOST_SHAREABLE = 0x10; + /// The set of types with a fixed size at shader-creation time (ie. everything + /// except arrays sized by an override-expression) + const CREATION_RESOLVED = 0x20; + /// This type can be passed as a function argument. const ARGUMENT = 0x40; @@ -142,6 +146,10 @@ pub enum TypeError { EmptyStruct, #[error(transparent)] WidthError(#[from] WidthError), + #[error( + "The base handle {0:?} has an override-expression that didn't get resolved to a constant" + )] + UnresolvedOverride(Handle), } #[derive(Clone, Debug, thiserror::Error)] @@ -319,6 +327,7 @@ impl super::Validator { | TypeFlags::COPY | TypeFlags::ARGUMENT | TypeFlags::CONSTRUCTIBLE + | TypeFlags::CREATION_RESOLVED | shareable, Alignment::from_width(scalar.width), ) @@ -336,6 +345,7 @@ impl super::Validator { | TypeFlags::COPY | TypeFlags::ARGUMENT | TypeFlags::CONSTRUCTIBLE + | TypeFlags::CREATION_RESOLVED | shareable, Alignment::from(size) * Alignment::from_width(scalar.width), ) @@ -355,7 +365,8 @@ impl super::Validator { | TypeFlags::COPY | TypeFlags::HOST_SHAREABLE | TypeFlags::ARGUMENT - | TypeFlags::CONSTRUCTIBLE, + | TypeFlags::CONSTRUCTIBLE + | TypeFlags::CREATION_RESOLVED, Alignment::from(rows) * Alignment::from_width(scalar.width), ) } @@ -383,7 +394,10 @@ impl super::Validator { } }; TypeInfo::new( - TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE, + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::HOST_SHAREABLE + | TypeFlags::CREATION_RESOLVED, Alignment::from_width(width), ) } @@ -424,7 +438,10 @@ impl super::Validator { // Pointers cannot be stored in variables, structure members, or // array elements, so we do not mark them as `DATA`. TypeInfo::new( - argument_flag | TypeFlags::SIZED | TypeFlags::COPY, + argument_flag + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::CREATION_RESOLVED, Alignment::ONE, ) } @@ -451,13 +468,19 @@ impl super::Validator { // Pointers cannot be stored in variables, structure members, or // array elements, so we do not mark them as `DATA`. TypeInfo::new( - argument_flag | TypeFlags::SIZED | TypeFlags::COPY, + argument_flag + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::CREATION_RESOLVED, Alignment::ONE, ) } Ti::Array { base, size, stride } => { let base_info = &self.types[base.index()]; - if !base_info.flags.contains(TypeFlags::DATA | TypeFlags::SIZED) { + if !base_info + .flags + .contains(TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::CREATION_RESOLVED) + { return Err(TypeError::InvalidArrayBaseType(base)); } @@ -496,12 +519,23 @@ impl super::Validator { | TypeFlags::HOST_SHAREABLE | TypeFlags::ARGUMENT | TypeFlags::CONSTRUCTIBLE + | TypeFlags::CREATION_RESOLVED + } + crate::ArraySize::Pending(_) => { + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::ARGUMENT } crate::ArraySize::Dynamic => { // Non-SIZED types may only appear as the last element of a structure. // This is enforced by checks for SIZED-ness for all compound types, // and a special case for structs. - TypeFlags::DATA | TypeFlags::COPY | TypeFlags::HOST_SHAREABLE + TypeFlags::DATA + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::CREATION_RESOLVED } }; @@ -523,7 +557,8 @@ impl super::Validator { | TypeFlags::HOST_SHAREABLE | TypeFlags::IO_SHAREABLE | TypeFlags::ARGUMENT - | TypeFlags::CONSTRUCTIBLE, + | TypeFlags::CONSTRUCTIBLE + | TypeFlags::CREATION_RESOLVED, Alignment::ONE, ); ti.uniform_layout = Ok(Alignment::MIN_UNIFORM); @@ -533,7 +568,10 @@ impl super::Validator { for (i, member) in members.iter().enumerate() { let base_info = &self.types[member.ty.index()]; - if !base_info.flags.contains(TypeFlags::DATA) { + if !base_info + .flags + .contains(TypeFlags::DATA | TypeFlags::CREATION_RESOLVED) + { return Err(TypeError::InvalidData(member.ty)); } if !base_info.flags.contains(TypeFlags::HOST_SHAREABLE) { @@ -649,26 +687,41 @@ impl super::Validator { if arrayed && matches!(dim, crate::ImageDimension::Cube) { self.require_type_capability(Capabilities::CUBE_ARRAY_TEXTURES)?; } - TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE) + TypeInfo::new( + TypeFlags::ARGUMENT | TypeFlags::CREATION_RESOLVED, + Alignment::ONE, + ) } - Ti::Sampler { .. } => TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE), + Ti::Sampler { .. } => TypeInfo::new( + TypeFlags::ARGUMENT | TypeFlags::CREATION_RESOLVED, + Alignment::ONE, + ), Ti::AccelerationStructure => { self.require_type_capability(Capabilities::RAY_QUERY)?; - TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE) + TypeInfo::new( + TypeFlags::ARGUMENT | TypeFlags::CREATION_RESOLVED, + Alignment::ONE, + ) } Ti::RayQuery => { self.require_type_capability(Capabilities::RAY_QUERY)?; TypeInfo::new( - TypeFlags::DATA | TypeFlags::CONSTRUCTIBLE | TypeFlags::SIZED, + TypeFlags::DATA + | TypeFlags::CONSTRUCTIBLE + | TypeFlags::SIZED + | TypeFlags::CREATION_RESOLVED, Alignment::ONE, ) } Ti::BindingArray { base, size } => { let type_info_mask = match size { - crate::ArraySize::Constant(_) => TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE, + crate::ArraySize::Constant(_) => { + TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE | TypeFlags::CREATION_RESOLVED + } + crate::ArraySize::Pending(_) => TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE, crate::ArraySize::Dynamic => { // Final type is non-sized - TypeFlags::HOST_SHAREABLE + TypeFlags::HOST_SHAREABLE | TypeFlags::CREATION_RESOLVED } }; let base_info = &self.types[base.index()]; @@ -681,6 +734,10 @@ impl super::Validator { }; } + if !base_info.flags.contains(TypeFlags::CREATION_RESOLVED) { + return Err(TypeError::InvalidData(base)); + } + TypeInfo::new(base_info.flags & type_info_mask, Alignment::ONE) } }) diff --git a/naga/tests/out/analysis/access.info.ron b/naga/tests/out/analysis/access.info.ron index 886f0bf9ed..8948cb3a0a 100644 --- a/naga/tests/out/analysis/access.info.ron +++ b/naga/tests/out/analysis/access.info.ron @@ -1,38 +1,38 @@ ( type_flags: [ - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | HOST_SHAREABLE"), - ("DATA | SIZED | HOST_SHAREABLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | COPY | HOST_SHAREABLE"), - ("DATA | HOST_SHAREABLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("SIZED | COPY | ARGUMENT"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("SIZED | COPY | ARGUMENT"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("SIZED | COPY | ARGUMENT"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("SIZED | COPY | ARGUMENT"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("SIZED | COPY | ARGUMENT"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | HOST_SHAREABLE | CREATION_RESOLVED"), + ("DATA | SIZED | HOST_SHAREABLE | CREATION_RESOLVED"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | COPY | HOST_SHAREABLE | CREATION_RESOLVED"), + ("DATA | HOST_SHAREABLE | CREATION_RESOLVED"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("SIZED | COPY | CREATION_RESOLVED | ARGUMENT"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("SIZED | COPY | CREATION_RESOLVED | ARGUMENT"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("SIZED | COPY | CREATION_RESOLVED | ARGUMENT"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("SIZED | COPY | CREATION_RESOLVED | ARGUMENT"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("SIZED | COPY | CREATION_RESOLVED | ARGUMENT"), ], functions: [ ( diff --git a/naga/tests/out/analysis/collatz.info.ron b/naga/tests/out/analysis/collatz.info.ron index 6e7dd37bed..7ec5799d75 100644 --- a/naga/tests/out/analysis/collatz.info.ron +++ b/naga/tests/out/analysis/collatz.info.ron @@ -1,9 +1,9 @@ ( type_flags: [ - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | COPY | HOST_SHAREABLE"), - ("DATA | COPY | HOST_SHAREABLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | COPY | HOST_SHAREABLE | CREATION_RESOLVED"), + ("DATA | COPY | HOST_SHAREABLE | CREATION_RESOLVED"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), ], functions: [ ( diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index 0bb10336c8..835525e52d 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -1,7 +1,7 @@ ( type_flags: [ - ("DATA | SIZED | COPY | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), ], functions: [], entry_points: [ diff --git a/naga/tests/out/analysis/shadow.info.ron b/naga/tests/out/analysis/shadow.info.ron index e7a122dc7a..3d6841fdd4 100644 --- a/naga/tests/out/analysis/shadow.info.ron +++ b/naga/tests/out/analysis/shadow.info.ron @@ -1,19 +1,19 @@ ( type_flags: [ - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("ARGUMENT"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), - ("DATA | COPY | HOST_SHAREABLE"), - ("DATA | COPY | HOST_SHAREABLE"), - ("ARGUMENT"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("CREATION_RESOLVED | ARGUMENT"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | COPY | HOST_SHAREABLE | CREATION_RESOLVED"), + ("DATA | COPY | HOST_SHAREABLE | CREATION_RESOLVED"), + ("CREATION_RESOLVED | ARGUMENT"), ], functions: [ ( diff --git a/tests/tests/shader/array_size_overrides.rs b/tests/tests/shader/array_size_overrides.rs new file mode 100644 index 0000000000..7f1d324254 --- /dev/null +++ b/tests/tests/shader/array_size_overrides.rs @@ -0,0 +1,131 @@ +use std::mem::size_of_val; +use wgpu::util::DeviceExt; +use wgpu::{BufferDescriptor, BufferUsages, Maintain, MapMode}; +use wgpu_test::{fail_if, gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; + +const SHADER: &str = r#" + override n = 8; + + var arr: array; + + @group(0) @binding(0) + var output: array; + + @compute @workgroup_size(1) fn main() { + // 1d spiral + for (var i = 0; i < n - 2; i++) { + arr[i] = u32(n - 2 - i); + if (i + 1 < (n + (n % 2)) / 2) { + arr[i] -= 1u; + } + } + var i = 0u; + var j = 1u; + while (i != j) { + // non-commutative + output[0] = output[0] * arr[i] + arr[i]; + j = i; + i = arr[i]; + } + } +"#; + +#[gpu_test] +static ARRAY_SIZE_OVERRIDES: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters(TestParameters::default().limits(wgpu::Limits::default())) + .run_async(move |ctx| async move { + array_size_overrides(&ctx, None, &[534], false).await; + array_size_overrides(&ctx, Some(14), &[286480122], false).await; + array_size_overrides(&ctx, Some(1), &[0], true).await; + }); + +async fn array_size_overrides( + ctx: &TestingContext, + n: Option, + out: &[u32], + should_fail: bool, +) { + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(SHADER)), + }); + let pipeline_options = wgpu::PipelineCompilationOptions { + constants: &[("n".to_owned(), n.unwrap_or(0).into())].into(), + ..Default::default() + }; + let compute_pipeline = fail_if( + &ctx.device, + should_fail, + || { + ctx.device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: Some("main"), + compilation_options: if n.is_some() { + pipeline_options + } else { + wgpu::PipelineCompilationOptions::default() + }, + cache: None, + }) + }, + None, + ); + if should_fail { + return; + } + let init: &[u32] = &[0]; + let init_size: u64 = size_of_val(init).try_into().unwrap(); + let buffer = DeviceExt::create_buffer_init( + &ctx.device, + &wgpu::util::BufferInitDescriptor { + label: None, + contents: bytemuck::cast_slice(init), + usage: wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::COPY_SRC, + }, + ); + let mapping_buffer = ctx.device.create_buffer(&BufferDescriptor { + label: Some("mapping buffer"), + size: init_size, + usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&compute_pipeline); + let bind_group_layout = compute_pipeline.get_bind_group_layout(0); + let bind_group_entries = [wgpu::BindGroupEntry { + binding: 0, + resource: buffer.as_entire_binding(), + }]; + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &bind_group_entries, + }); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups(1, 1, 1); + } + encoder.copy_buffer_to_buffer(&buffer, 0, &mapping_buffer, 0, init_size); + ctx.queue.submit(Some(encoder.finish())); + + mapping_buffer.slice(..).map_async(MapMode::Read, |_| ()); + ctx.async_poll(Maintain::wait()).await.panic_on_timeout(); + + let mapped = mapping_buffer.slice(..).get_mapped_range(); + + let typed: &[u32] = bytemuck::cast_slice(&mapped); + assert_eq!(typed, out); +} diff --git a/tests/tests/shader/mod.rs b/tests/tests/shader/mod.rs index f05fbac25c..07c0fffb17 100644 --- a/tests/tests/shader/mod.rs +++ b/tests/tests/shader/mod.rs @@ -15,6 +15,7 @@ use wgpu::{ use wgpu_test::TestingContext; +pub mod array_size_overrides; pub mod compilation_messages; pub mod data_builtins; pub mod numeric_builtins;