Skip to content

Commit

Permalink
Merge branch 'trunk' into ad/i64-atomics
Browse files Browse the repository at this point in the history
  • Loading branch information
atlv24 authored May 29, 2024
2 parents 163e695 + b5b39f6 commit 2e4623d
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 10 deletions.
42 changes: 41 additions & 1 deletion naga/src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ pub enum FunctionError {
WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
#[error("Subgroup operation is invalid")]
InvalidSubgroup(#[from] SubgroupError),
#[error("Emit statement should not cover \"result\" expressions like {0:?}")]
EmitResult(Handle<crate::Expression>),
}

bitflags::bitflags! {
Expand Down Expand Up @@ -649,7 +651,45 @@ impl super::Validator {
match *statement {
S::Emit(ref range) => {
for handle in range.clone() {
self.emit_expression(handle, context)?;
use crate::Expression as Ex;
match context.expressions[handle] {
Ex::Literal(_)
| Ex::Constant(_)
| Ex::Override(_)
| Ex::ZeroValue(_)
| Ex::Compose { .. }
| Ex::Access { .. }
| Ex::AccessIndex { .. }
| Ex::Splat { .. }
| Ex::Swizzle { .. }
| Ex::FunctionArgument(_)
| Ex::GlobalVariable(_)
| Ex::LocalVariable(_)
| Ex::Load { .. }
| Ex::ImageSample { .. }
| Ex::ImageLoad { .. }
| Ex::ImageQuery { .. }
| Ex::Unary { .. }
| Ex::Binary { .. }
| Ex::Select { .. }
| Ex::Derivative { .. }
| Ex::Relational { .. }
| Ex::Math { .. }
| Ex::As { .. }
| Ex::ArrayLength(_)
| Ex::RayQueryGetIntersection { .. } => {
self.emit_expression(handle, context)?
}
Ex::CallResult(_)
| Ex::AtomicResult { .. }
| Ex::WorkGroupUniformLoadResult { .. }
| Ex::RayQueryProceedResult
| Ex::SubgroupBallotResult
| Ex::SubgroupOperationResult { .. } => {
return Err(FunctionError::EmitResult(handle)
.with_span_handle(handle, context.expressions));
}
}
}
}
S::Block(ref block) => {
Expand Down
1 change: 1 addition & 0 deletions naga/tests/root.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod example_wgsl;
mod snapshots;
mod spirv_capabilities;
mod validation;
mod wgsl_errors;
230 changes: 230 additions & 0 deletions naga/tests/validation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
use naga::{valid, Expression, Function, Scalar};

#[test]
fn emit_atomic_result() {
use naga::{Module, Type, TypeInner};

// We want to ensure that the *only* problem with the code is the
// use of an `Emit` statement instead of an `Atomic` statement. So
// validate two versions of the module varying only in that
// aspect.
//
// Looking at uses of the `atomic` makes it easy to identify the
// differences between the two variants.
fn variant(
atomic: bool,
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
let span = naga::Span::default();
let mut module = Module::default();
let ty_u32 = module.types.insert(
Type {
name: Some("u32".into()),
inner: TypeInner::Scalar(Scalar::U32),
},
span,
);
let ty_atomic_u32 = module.types.insert(
Type {
name: Some("atomic<u32>".into()),
inner: TypeInner::Atomic(Scalar::U32),
},
span,
);
let var_atomic = module.global_variables.append(
naga::GlobalVariable {
name: Some("atomic_global".into()),
space: naga::AddressSpace::WorkGroup,
binding: None,
ty: ty_atomic_u32,
init: None,
},
span,
);

let mut fun = Function::default();
let ex_global = fun
.expressions
.append(Expression::GlobalVariable(var_atomic), span);
let ex_42 = fun
.expressions
.append(Expression::Literal(naga::Literal::U32(42)), span);
let ex_result = fun.expressions.append(
Expression::AtomicResult {
ty: ty_u32,
comparison: false,
},
span,
);

if atomic {
fun.body.push(
naga::Statement::Atomic {
pointer: ex_global,
fun: naga::AtomicFunction::Add,
value: ex_42,
result: ex_result,
},
span,
);
} else {
fun.body.push(
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
span,
);
}

module.functions.append(fun, span);

valid::Validator::new(
valid::ValidationFlags::default(),
valid::Capabilities::all(),
)
.validate(&module)
}

variant(true).expect("module should validate");
assert!(variant(false).is_err());
}

#[test]
fn emit_call_result() {
use naga::{Module, Type, TypeInner};

// We want to ensure that the *only* problem with the code is the
// use of an `Emit` statement instead of a `Call` statement. So
// validate two versions of the module varying only in that
// aspect.
//
// Looking at uses of the `call` makes it easy to identify the
// differences between the two variants.
fn variant(
call: bool,
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
let span = naga::Span::default();
let mut module = Module::default();
let ty_u32 = module.types.insert(
Type {
name: Some("u32".into()),
inner: TypeInner::Scalar(Scalar::U32),
},
span,
);

let mut fun_callee = Function {
result: Some(naga::FunctionResult {
ty: ty_u32,
binding: None,
}),
..Function::default()
};
let ex_42 = fun_callee
.expressions
.append(Expression::Literal(naga::Literal::U32(42)), span);
fun_callee
.body
.push(naga::Statement::Return { value: Some(ex_42) }, span);
let fun_callee = module.functions.append(fun_callee, span);

let mut fun_caller = Function::default();
let ex_result = fun_caller
.expressions
.append(Expression::CallResult(fun_callee), span);

if call {
fun_caller.body.push(
naga::Statement::Call {
function: fun_callee,
arguments: vec![],
result: Some(ex_result),
},
span,
);
} else {
fun_caller.body.push(
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
span,
);
}

module.functions.append(fun_caller, span);

valid::Validator::new(
valid::ValidationFlags::default(),
valid::Capabilities::all(),
)
.validate(&module)
}

variant(true).expect("should validate");
assert!(variant(false).is_err());
}

#[test]
fn emit_workgroup_uniform_load_result() {
use naga::{Module, Type, TypeInner};

// We want to ensure that the *only* problem with the code is the
// use of an `Emit` statement instead of an `Atomic` statement. So
// validate two versions of the module varying only in that
// aspect.
//
// Looking at uses of the `wg_load` makes it easy to identify the
// differences between the two variants.
fn variant(
wg_load: bool,
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
let span = naga::Span::default();
let mut module = Module::default();
let ty_u32 = module.types.insert(
Type {
name: Some("u32".into()),
inner: TypeInner::Scalar(Scalar::U32),
},
span,
);
let var_workgroup = module.global_variables.append(
naga::GlobalVariable {
name: Some("workgroup_global".into()),
space: naga::AddressSpace::WorkGroup,
binding: None,
ty: ty_u32,
init: None,
},
span,
);

let mut fun = Function::default();
let ex_global = fun
.expressions
.append(Expression::GlobalVariable(var_workgroup), span);
let ex_result = fun
.expressions
.append(Expression::WorkGroupUniformLoadResult { ty: ty_u32 }, span);

if wg_load {
fun.body.push(
naga::Statement::WorkGroupUniformLoad {
pointer: ex_global,
result: ex_result,
},
span,
);
} else {
fun.body.push(
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
span,
);
}

module.functions.append(fun, span);

valid::Validator::new(
valid::ValidationFlags::default(),
valid::Capabilities::all(),
)
.validate(&module)
}

variant(true).expect("module should validate");
assert!(variant(false).is_err());
}
9 changes: 0 additions & 9 deletions wgpu-core/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,15 +915,6 @@ impl Interface {
class,
},
naga::TypeInner::Sampler { comparison } => ResourceType::Sampler { comparison },
naga::TypeInner::Array { stride, size, .. } => {
let size = match size {
naga::ArraySize::Constant(size) => size.get() * stride,
naga::ArraySize::Dynamic => stride,
};
ResourceType::Buffer {
size: wgt::BufferSize::new(size as u64).unwrap(),
}
}
ref other => ResourceType::Buffer {
size: wgt::BufferSize::new(other.size(module.to_ctx()) as u64).unwrap(),
},
Expand Down

0 comments on commit 2e4623d

Please sign in to comment.