Skip to content

Commit

Permalink
[d3d12] get num_workgroups builtin working for indirect dispatches
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Jul 24, 2024
1 parent 6b4d100 commit fb99ea9
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 17 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ Bottom level categories:
- Print requested and supported usages on `UnsupportedUsage` error. By @VladasZ in [#6007](https://github.com/gfx-rs/wgpu/pull/6007)
- Ensure safety of indirect dispatch. By @teoxoy in [#5714](https://github.com/gfx-rs/wgpu/pull/5714)

#### D3D12

- Get `num_workgroups` builtin working for indirect dispatches. By @teoxoy in [#5730](https://github.com/gfx-rs/wgpu/pull/5730)

## 22.0.0 (2024-07-17)

### Overview
Expand Down
8 changes: 3 additions & 5 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};

/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
#[gpu_test]
Expand All @@ -12,8 +12,7 @@ static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new(
.limits(wgpu::Limits {
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
}),
)
.run_async(|ctx| async move {
let num_workgroups = [1, 2, 3];
Expand All @@ -34,8 +33,7 @@ static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
max_compute_workgroups_per_dimension: 10,
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
}),
)
.run_async(|ctx| async move {
let max = ctx.device.limits().max_compute_workgroups_per_dimension;
Expand Down
3 changes: 2 additions & 1 deletion wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2622,7 +2622,8 @@ impl<A: HalApi> Device<A> {

let hal_desc = hal::PipelineLayoutDescriptor {
label: desc.label.to_hal(self.instance_flags),
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE,
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE
| hal::PipelineLayoutFlags::NUM_WORK_GROUPS,
bind_group_layouts: &raw_bind_group_layouts,
push_constant_ranges: desc.push_constant_ranges.as_ref(),
};
Expand Down
21 changes: 13 additions & 8 deletions wgpu-core/src/indirect_validation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::atomic::AtomicBool;
use std::{num::NonZeroU64, sync::atomic::AtomicBool};

use thiserror::Error;

Expand Down Expand Up @@ -63,7 +63,7 @@ impl<A: HalApi> IndirectValidation<A> {

let src = format!("
@group(0) @binding(0)
var<storage, read_write> dst: array<u32, 3>;
var<storage, read_write> dst: array<u32, 6>;
@group(1) @binding(0)
var<storage, read> src: array<u32>;
struct OffsetPc {{
Expand All @@ -78,6 +78,9 @@ impl<A: HalApi> IndirectValidation<A> {
dst[0] = res.x;
dst[1] = res.y;
dst[2] = res.z;
dst[3] = res.x;
dst[4] = res.y;
dst[5] = res.z;
}}
");

Expand Down Expand Up @@ -123,6 +126,8 @@ impl<A: HalApi> IndirectValidation<A> {
}
})?;

const DST_BUFFER_SIZE: NonZeroU64 = unsafe { NonZeroU64::new_unchecked(4 * 3 * 2) };

let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
label: None,
flags: hal::BindGroupLayoutFlags::empty(),
Expand All @@ -132,7 +137,7 @@ impl<A: HalApi> IndirectValidation<A> {
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
min_binding_size: Some(DST_BUFFER_SIZE),
},
count: None,
}],
Expand All @@ -152,7 +157,7 @@ impl<A: HalApi> IndirectValidation<A> {
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: true,
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
min_binding_size: Some(NonZeroU64::new(4 * 3).unwrap()),
},
count: None,
}],
Expand Down Expand Up @@ -204,7 +209,7 @@ impl<A: HalApi> IndirectValidation<A> {

let dst_buffer_desc = hal::BufferDescriptor {
label: None,
size: 4 * 3,
size: DST_BUFFER_SIZE.get(),
usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE,
memory_flags: hal::MemoryFlags::empty(),
};
Expand All @@ -224,7 +229,7 @@ impl<A: HalApi> IndirectValidation<A> {
buffers: &[hal::BufferBinding {
buffer: &dst_buffer_0,
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
size: Some(DST_BUFFER_SIZE),
}],
samplers: &[],
textures: &[],
Expand All @@ -247,7 +252,7 @@ impl<A: HalApi> IndirectValidation<A> {
buffers: &[hal::BufferBinding {
buffer: &dst_buffer_1,
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
size: Some(DST_BUFFER_SIZE),
}],
samplers: &[],
textures: &[],
Expand Down Expand Up @@ -292,7 +297,7 @@ impl<A: HalApi> IndirectValidation<A> {
buffers: &[hal::BufferBinding {
buffer,
offset: 0,
size: Some(std::num::NonZeroU64::new(binding_size).unwrap()),
size: Some(NonZeroU64::new(binding_size).unwrap()),
}],
samplers: &[],
textures: &[],
Expand Down
12 changes: 9 additions & 3 deletions wgpu-hal/src/dx12/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1193,11 +1193,17 @@ impl crate::CommandEncoder for super::CommandEncoder {
self.list.as_ref().unwrap().dispatch(count);
}
unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {
self.prepare_dispatch([0; 3]);
//TODO: update special constants indirectly
self.update_root_elements();
let cmd_signature = if let Some(cmd_signatures) =
self.pass.layout.special_constants_cmd_signatures.as_mut()
{
cmd_signatures.dispatch.as_mut_ptr()
} else {
self.shared.cmd_signatures.dispatch.as_mut_ptr()
};
unsafe {
self.list.as_ref().unwrap().ExecuteIndirect(
self.shared.cmd_signatures.dispatch.as_mut_ptr(),
cmd_signature,
1,
buffer.resource.as_mut_ptr(),
offset,
Expand Down
45 changes: 45 additions & 0 deletions wgpu-hal/src/dx12/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,50 @@ impl crate::Device for super::Device {
.create_root_signature(blob, 0)
.into_device_result("Root signature creation")?;

let special_constants_cmd_signatures =
if let Some(root_index) = special_constants_root_index {
Some(super::CommandSignatures {
draw: self
.raw
.create_command_signature(
raw.clone(),
&[
d3d12::IndirectArgument::constant(root_index, 0, 3),
d3d12::IndirectArgument::draw(),
],
12 + mem::size_of::<wgt::DrawIndirectArgs>() as u32,
0,
)
.into_device_result("Command (draw) signature creation")?,
draw_indexed: self
.raw
.create_command_signature(
raw.clone(),
&[
d3d12::IndirectArgument::constant(root_index, 0, 3),
d3d12::IndirectArgument::draw_indexed(),
],
12 + mem::size_of::<wgt::DrawIndexedIndirectArgs>() as u32,
0,
)
.into_device_result("Command (draw_indexed) signature creation")?,
dispatch: self
.raw
.create_command_signature(
raw.clone(),
&[
d3d12::IndirectArgument::constant(root_index, 0, 3),
d3d12::IndirectArgument::dispatch(),
],
12 + mem::size_of::<wgt::DispatchIndirectArgs>() as u32,
0,
)
.into_device_result("Command (dispatch) signature creation")?,
})
} else {
None
};

log::debug!("\traw = {:?}", raw);

if let Some(label) = desc.label {
Expand All @@ -1119,6 +1163,7 @@ impl crate::Device for super::Device {
signature: raw,
total_root_elements: parameters.len() as super::RootIndex,
special_constants_root_index,
special_constants_cmd_signatures,
root_constant_info,
},
bind_group_infos,
Expand Down
3 changes: 3 additions & 0 deletions wgpu-hal/src/dx12/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ struct Idler {
event: d3d12::Event,
}

#[derive(Debug, Clone)]
struct CommandSignatures {
draw: d3d12::CommandSignature,
draw_indexed: d3d12::CommandSignature,
Expand Down Expand Up @@ -345,6 +346,7 @@ impl PassState {
signature: d3d12::RootSignature::null(),
total_root_elements: 0,
special_constants_root_index: None,
special_constants_cmd_signatures: None,
root_constant_info: None,
},
root_elements: [RootElement::Empty; MAX_ROOT_ELEMENTS],
Expand Down Expand Up @@ -556,6 +558,7 @@ struct PipelineLayoutShared {
signature: d3d12::RootSignature,
total_root_elements: RootIndex,
special_constants_root_index: Option<RootIndex>,
special_constants_cmd_signatures: Option<CommandSignatures>,
root_constant_info: Option<RootConstantInfo>,
}

Expand Down

0 comments on commit fb99ea9

Please sign in to comment.