Skip to content

Commit

Permalink
feature: [spv-front] Support for OpAtomicCompareExchange (#6590)
Browse files Browse the repository at this point in the history
Add support for parsing and executing OpAtomicCompareExchange in the SPIR-V frontend.
This concludes the work to support atomics in the SPIR-V frontend,
excluding test clean-up.

Fixes #6296.
Fixes #6590.

Connections:

- [naga spv-in] Support for OpAtomicCompareExchange #6296
- [spv-in] Atomics support #4489

Co-authored-by: Jim Blandy <[email protected]>
  • Loading branch information
schell and jimblandy authored Dec 9, 2024
1 parent 6e5d398 commit 234b6dd
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148]
- Add support for GLSL `usampler*` and `isampler*`. By @DavidPeicho in [#6513](https://github.com/gfx-rs/wgpu/pull/6513).
- Expose Ray Query flags as constants in WGSL. Implement candidate intersections. By @kvark in [#5429](https://github.com/gfx-rs/wgpu/pull/5429)
- Allow for override-expressions in `workgroup_size`. By @KentSlaney in [#6635](https://github.com/gfx-rs/wgpu/pull/6635).
- Add support for OpAtomicCompareExchange in SPIR-V frontend. By @schell in [#6590](https://github.com/gfx-rs/wgpu/pull/6590).

#### General

Expand Down
2 changes: 2 additions & 0 deletions naga/src/front/atomic_upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ pub enum Error {
GlobalInitUnsupported,
#[error("expected to find a global variable")]
GlobalVariableMissing,
#[error("atomic compare exchange requires a scalar base type")]
CompareExchangeNonScalarBaseType,
}

#[derive(Clone, Default)]
Expand Down
119 changes: 109 additions & 10 deletions naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4273,6 +4273,102 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
self.upgrade_atomics
.insert(ctx.get_contained_global_variable(p_exp_h)?);
}
Op::AtomicCompareExchange => {
inst.expect(9)?;

let start = self.data_offset;
let span = self.span_from_with_op(start);
let result_type_id = self.next()?;
let result_id = self.next()?;
let pointer_id = self.next()?;
let _memory_scope_id = self.next()?;
let _equal_memory_semantics_id = self.next()?;
let _unequal_memory_semantics_id = self.next()?;
let value_id = self.next()?;
let comparator_id = self.next()?;

let (p_exp_h, p_base_ty_h) = self.get_exp_and_base_ty_handles(
pointer_id,
ctx,
&mut emitter,
&mut block,
body_idx,
)?;

log::trace!("\t\t\tlooking up value expr {:?}", value_id);
let v_lexp_handle =
get_expr_handle!(value_id, self.lookup_expression.lookup(value_id)?);

log::trace!("\t\t\tlooking up comparator expr {:?}", value_id);
let c_lexp_handle = get_expr_handle!(
comparator_id,
self.lookup_expression.lookup(comparator_id)?
);

// We know from the SPIR-V spec that the result type must be an integer
// scalar, and we'll need the type itself to get a handle to the atomic
// result struct.
let crate::TypeInner::Scalar(scalar) = ctx.module.types[p_base_ty_h].inner
else {
return Err(
crate::front::atomic_upgrade::Error::CompareExchangeNonScalarBaseType
.into(),
);
};

// Get a handle to the atomic result struct type.
let atomic_result_struct_ty_h = ctx.module.generate_predeclared_type(
crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar),
);

block.extend(emitter.finish(ctx.expressions));

// Create an expression for our atomic result
let atomic_lexp_handle = {
let expr = crate::Expression::AtomicResult {
ty: atomic_result_struct_ty_h,
comparison: true,
};
ctx.expressions.append(expr, span)
};

// Create an dot accessor to extract the value from the
// result struct __atomic_compare_exchange_result<T> and use that
// as the expression for the result_id
{
let expr = crate::Expression::AccessIndex {
base: atomic_lexp_handle,
index: 0,
};
let handle = ctx.expressions.append(expr, span);
// Use this dot accessor as the result id's expression
let _ = self.lookup_expression.insert(
result_id,
LookupExpression {
handle,
type_id: result_type_id,
block_id,
},
);
}

emitter.start(ctx.expressions);

// Create a statement for the op itself
let stmt = crate::Statement::Atomic {
pointer: p_exp_h,
fun: crate::AtomicFunction::Exchange {
compare: Some(c_lexp_handle),
},
value: v_lexp_handle,
result: Some(atomic_lexp_handle),
};
block.push(stmt, span);

// Store any associated global variables so we can upgrade their types later
self.upgrade_atomics
.insert(ctx.get_contained_global_variable(p_exp_h)?);
}
Op::AtomicExchange
| Op::AtomicIAdd
| Op::AtomicISub
Expand Down Expand Up @@ -5969,17 +6065,18 @@ mod test_atomic {
let m = crate::front::spv::parse_u8_slice(bytes, &Default::default()).unwrap();

let mut wgsl = String::new();
let mut should_panic = false;

for vflags in [
crate::valid::ValidationFlags::all(),
crate::valid::ValidationFlags::empty(),
for (vflags, name) in [
(crate::valid::ValidationFlags::empty(), "empty"),
(crate::valid::ValidationFlags::all(), "all"),
] {
log::info!("validating with flags - {name}");
let mut validator = crate::valid::Validator::new(vflags, Default::default());
match validator.validate(&m) {
Err(e) => {
log::error!("SPIR-V validation {}", e.emit_to_string(""));
should_panic = true;
log::info!("types: {:#?}", m.types);
panic!("validation error");
}
Ok(i) => {
wgsl = crate::back::wgsl::write_string(
Expand All @@ -5989,15 +6086,10 @@ mod test_atomic {
)
.unwrap();
log::info!("wgsl-out:\n{wgsl}");
break;
}
};
}

if should_panic {
panic!("validation error");
}

let m = match crate::front::wgsl::parse_str(&wgsl) {
Ok(m) => m,
Err(e) => {
Expand Down Expand Up @@ -6032,6 +6124,13 @@ mod test_atomic {
atomic_test(include_bytes!("../../../tests/in/spv/atomic_exchange.spv"));
}

#[test]
fn atomic_compare_exchange() {
atomic_test(include_bytes!(
"../../../tests/in/spv/atomic_compare_exchange.spv"
));
}

#[test]
fn atomic_i_decrement() {
atomic_test(include_bytes!(
Expand Down
Binary file added naga/tests/in/spv/atomic_compare_exchange.spv
Binary file not shown.
89 changes: 89 additions & 0 deletions naga/tests/in/spv/atomic_compare_exchange.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
; SPIR-V
; Version: 1.5
; Generator: Google rspirv; 0
; Bound: 65
; Schema: 0
OpCapability Shader
OpCapability VulkanMemoryModel
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %1 "stage::test_atomic_compare_exchange" %2 %3
OpExecutionMode %1 LocalSize 32 1 1
OpMemberDecorate %_struct_9 0 Offset 0
OpMemberDecorate %_struct_9 1 Offset 4
OpDecorate %_struct_10 Block
OpMemberDecorate %_struct_10 0 Offset 0
OpDecorate %2 Binding 0
OpDecorate %2 DescriptorSet 0
OpDecorate %3 NonWritable
OpDecorate %3 Binding 1
OpDecorate %3 DescriptorSet 0
%uint = OpTypeInt 32 0
%void = OpTypeVoid
%13 = OpTypeFunction %void
%bool = OpTypeBool
%uint_0 = OpConstant %uint 0
%uint_2 = OpConstant %uint 2
%false = OpConstantFalse %bool
%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
%uint_1 = OpConstant %uint 1
%_struct_9 = OpTypeStruct %uint %uint
%20 = OpUndef %_struct_9
%uint_3 = OpConstant %uint 3
%int = OpTypeInt 32 1
%23 = OpUndef %bool
%true = OpConstantTrue %bool
%_struct_10 = OpTypeStruct %uint
%_ptr_StorageBuffer__struct_10 = OpTypePointer StorageBuffer %_struct_10
%2 = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer
%3 = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer
%uint_256 = OpConstant %uint 256
%1 = OpFunction %void None %13
%27 = OpLabel
%28 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %2 %uint_0
%29 = OpInBoundsAccessChain %_ptr_StorageBuffer_uint %3 %uint_0
%30 = OpLoad %uint %29
%31 = OpCompositeConstruct %_struct_9 %uint_0 %30
OpBranch %32
%32 = OpLabel
%33 = OpPhi %_struct_9 %31 %27 %34 %35
OpLoopMerge %36 %35 None
OpBranch %37
%37 = OpLabel
%38 = OpCompositeExtract %uint %33 0
%39 = OpCompositeExtract %uint %33 1
%40 = OpULessThan %bool %38 %39
OpSelectionMerge %41 None
OpBranchConditional %40 %42 %43
%42 = OpLabel
%45 = OpIAdd %uint %38 %uint_1
%46 = OpCompositeInsert %_struct_9 %45 %33 0
%47 = OpCompositeConstruct %_struct_9 %uint_1 %38
OpBranch %41
%43 = OpLabel
%48 = OpCompositeInsert %_struct_9 %uint_0 %20 0
OpBranch %41
%41 = OpLabel
%34 = OpPhi %_struct_9 %46 %42 %33 %43
%49 = OpPhi %_struct_9 %47 %42 %48 %43
%50 = OpCompositeExtract %uint %49 0
%51 = OpCompositeExtract %uint %49 1
%52 = OpBitcast %int %50
OpSelectionMerge %53 None
OpSwitch %52 %54 0 %55 1 %56
%54 = OpLabel
OpBranch %53
%55 = OpLabel
OpBranch %53
%56 = OpLabel
%57 = OpAtomicCompareExchange %uint %28 %uint_2 %uint_256 %uint_256 %51 %uint_3
%58 = OpIEqual %bool %57 %uint_3
%64 = OpSelect %bool %58 %false %true
OpBranch %53
%53 = OpLabel
%63 = OpPhi %bool %23 %54 %false %55 %64 %56
OpBranch %35
%35 = OpLabel
OpBranchConditional %63 %32 %36
%36 = OpLabel
OpReturn
OpFunctionEnd

0 comments on commit 234b6dd

Please sign in to comment.