Skip to content

Commit

Permalink
Added required_subgroup_size to PipelineShaderStageCreateInfo (#2235)
Browse files Browse the repository at this point in the history
* Added required_subgroup_size to PipelineShaderStageCreateInfo

* Added validation errors.

* Fixed error msgs / vuids.

* ComputeShaderExecution for validating local_size.

* WorkgroupSizeId reflection.

* contains_enum

* Reworked ComputeShaderExecution.

* panic msgs.

* workgroup size validation

* unused import

* fixed test deprecated fn

* catch workgroup size overflow

* EntryPointInfo::local_size docs

* comments

* typo + error msg
  • Loading branch information
charles-r-earp authored Aug 18, 2023
1 parent 4133a3b commit ee4e308
Show file tree
Hide file tree
Showing 6 changed files with 676 additions and 18 deletions.
44 changes: 43 additions & 1 deletion vulkano-shaders/src/entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,49 @@ fn write_shader_execution(execution: &ShaderExecution) -> TokenStream {
)
}
}
ShaderExecution::Compute => quote! { ::vulkano::shader::ShaderExecution::Compute },
ShaderExecution::Compute(execution) => {
use ::quote::ToTokens;
use ::vulkano::shader::{ComputeShaderExecution, LocalSize};

struct LocalSizeToTokens(LocalSize);

impl ToTokens for LocalSizeToTokens {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self.0 {
LocalSize::Literal(literal) => quote! {
::vulkano::shader::LocalSize::Literal(#literal)
},
LocalSize::SpecId(id) => quote! {
::vulkano::shader::LocalSize::SpecId(#id)
},
}
.to_tokens(tokens);
}
}

match execution {
ComputeShaderExecution::LocalSize([x, y, z]) => {
let [x, y, z] = [
LocalSizeToTokens(*x),
LocalSizeToTokens(*y),
LocalSizeToTokens(*z),
];
quote! { ::vulkano::shader::ShaderExecution::Compute(
::vulkano::shader::ComputeShaderExecution::LocalSize([#x, #y, #z])
) }
}
ComputeShaderExecution::LocalSizeId([x, y, z]) => {
let [x, y, z] = [
LocalSizeToTokens(*x),
LocalSizeToTokens(*y),
LocalSizeToTokens(*z),
];
quote! { ::vulkano::shader::ShaderExecution::Compute(
::vulkano::shader::ComputeShaderExecution::LocalSizeId([#x, #y, #z])
) }
}
}
}
ShaderExecution::RayGeneration => {
quote! { ::vulkano::shader::ShaderExecution::RayGeneration }
}
Expand Down
151 changes: 148 additions & 3 deletions vulkano/src/pipeline/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,9 @@ impl ComputePipeline {
if let Some(cache) = &cache {
assert_eq!(device, cache.device().as_ref());
}

create_info
.validate(device)
.map_err(|err| err.add_context("create_info"))?;

Ok(())
}

Expand All @@ -100,12 +98,14 @@ impl ComputePipeline {
let specialization_info_vk;
let specialization_map_entries_vk: Vec<_>;
let mut specialization_data_vk: Vec<u8>;
let required_subgroup_size_create_info;

{
let &PipelineShaderStageCreateInfo {
flags,
ref entry_point,
ref specialization_info,
ref required_subgroup_size,
_ne: _,
} = stage;

Expand Down Expand Up @@ -135,7 +135,20 @@ impl ComputePipeline {
data_size: specialization_data_vk.len(),
p_data: specialization_data_vk.as_ptr() as *const _,
};
required_subgroup_size_create_info =
required_subgroup_size.map(|required_subgroup_size| {
ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo {
required_subgroup_size,
..Default::default()
}
});
stage_vk = ash::vk::PipelineShaderStageCreateInfo {
p_next: required_subgroup_size_create_info.as_ref().map_or(
ptr::null(),
|required_subgroup_size_create_info| {
required_subgroup_size_create_info as *const _ as _
},
),
flags: flags.into(),
stage: ShaderStage::from(&entry_point_info.execution).into(),
module: entry_point.module().handle(),
Expand Down Expand Up @@ -333,12 +346,13 @@ impl ComputePipelineCreateInfo {
flags: _,
ref entry_point,
specialization_info: _,
required_subgroup_size: _vk,
_ne: _,
} = &stage;

let entry_point_info = entry_point.info();

if !matches!(entry_point_info.execution, ShaderExecution::Compute) {
if !matches!(entry_point_info.execution, ShaderExecution::Compute(_)) {
return Err(Box::new(ValidationError {
context: "stage.entry_point".into(),
problem: "is not a `ShaderStage::Compute` entry point".into(),
Expand Down Expand Up @@ -514,4 +528,135 @@ mod tests {
let data_buffer_content = data_buffer.read().unwrap();
assert_eq!(*data_buffer_content, 0x12345678);
}

#[test]
fn required_subgroup_size() {
// This test checks whether required_subgroup_size works.
// It executes a single compute shader (one invocation) that writes the subgroup size
// to a buffer. The buffer content is then checked for the right value.

let (device, queue) = gfx_dev_and_queue!(subgroup_size_control);

let cs = unsafe {
/*
#version 450
#extension GL_KHR_shader_subgroup_basic: enable
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer Output {
uint write;
} write;
void main() {
if (gl_GlobalInvocationID.x == 0) {
write.write = gl_SubgroupSize;
}
}
*/
const MODULE: [u32; 246] = [
119734787, 65536, 851978, 30, 0, 131089, 1, 131089, 61, 393227, 1, 1280527431,
1685353262, 808793134, 0, 196622, 0, 1, 458767, 5, 4, 1852399981, 0, 9, 23, 393232,
4, 17, 128, 1, 1, 196611, 2, 450, 655364, 1197427783, 1279741775, 1885560645,
1953718128, 1600482425, 1701734764, 1919509599, 1769235301, 25974, 524292,
1197427783, 1279741775, 1852399429, 1685417059, 1768185701, 1952671090, 6649449,
589828, 1264536647, 1935626824, 1701077352, 1970495346, 1869768546, 1650421877,
1667855201, 0, 262149, 4, 1852399981, 0, 524293, 9, 1197436007, 1633841004,
1986939244, 1952539503, 1231974249, 68, 262149, 18, 1886680399, 29813, 327686, 18,
0, 1953067639, 101, 262149, 20, 1953067639, 101, 393221, 23, 1398762599,
1919378037, 1399879023, 6650473, 262215, 9, 11, 28, 327752, 18, 0, 35, 0, 196679,
18, 3, 262215, 20, 34, 0, 262215, 20, 33, 0, 196679, 23, 0, 262215, 23, 11, 36,
196679, 24, 0, 262215, 29, 11, 25, 131091, 2, 196641, 3, 2, 262165, 6, 32, 0,
262167, 7, 6, 3, 262176, 8, 1, 7, 262203, 8, 9, 1, 262187, 6, 10, 0, 262176, 11, 1,
6, 131092, 14, 196638, 18, 6, 262176, 19, 2, 18, 262203, 19, 20, 2, 262165, 21, 32,
1, 262187, 21, 22, 0, 262203, 11, 23, 1, 262176, 25, 2, 6, 262187, 6, 27, 128,
262187, 6, 28, 1, 393260, 7, 29, 27, 28, 28, 327734, 2, 4, 0, 3, 131320, 5, 327745,
11, 12, 9, 10, 262205, 6, 13, 12, 327850, 14, 15, 13, 10, 196855, 17, 0, 262394,
15, 16, 17, 131320, 16, 262205, 6, 24, 23, 327745, 25, 26, 20, 22, 196670, 26, 24,
131321, 17, 131320, 17, 65789, 65592,
];
let module =
ShaderModule::new(device.clone(), ShaderModuleCreateInfo::new(&MODULE)).unwrap();
module.entry_point("main").unwrap()
};

let properties = device.physical_device().properties();
let subgroup_size = properties.min_subgroup_size.unwrap_or(1);

let pipeline = {
let stage = PipelineShaderStageCreateInfo {
required_subgroup_size: Some(subgroup_size),
..PipelineShaderStageCreateInfo::new(cs)
};
let layout = PipelineLayout::new(
device.clone(),
PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage])
.into_pipeline_layout_create_info(device.clone())
.unwrap(),
)
.unwrap();
ComputePipeline::new(
device.clone(),
None,
ComputePipelineCreateInfo::stage_layout(stage, layout),
)
.unwrap()
};

let memory_allocator = StandardMemoryAllocator::new_default(device.clone());
let data_buffer = Buffer::from_data(
&memory_allocator,
BufferCreateInfo {
usage: BufferUsage::STORAGE_BUFFER,
..Default::default()
},
AllocationCreateInfo {
memory_type_filter: MemoryTypeFilter::PREFER_DEVICE
| MemoryTypeFilter::HOST_RANDOM_ACCESS,
..Default::default()
},
0,
)
.unwrap();

let ds_allocator = StandardDescriptorSetAllocator::new(device.clone());
let set = PersistentDescriptorSet::new(
&ds_allocator,
pipeline.layout().set_layouts().get(0).unwrap().clone(),
[WriteDescriptorSet::buffer(0, data_buffer.clone())],
[],
)
.unwrap();

let cb_allocator = StandardCommandBufferAllocator::new(device.clone(), Default::default());
let mut cbb = AutoCommandBufferBuilder::primary(
&cb_allocator,
queue.queue_family_index(),
CommandBufferUsage::OneTimeSubmit,
)
.unwrap();
cbb.bind_pipeline_compute(pipeline.clone())
.unwrap()
.bind_descriptor_sets(
PipelineBindPoint::Compute,
pipeline.layout().clone(),
0,
set,
)
.unwrap()
.dispatch([128, 1, 1])
.unwrap();
let cb = cbb.build().unwrap();

let future = now(device)
.then_execute(queue, cb)
.unwrap()
.then_signal_fence_and_flush()
.unwrap();
future.wait(None).unwrap();

let data_buffer_content = data_buffer.read().unwrap();
assert_eq!(*data_buffer_content, subgroup_size);
}
}
21 changes: 19 additions & 2 deletions vulkano/src/pipeline/graphics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ impl GraphicsPipeline {
create_info
.validate(device)
.map_err(|err| err.add_context("create_info"))?;

Ok(())
}

Expand Down Expand Up @@ -204,6 +203,8 @@ impl GraphicsPipeline {
specialization_info_vk: ash::vk::SpecializationInfo,
specialization_map_entries_vk: Vec<ash::vk::SpecializationMapEntry>,
specialization_data_vk: Vec<u8>,
required_subgroup_size_create_info:
Option<ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo>,
}

let (mut stages_vk, mut per_stage_vk): (SmallVec<[_; 5]>, SmallVec<[_; 5]>) = stages
Expand All @@ -213,6 +214,7 @@ impl GraphicsPipeline {
flags,
ref entry_point,
ref specialization_info,
ref required_subgroup_size,
_ne: _,
} = stage;

Expand All @@ -235,7 +237,13 @@ impl GraphicsPipeline {
}
})
.collect();

let required_subgroup_size_create_info =
required_subgroup_size.map(|required_subgroup_size| {
ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo {
required_subgroup_size,
..Default::default()
}
});
(
ash::vk::PipelineShaderStageCreateInfo {
flags: flags.into(),
Expand All @@ -255,6 +263,7 @@ impl GraphicsPipeline {
},
specialization_map_entries_vk,
specialization_data_vk,
required_subgroup_size_create_info,
},
)
})
Expand All @@ -267,10 +276,17 @@ impl GraphicsPipeline {
specialization_info_vk,
specialization_map_entries_vk,
specialization_data_vk,
required_subgroup_size_create_info,
},
) in (stages_vk.iter_mut()).zip(per_stage_vk.iter_mut())
{
*stage_vk = ash::vk::PipelineShaderStageCreateInfo {
p_next: required_subgroup_size_create_info.as_ref().map_or(
ptr::null(),
|required_subgroup_size_create_info| {
required_subgroup_size_create_info as *const _ as _
},
),
p_name: name_vk.as_ptr(),
p_specialization_info: specialization_info_vk,
..*stage_vk
Expand Down Expand Up @@ -2420,6 +2436,7 @@ impl GraphicsPipelineCreateInfo {
flags: _,
ref entry_point,
specialization_info: _,
required_subgroup_size: _vk,
_ne: _,
} = stage;

Expand Down
Loading

0 comments on commit ee4e308

Please sign in to comment.