Skip to content

Commit

Permalink
move command buffer resolving in Global's methods
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Jul 3, 2024
1 parent a9c74f4 commit e26d2d7
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 94 deletions.
24 changes: 20 additions & 4 deletions wgpu-core/src/command/clear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{ops::Range, sync::Arc};
use crate::device::trace::Command as TraceCommand;
use crate::{
api_log,
command::CommandBuffer,
command::CommandEncoderError,
device::DeviceError,
get_lowest_common_denom,
global::Global,
Expand Down Expand Up @@ -76,7 +76,7 @@ whereas subesource range specified start {subresource_base_array_layer} and coun
#[error(transparent)]
Device(#[from] DeviceError),
#[error(transparent)]
CommandEncoderError(#[from] super::CommandEncoderError),
CommandEncoderError(#[from] CommandEncoderError),
}

impl Global {
Expand All @@ -92,7 +92,15 @@ impl Global {

let hub = A::hub(self);

let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;

let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();

Expand Down Expand Up @@ -176,7 +184,15 @@ impl Global {

let hub = A::hub(self);

let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;

let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();

Expand Down
63 changes: 36 additions & 27 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,35 +301,40 @@ impl Global {
timestamp_writes: None, // Handle only once we resolved the encoder.
};

match CommandBuffer::lock_encoder(hub, encoder_id) {
Ok(cmd_buf) => {
arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
let Ok(query_set) = hub.query_sets.get(tw.query_set) else {
return (
ComputePass::new(None, arc_desc),
Some(CommandEncoderError::InvalidTimestampWritesQuerySetId(
tw.query_set,
)),
);
};
let make_err = |e, arc_desc| (ComputePass::new(None, arc_desc), Some(e));

if let Err(e) = query_set.same_device_as(cmd_buf.as_ref()) {
return (ComputePass::new(None, arc_desc), Some(e.into()));
}
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return make_err(CommandEncoderError::Invalid, arc_desc),
};

Some(ArcPassTimestampWrites {
query_set,
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
})
} else {
None
};

(ComputePass::new(Some(cmd_buf), arc_desc), None)
match cmd_buf.lock_encoder() {
Ok(_) => {}
Err(e) => return make_err(e, arc_desc),
};

arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
let Ok(query_set) = hub.query_sets.get(tw.query_set) else {
return make_err(
CommandEncoderError::InvalidTimestampWritesQuerySetId(tw.query_set),
arc_desc,
);
};

if let Err(e) = query_set.same_device_as(cmd_buf.as_ref()) {
return make_err(e.into(), arc_desc);
}
Err(err) => (ComputePass::new(None, arc_desc), Some(err)),
}

Some(ArcPassTimestampWrites {
query_set,
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
})
} else {
None
};

(ComputePass::new(Some(cmd_buf), arc_desc), None)
}

/// Creates a type erased compute pass.
Expand Down Expand Up @@ -378,7 +383,11 @@ impl Global {
let hub = A::hub(self);
let scope = PassErrorScope::Pass;

let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(scope)?;
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid).map_pass_err(scope),
};
cmd_buf.check_recording().map_pass_err(scope)?;

#[cfg(feature = "trace")]
{
Expand Down
94 changes: 42 additions & 52 deletions wgpu-core/src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub use timestamp_writes::PassTimestampWrites;
use self::memory_init::CommandBufferTextureMemoryActions;

use crate::device::{Device, DeviceError};
use crate::hub::Hub;
use crate::lock::{rank, Mutex};
use crate::snatch::SnatchGuard;

Expand Down Expand Up @@ -425,65 +424,41 @@ impl<A: HalApi> CommandBuffer<A> {
}

impl<A: HalApi> CommandBuffer<A> {
fn get_encoder_impl(
hub: &Hub<A>,
id: id::CommandEncoderId,
lock_on_acquire: bool,
) -> Result<Arc<Self>, CommandEncoderError> {
match hub.command_buffers.get(id.into_command_buffer_id()) {
Ok(cmd_buf) => {
let mut cmd_buf_data_guard = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data_guard.as_mut().unwrap();
match cmd_buf_data.status {
CommandEncoderStatus::Recording => {
if lock_on_acquire {
cmd_buf_data.status = CommandEncoderStatus::Locked;
}
drop(cmd_buf_data_guard);
Ok(cmd_buf)
}
CommandEncoderStatus::Locked => {
// Any operation on a locked encoder is required to put it into the invalid/error state.
// See https://www.w3.org/TR/webgpu/#encoder-state-locked
cmd_buf_data.encoder.discard();
cmd_buf_data.status = CommandEncoderStatus::Error;
Err(CommandEncoderError::Locked)
}
CommandEncoderStatus::Finished => Err(CommandEncoderError::NotRecording),
CommandEncoderStatus::Error => Err(CommandEncoderError::Invalid),
fn lock_encoder_impl(&self, lock: bool) -> Result<(), CommandEncoderError> {
let mut cmd_buf_data_guard = self.data.lock();
let cmd_buf_data = cmd_buf_data_guard.as_mut().unwrap();
match cmd_buf_data.status {
CommandEncoderStatus::Recording => {
if lock {
cmd_buf_data.status = CommandEncoderStatus::Locked;
}
Ok(())
}
Err(_) => Err(CommandEncoderError::Invalid),
CommandEncoderStatus::Locked => {
// Any operation on a locked encoder is required to put it into the invalid/error state.
// See https://www.w3.org/TR/webgpu/#encoder-state-locked
cmd_buf_data.encoder.discard();
cmd_buf_data.status = CommandEncoderStatus::Error;
Err(CommandEncoderError::Locked)
}
CommandEncoderStatus::Finished => Err(CommandEncoderError::NotRecording),
CommandEncoderStatus::Error => Err(CommandEncoderError::Invalid),
}
}

/// Return the [`CommandBuffer`] for `id`, for recording new commands.
///
/// In `wgpu_core`, the [`CommandBuffer`] type serves both as encoder and
/// buffer, which is why this function takes an [`id::CommandEncoderId`] but
/// returns a [`CommandBuffer`]. The returned command buffer must be in the
/// "recording" state. Otherwise, an error is returned.
fn get_encoder(
hub: &Hub<A>,
id: id::CommandEncoderId,
) -> Result<Arc<Self>, CommandEncoderError> {
let lock_on_acquire = false;
Self::get_encoder_impl(hub, id, lock_on_acquire)
/// Checks that the encoder is in the [`CommandEncoderStatus::Recording`] state.
fn check_recording(&self) -> Result<(), CommandEncoderError> {
self.lock_encoder_impl(false)
}

/// Return the [`CommandBuffer`] for `id` and if successful puts it into the [`CommandEncoderStatus::Locked`] state.
/// Locks the encoder by putting it in the [`CommandEncoderStatus::Locked`] state.
///
/// See [`CommandBuffer::get_encoder`].
/// Call [`CommandBuffer::unlock_encoder`] to put the [`CommandBuffer`] back into the [`CommandEncoderStatus::Recording`] state.
fn lock_encoder(
hub: &Hub<A>,
id: id::CommandEncoderId,
) -> Result<Arc<Self>, CommandEncoderError> {
let lock_on_acquire = true;
Self::get_encoder_impl(hub, id, lock_on_acquire)
fn lock_encoder(&self) -> Result<(), CommandEncoderError> {
self.lock_encoder_impl(true)
}

/// Unlocks the [`CommandBuffer`] for `id` and puts it back into the [`CommandEncoderStatus::Recording`] state.
/// Unlocks the [`CommandBuffer`] and puts it back into the [`CommandEncoderStatus::Recording`] state.
///
/// This function is the counterpart to [`CommandBuffer::lock_encoder`].
/// It is only valid to call this function if the encoder is in the [`CommandEncoderStatus::Locked`] state.
Expand Down Expand Up @@ -661,7 +636,12 @@ impl Global {

let hub = A::hub(self);

let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)?;
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid),
};
cmd_buf.check_recording()?;

let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();
#[cfg(feature = "trace")]
Expand Down Expand Up @@ -692,7 +672,12 @@ impl Global {

let hub = A::hub(self);

let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)?;
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid),
};
cmd_buf.check_recording()?;

let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();

Expand Down Expand Up @@ -723,7 +708,12 @@ impl Global {

let hub = A::hub(self);

let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)?;
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid),
};
cmd_buf.check_recording()?;

let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();

Expand Down
19 changes: 17 additions & 2 deletions wgpu-core/src/command/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,14 @@ impl Global {
) -> Result<(), QueryError> {
let hub = A::hub(self);

let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;

cmd_buf
.device
Expand Down Expand Up @@ -369,7 +376,15 @@ impl Global {
) -> Result<(), QueryError> {
let hub = A::hub(self);

let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;

let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();

Expand Down
18 changes: 14 additions & 4 deletions wgpu-core/src/command/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1432,9 +1432,16 @@ impl Global {
occlusion_query_set: None,
};

let cmd_buf = match CommandBuffer::lock_encoder(hub, encoder_id) {
let make_err = |e, arc_desc| (RenderPass::new(None, arc_desc), Some(e));

let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(e) => return (RenderPass::new(None, arc_desc), Some(e)),
Err(_) => return make_err(CommandEncoderError::Invalid, arc_desc),
};

match cmd_buf.lock_encoder() {
Ok(_) => {}
Err(e) => return make_err(e, arc_desc),
};

let err = fill_arc_desc(hub, &cmd_buf.device, desc, &mut arc_desc).err();
Expand Down Expand Up @@ -1471,8 +1478,11 @@ impl Global {
#[cfg(feature = "trace")]
{
let hub = A::hub(self);
let cmd_buf: Arc<CommandBuffer<A>> =
CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(pass_scope)?;

let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid).map_pass_err(pass_scope)?,
};

let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();
Expand Down
Loading

0 comments on commit e26d2d7

Please sign in to comment.