Skip to content

Commit

Permalink
[d3d12] get num_workgroups builtin working for indirect dispatches
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Oct 14, 2024
1 parent bc94fab commit 03a396d
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 46 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
#### DX12

- Replace `winapi` code to use the `windows` crate. By @MarijnS95 in [#5956](https://github.com/gfx-rs/wgpu/pull/5956) and [#6173](https://github.com/gfx-rs/wgpu/pull/6173)
- Get `num_workgroups` builtin working for indirect dispatches. By @teoxoy in [#5730](https://github.com/gfx-rs/wgpu/pull/5730)

#### HAL

Expand Down
8 changes: 3 additions & 5 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};

/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
#[gpu_test]
Expand All @@ -12,8 +12,7 @@ static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new(
.limits(wgpu::Limits {
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
}),
)
.run_async(|ctx| async move {
let num_workgroups = [1, 2, 3];
Expand All @@ -34,8 +33,7 @@ static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
max_compute_workgroups_per_dimension: 10,
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
}),
)
.run_async(|ctx| async move {
let max = ctx.device.limits().max_compute_workgroups_per_dimension;
Expand Down
3 changes: 2 additions & 1 deletion wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2638,7 +2638,8 @@ impl Device {

let hal_desc = hal::PipelineLayoutDescriptor {
label: desc.label.to_hal(self.instance_flags),
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE,
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE
| hal::PipelineLayoutFlags::NUM_WORK_GROUPS,
bind_group_layouts: &raw_bind_group_layouts,
push_constant_ranges: desc.push_constant_ranges.as_ref(),
};
Expand Down
31 changes: 21 additions & 10 deletions wgpu-core/src/indirect_validation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::atomic::AtomicBool;
use std::{mem::size_of, num::NonZeroU64, sync::atomic::AtomicBool};

use thiserror::Error;

Expand Down Expand Up @@ -63,7 +63,7 @@ impl IndirectValidation {
let src = format!(
"
@group(0) @binding(0)
var<storage, read_write> dst: array<u32, 3>;
var<storage, read_write> dst: array<u32, 6>;
@group(1) @binding(0)
var<storage, read> src: array<u32>;
struct OffsetPc {{
Expand All @@ -80,14 +80,25 @@ impl IndirectValidation {
src.y > max_compute_workgroups_per_dimension ||
src.z > max_compute_workgroups_per_dimension
) {{
dst = array(0u, 0u, 0u);
dst = array(0u, 0u, 0u, 0u, 0u, 0u);
}} else {{
dst = array(src.x, src.y, src.z);
dst = array(src.x, src.y, src.z, src.x, src.y, src.z);
}}
}}
"
);

// SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe.
const SRC_BUFFER_SIZE: NonZeroU64 =
unsafe { NonZeroU64::new_unchecked(size_of::<u32>() as u64 * 3) };

// SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe.
const DST_BUFFER_SIZE: NonZeroU64 = unsafe {
NonZeroU64::new_unchecked(
SRC_BUFFER_SIZE.get() * 2, // From above: `dst: array<u32, 6>`
)
};

let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
CreateShaderModuleError::Parsing(naga::error::ShaderError {
source: src.clone(),
Expand Down Expand Up @@ -139,7 +150,7 @@ impl IndirectValidation {
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
min_binding_size: Some(DST_BUFFER_SIZE),
},
count: None,
}],
Expand All @@ -159,7 +170,7 @@ impl IndirectValidation {
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: true,
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
min_binding_size: Some(SRC_BUFFER_SIZE),
},
count: None,
}],
Expand Down Expand Up @@ -217,7 +228,7 @@ impl IndirectValidation {

let dst_buffer_desc = hal::BufferDescriptor {
label: None,
size: 4 * 3,
size: DST_BUFFER_SIZE.get(),
usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE,
memory_flags: hal::MemoryFlags::empty(),
};
Expand All @@ -237,7 +248,7 @@ impl IndirectValidation {
buffers: &[hal::BufferBinding {
buffer: dst_buffer_0.as_ref(),
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
size: Some(DST_BUFFER_SIZE),
}],
samplers: &[],
textures: &[],
Expand All @@ -260,7 +271,7 @@ impl IndirectValidation {
buffers: &[hal::BufferBinding {
buffer: dst_buffer_1.as_ref(),
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
size: Some(DST_BUFFER_SIZE),
}],
samplers: &[],
textures: &[],
Expand Down Expand Up @@ -305,7 +316,7 @@ impl IndirectValidation {
buffers: &[hal::BufferBinding {
buffer,
offset: 0,
size: Some(std::num::NonZeroU64::new(binding_size).unwrap()),
size: Some(NonZeroU64::new(binding_size).unwrap()),
}],
samplers: &[],
textures: &[],
Expand Down
12 changes: 9 additions & 3 deletions wgpu-hal/src/dx12/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1210,11 +1210,17 @@ impl crate::CommandEncoder for super::CommandEncoder {
}

unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {
self.prepare_dispatch([0; 3]);
//TODO: update special constants indirectly
self.update_root_elements();
let cmd_signature = &self
.pass
.layout
.special_constants_cmd_signatures
.as_ref()
.unwrap_or_else(|| &self.shared.cmd_signatures)
.dispatch;
unsafe {
self.list.as_ref().unwrap().ExecuteIndirect(
&self.shared.cmd_signatures.dispatch,
cmd_signature,
1,
&buffer.resource,
offset,
Expand Down
134 changes: 107 additions & 27 deletions wgpu-hal/src/dx12/device.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
ffi,
mem::{self, size_of},
mem::{self, size_of, size_of_val},
num::NonZeroU32,
ptr,
sync::Arc,
Expand Down Expand Up @@ -94,52 +94,32 @@ impl super::Device {
let capacity_views = limits.max_non_sampler_bindings as u64;
let capacity_samplers = 2_048;

fn create_command_signature(
raw: &Direct3D12::ID3D12Device,
byte_stride: usize,
arguments: &[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC],
node_mask: u32,
) -> Result<Direct3D12::ID3D12CommandSignature, crate::DeviceError> {
let mut signature = None;
unsafe {
raw.CreateCommandSignature(
&Direct3D12::D3D12_COMMAND_SIGNATURE_DESC {
ByteStride: byte_stride as u32,
NumArgumentDescs: arguments.len() as u32,
pArgumentDescs: arguments.as_ptr(),
NodeMask: node_mask,
},
None,
&mut signature,
)
}
.into_device_result("Command signature creation")?;
signature.ok_or(crate::DeviceError::Unexpected)
}

let shared = super::DeviceShared {
zero_buffer,
cmd_signatures: super::CommandSignatures {
draw: create_command_signature(
draw: Self::create_command_signature(
&raw,
None,
size_of::<wgt::DrawIndirectArgs>(),
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW,
..Default::default()
}],
0,
)?,
draw_indexed: create_command_signature(
draw_indexed: Self::create_command_signature(
&raw,
None,
size_of::<wgt::DrawIndexedIndirectArgs>(),
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW_INDEXED,
..Default::default()
}],
0,
)?,
dispatch: create_command_signature(
dispatch: Self::create_command_signature(
&raw,
None,
size_of::<wgt::DispatchIndirectArgs>(),
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH,
Expand Down Expand Up @@ -214,6 +194,30 @@ impl super::Device {
})
}

fn create_command_signature(
raw: &Direct3D12::ID3D12Device,
root_signature: Option<&Direct3D12::ID3D12RootSignature>,
byte_stride: usize,
arguments: &[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC],
node_mask: u32,
) -> Result<Direct3D12::ID3D12CommandSignature, crate::DeviceError> {
let mut signature = None;
unsafe {
raw.CreateCommandSignature(
&Direct3D12::D3D12_COMMAND_SIGNATURE_DESC {
ByteStride: byte_stride as u32,
NumArgumentDescs: arguments.len() as u32,
pArgumentDescs: arguments.as_ptr(),
NodeMask: node_mask,
},
root_signature,
&mut signature,
)
}
.into_device_result("Command signature creation")?;
signature.ok_or(crate::DeviceError::Unexpected)
}

// Blocks until the dedicated present queue is finished with all of its work.
//
// Once this method completes, the surface is able to be resized or deleted.
Expand Down Expand Up @@ -1119,6 +1123,81 @@ impl crate::Device for super::Device {
}
.into_device_result("Root signature creation")?;

let special_constants_cmd_signatures = if let Some(root_index) =
special_constants_root_index
{
let constant_indirect_argument_desc = Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_CONSTANT,
Anonymous: Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC_0 {
Constant: Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC_0_1 {
RootParameterIndex: root_index,
DestOffsetIn32BitValues: 0,
Num32BitValuesToSet: 3,
},
},
};
let special_constant_buffer_args_len = {
// Hack: construct a dummy value of the special constants buffer value we need to
// fill, and calculate the size of each member.
let super::RootElement::SpecialConstantBuffer {
first_vertex,
first_instance,
other,
} = (super::RootElement::SpecialConstantBuffer {
first_vertex: 0,
first_instance: 0,
other: 0,
})
else {
unreachable!();
};
size_of_val(&first_vertex) + size_of_val(&first_instance) + size_of_val(&other)
};
Some(super::CommandSignatures {
draw: Self::create_command_signature(
&self.raw,
Some(&raw),
special_constant_buffer_args_len + size_of::<wgt::DrawIndirectArgs>(),
&[
constant_indirect_argument_desc,
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW,
..Default::default()
},
],
0,
)?,
draw_indexed: Self::create_command_signature(
&self.raw,
Some(&raw),
special_constant_buffer_args_len + size_of::<wgt::DrawIndexedIndirectArgs>(),
&[
constant_indirect_argument_desc,
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW_INDEXED,
..Default::default()
},
],
0,
)?,
dispatch: Self::create_command_signature(
&self.raw,
Some(&raw),
special_constant_buffer_args_len + size_of::<wgt::DispatchIndirectArgs>(),
&[
constant_indirect_argument_desc,
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH,
..Default::default()
},
],
0,
)?,
})
} else {
None
};

if let Some(label) = desc.label {
unsafe { raw.SetName(&windows::core::HSTRING::from(label)) }
.into_device_result("SetName")?;
Expand All @@ -1131,6 +1210,7 @@ impl crate::Device for super::Device {
signature: Some(raw),
total_root_elements: parameters.len() as super::RootIndex,
special_constants_root_index,
special_constants_cmd_signatures,
root_constant_info,
},
bind_group_infos,
Expand Down
6 changes: 6 additions & 0 deletions wgpu-hal/src/dx12/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ struct Idler {
event: Event,
}

#[derive(Debug, Clone)]
struct CommandSignatures {
draw: Direct3D12::ID3D12CommandSignature,
draw_indexed: Direct3D12::ID3D12CommandSignature,
Expand Down Expand Up @@ -636,8 +637,11 @@ enum RootElement {
Empty,
Constant,
SpecialConstantBuffer {
/// The first vertex in an indirect draw call, _or_ the `x` of a compute dispatch.
first_vertex: i32,
/// The first instance in an indirect draw call, _or_ the `y` of a compute dispatch.
first_instance: u32,
/// Unused in an indirect draw call, _or_ the `z` of a compute dispatch.
other: u32,
},
/// Descriptor table.
Expand Down Expand Up @@ -682,6 +686,7 @@ impl PassState {
signature: None,
total_root_elements: 0,
special_constants_root_index: None,
special_constants_cmd_signatures: None,
root_constant_info: None,
},
root_elements: [RootElement::Empty; MAX_ROOT_ELEMENTS],
Expand Down Expand Up @@ -919,6 +924,7 @@ struct PipelineLayoutShared {
signature: Option<Direct3D12::ID3D12RootSignature>,
total_root_elements: RootIndex,
special_constants_root_index: Option<RootIndex>,
special_constants_cmd_signatures: Option<CommandSignatures>,
root_constant_info: Option<RootConstantInfo>,
}

Expand Down

0 comments on commit 03a396d

Please sign in to comment.