Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Override-expressions for Fixed Size Arrays #6654

Merged
merged 21 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148]
- Allow for override-expressions in `workgroup_size`. By @KentSlaney in [#6635](https://github.com/gfx-rs/wgpu/pull/6635).
- Add support for OpAtomicCompareExchange in SPIR-V frontend. By @schell in [#6590](https://github.com/gfx-rs/wgpu/pull/6590).
- Implement type inference for abstract arguments to user-defined functions. By @jamienicol in [#6577](https://github.com/gfx-rs/wgpu/pull/6577).
- Allow for override-expressions in array sizes. By @KentSlaney in [#6654](https://github.com/gfx-rs/wgpu/pull/6654).

#### General

Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,7 @@ impl<'a, W: Write> Writer<'a, W> {
crate::ArraySize::Constant(size) => {
write!(self.out, "{size}")?;
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => (),
}

Expand Down Expand Up @@ -4459,6 +4460,7 @@ impl<'a, W: Write> Writer<'a, W> {
.expect("Bad array size")
{
proc::IndexableLength::Known(count) => count,
proc::IndexableLength::Pending => unreachable!(),
proc::IndexableLength::Dynamic => return Ok(()),
};
self.write_type(base)?;
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ impl crate::TypeInner {
let count = match size {
crate::ArraySize::Constant(size) => size.get(),
// A dynamically-sized array has to have at least one element
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => 1,
};
let last_el_size = gctx.types[base].inner.size_hlsl(gctx);
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
crate::ArraySize::Constant(size) => {
write!(self.out, "{size}")?;
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => unreachable!(),
}

Expand Down Expand Up @@ -2634,6 +2635,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
index::IndexableLength::Known(limit) => {
write!(self.out, "{}u", limit - 1)?;
}
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => unreachable!(),
}
write!(self.out, ")")?;
Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2407,6 +2407,7 @@ impl<W: Write> Writer<W> {
self.out.write_str(") < ")?;
match length {
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => {
let global =
context.function.originating_global(base).ok_or_else(|| {
Expand Down Expand Up @@ -2569,6 +2570,7 @@ impl<W: Write> Writer<W> {
index::IndexableLength::Known(limit) => {
write!(self.out, "{}u", limit - 1)?;
}
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => {
let global = context.function.originating_global(base).ok_or_else(|| {
Error::GenericValidation("Could not find originating global".into())
Expand Down Expand Up @@ -3740,6 +3742,9 @@ impl<W: Write> Writer<W> {
)?;
writeln!(self.out, "}};")?;
}
crate::ArraySize::Pending(_) => {
unreachable!()
}
crate::ArraySize::Dynamic => {
writeln!(self.out, "typedef {base_name} {name}[1];")?;
}
Expand Down Expand Up @@ -6008,6 +6013,7 @@ mod workgroup_mem_init {
let count = match size.to_indexable_length(module).expect("Bad array size")
{
proc::IndexableLength::Known(count) => count,
proc::IndexableLength::Pending => unreachable!(),
proc::IndexableLength::Dynamic => unreachable!(),
};

Expand Down
60 changes: 60 additions & 0 deletions naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ pub fn process_overrides<'a>(
}
module.entry_points = entry_points;

process_pending(&mut module, &override_map, &adjusted_global_expressions)?;

// Now that we've rewritten all the expressions, we need to
// recompute their types and other metadata. For the time being,
// do a full re-validation.
Expand All @@ -205,6 +207,64 @@ pub fn process_overrides<'a>(
Ok((Cow::Owned(module), Cow::Owned(module_info)))
}

fn process_pending(
module: &mut Module,
override_map: &HandleVec<Override, Handle<Constant>>,
adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
) -> Result<(), PipelineConstantError> {
for (handle, ty) in module.types.clone().iter() {
if let crate::TypeInner::Array {
base,
size: crate::ArraySize::Pending(size),
stride,
} = ty.inner
{
let expr = match size {
crate::PendingArraySize::Expression(size_expr) => {
adjusted_global_expressions[size_expr]
}
crate::PendingArraySize::Override(size_override) => {
module.constants[override_map[size_override]].init
}
};
let value = module
.to_ctx()
.eval_expr_to_u32(expr)
.map(|n| {
if n == 0 {
Err(PipelineConstantError::ValidationError(
WithSpan::new(ValidationError::ArraySizeError { handle: expr })
.with_span(
module.global_expressions.get_span(expr),
"evaluated to zero",
),
))
} else {
Ok(std::num::NonZeroU32::new(n).unwrap())
}
})
.map_err(|_| {
PipelineConstantError::ValidationError(
WithSpan::new(ValidationError::ArraySizeError { handle: expr })
.with_span(module.global_expressions.get_span(expr), "negative"),
)
})??;
module.types.replace(
handle,
crate::Type {
name: None,
inner: crate::TypeInner::Array {
base,
size: crate::ArraySize::Constant(value),
stride,
},
},
);
}
}
Ok(())
}

fn process_workgroup_size_override(
module: &mut Module,
adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/spv/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ impl BlockContext<'_> {
Ok(crate::proc::IndexableLength::Known(known_length)) => {
Ok(MaybeKnown::Known(known_length))
}
Ok(crate::proc::IndexableLength::Pending) => {
unreachable!()
}
Ok(crate::proc::IndexableLength::Dynamic) => {
let length_id = self.write_runtime_array_length(sequence, block)?;
Ok(MaybeKnown::Computed(length_id))
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ impl Writer {
let length_id = self.get_index_constant(length.get());
Instruction::type_array(id, type_id, length_id)
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
}
}
Expand All @@ -981,6 +982,7 @@ impl Writer {
let length_id = self.get_index_constant(length.get());
Instruction::type_array(id, type_id, length_id)
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
}
}
Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,9 @@ impl<W: Write> Writer<W> {
self.write_type(module, base)?;
write!(self.out, ", {len}")?;
}
crate::ArraySize::Pending(_) => {
unreachable!();
}
crate::ArraySize::Dynamic => {
self.write_type(module, base)?;
}
Expand All @@ -534,6 +537,9 @@ impl<W: Write> Writer<W> {
self.write_type(module, base)?;
write!(self.out, ", {len}")?;
}
crate::ArraySize::Pending(_) => {
unreachable!();
}
crate::ArraySize::Dynamic => {
self.write_type(module, base)?;
}
Expand Down
34 changes: 34 additions & 0 deletions naga/src/compact/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ pub fn compact(module: &mut crate::Module) {
}
}

for (_, ty) in module.types.iter() {
if let crate::TypeInner::Array {
size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(size_expr)),
..
} = ty.inner
{
module_tracer.global_expressions_used.insert(size_expr);
}
}

for e in module.entry_points.iter() {
if let Some(sizes) = e.workgroup_size_overrides {
for size in sizes.iter().filter_map(|x| *x) {
Expand Down Expand Up @@ -206,6 +216,30 @@ pub fn compact(module: &mut crate::Module) {
}
}

for (handle, ty) in module.types.clone().iter() {
if let crate::TypeInner::Array {
base,
size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(mut size_expr)),
stride,
} = ty.inner
{
module_map.global_expressions.adjust(&mut size_expr);
module.types.replace(
handle,
crate::Type {
name: None,
inner: crate::TypeInner::Array {
base,
size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(
size_expr,
)),
stride,
},
},
);
}
}

// Temporary storage to help us reuse allocations of existing
// named expression tables.
let mut reused_named_expressions = crate::NamedExpressions::default();
Expand Down
1 change: 1 addition & 0 deletions naga/src/front/glsl/offset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pub fn calculate_offset(

let span = match size {
crate::ArraySize::Constant(size) => size.get() * stride,
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => stride,
};

Expand Down
3 changes: 3 additions & 0 deletions naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
crate::TypeInner::Array { size, .. } => {
let size = match size {
crate::ArraySize::Constant(size) => size.get(),
crate::ArraySize::Pending(_) => {
unreachable!();
}
// A runtime sized array is not a composite type
crate::ArraySize::Dynamic => {
return Err(Error::InvalidAccessType(root_type_id))
Expand Down
71 changes: 57 additions & 14 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3056,26 +3056,69 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Ok(match size {
ast::ArraySize::Constant(expr) => {
let span = ctx.ast_expressions.get_span(expr);
let const_expr = self.expression(expr, &mut ctx.as_const())?;
let len =
ctx.module
.to_ctx()
.eval_expr_to_u32(const_expr)
.map_err(|err| match err {
crate::proc::U32EvalError::NonConst => {
Error::ExpectedConstExprConcreteIntegerScalar(span)
}
crate::proc::U32EvalError::Negative => {
Error::ExpectedPositiveArrayLength(span)
let const_expr = self.expression(expr, &mut ctx.as_const());
match const_expr {
Ok(value) => {
let len =
ctx.module.to_ctx().eval_expr_to_u32(value).map_err(
|err| match err {
crate::proc::U32EvalError::NonConst => {
Error::ExpectedConstExprConcreteIntegerScalar(span)
}
crate::proc::U32EvalError::Negative => {
Error::ExpectedPositiveArrayLength(span)
}
},
)?;
let size =
NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?;
crate::ArraySize::Constant(size)
}
err => {
if let Err(Error::ConstantEvaluatorError(ref ty, _)) = err {
match **ty {
crate::proc::ConstantEvaluatorError::OverrideExpr => {
crate::ArraySize::Pending(self.array_size_override(
expr,
&mut ctx.as_override(),
span,
)?)
}
_ => {
err?;
unreachable!()
}
}
})?;
let size = NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?;
crate::ArraySize::Constant(size)
} else {
err?;
unreachable!()
}
}
}
}
ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
})
}

fn array_size_override(
&mut self,
size_expr: Handle<ast::Expression<'source>>,
ctx: &mut ExpressionContext<'source, '_, '_>,
span: Span,
) -> Result<crate::PendingArraySize, Error<'source>> {
let expr = self.expression(size_expr, ctx)?;
match resolve_inner!(ctx, expr).scalar_kind().ok_or(0) {
Ok(crate::ScalarKind::Sint) | Ok(crate::ScalarKind::Uint) => Ok({
if let crate::Expression::Override(handle) = ctx.module.global_expressions[expr] {
crate::PendingArraySize::Override(handle)
} else {
crate::PendingArraySize::Expression(expr)
}
}),
_ => Err(Error::ExpectedConstExprConcreteIntegerScalar(span)),
}
}

/// Build the Naga equivalent of a named AST type.
///
/// Return a Naga `Handle<Type>` representing the front-end type
Expand Down
Loading
Loading