Skip to content

Commit

Permalink
Change from WeakPtr to ComPtr which has proper ownership semantics
Browse files Browse the repository at this point in the history
This makes it a lot easier to avoid making memory management
mistakes. It also is closer to the semantics that windows-rs
exposes for it's bindings.
  • Loading branch information
jrmuizel committed Jul 17, 2023
1 parent 7e8051e commit a6fa689
Show file tree
Hide file tree
Showing 15 changed files with 112 additions and 125 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "d3d12"
version = "0.6.0"
version = "0.7.0"
authors = [
"gfx-rs developers",
]
Expand Down
82 changes: 39 additions & 43 deletions src/com.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@ use std::{
use winapi::{ctypes::c_void, um::unknwnbase::IUnknown, Interface};

#[repr(transparent)]
pub struct WeakPtr<T>(*mut T);
pub struct ComPtr<T: Interface>(*mut T);

impl<T> WeakPtr<T> {
impl<T: Interface> ComPtr<T> {
pub fn null() -> Self {
WeakPtr(ptr::null_mut())
ComPtr(ptr::null_mut())
}

pub unsafe fn from_raw(raw: *mut T) -> Self {
WeakPtr(raw)
if !raw.is_null() {
(&*(raw as *mut IUnknown)).AddRef();
}
ComPtr(raw)
}

pub fn is_null(&self) -> bool {
Expand All @@ -40,66 +43,68 @@ impl<T> WeakPtr<T> {
}
}

impl<T: Interface> WeakPtr<T> {
impl<T: Interface> ComPtr<T> {
pub unsafe fn as_unknown(&self) -> &IUnknown {
debug_assert!(!self.is_null());
&*(self.0 as *mut IUnknown)
}

// Cast creates a new WeakPtr requiring explicit destroy call.
pub unsafe fn cast<U>(&self) -> D3DResult<WeakPtr<U>>
pub unsafe fn cast<U>(&self) -> D3DResult<ComPtr<U>>
where
U: Interface,
{
let mut obj = WeakPtr::<U>::null();
debug_assert!(!self.is_null());
let mut obj = ComPtr::<U>::null();
let hr = self
.as_unknown()
.QueryInterface(&U::uuidof(), obj.mut_void());
(obj, hr)
}

// Destroying one instance of the WeakPtr will invalidate all
// copies and clones.
pub unsafe fn destroy(&self) {
self.as_unknown().Release();
}
}

impl<T> Clone for WeakPtr<T> {
impl<T: Interface> Clone for ComPtr<T> {
fn clone(&self) -> Self {
WeakPtr(self.0)
debug_assert!(!self.is_null());
unsafe { self.as_unknown().AddRef(); }
ComPtr(self.0)
}
}

impl<T> Copy for WeakPtr<T> {}
impl<T: Interface> Drop for ComPtr<T> {
fn drop(&mut self) {
if !self.0.is_null() {
unsafe { self.as_unknown().Release(); }
}
}
}

impl<T> Deref for WeakPtr<T> {
impl<T: Interface> Deref for ComPtr<T> {
type Target = T;
fn deref(&self) -> &T {
debug_assert!(!self.is_null());
unsafe { &*self.0 }
}
}

impl<T> fmt::Debug for WeakPtr<T> {
impl<T: Interface> fmt::Debug for ComPtr<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WeakPtr( ptr: {:?} )", self.0)
write!(f, "ComPtr( ptr: {:?} )", self.0)
}
}

impl<T> PartialEq<*mut T> for WeakPtr<T> {
impl<T: Interface> PartialEq<*mut T> for ComPtr<T> {
fn eq(&self, other: &*mut T) -> bool {
self.0 == *other
}
}

impl<T> PartialEq for WeakPtr<T> {
impl<T: Interface> PartialEq for ComPtr<T> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}

impl<T> Hash for WeakPtr<T> {
impl<T: Interface> Hash for ComPtr<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
Expand All @@ -110,9 +115,9 @@ impl<T> Hash for WeakPtr<T> {
/// Give the variants so that parents come before children. This often manifests as going up in order (1 -> 2 -> 3). This is vital for safety.
///
/// Three function names need to be attached to each variant. The examples are given for the MyComObject1 variant below:
/// - the from function (`WeakPtr<actual::ComObject1> -> Self`)
/// - the as function (`&self -> Option<WeakPtr<actual::ComObject1>>`)
/// - the unwrap function (`&self -> WeakPtr<actual::ComObject1>` panicing on failure to cast)
/// - the from function (`ComPtr<actual::ComObject1> -> Self`)
/// - the as function (`&self -> Option<ComPtr<actual::ComObject1>>`)
/// - the unwrap function (`&self -> ComPtr<actual::ComObject1>` panicing on failure to cast)
///
/// ```rust
/// # pub use d3d12::weak_com_inheritance_chain;
Expand Down Expand Up @@ -145,21 +150,12 @@ macro_rules! weak_com_inheritance_chain {
) => {
$(#[$meta])*
$vis enum $name {
$first_variant($crate::WeakPtr<$first_type>),
$first_variant($crate::ComPtr<$first_type>),
$(
$variant($crate::WeakPtr<$type>)
$variant($crate::ComPtr<$type>)
),+
}
impl $name {
$vis unsafe fn destroy(&self) {
match *self {
Self::$first_variant(v) => v.destroy(),
$(
Self::$variant(v) => v.destroy(),
)*
}
}

$crate::weak_com_inheritance_chain! {
@recursion_logic,
$vis,
Expand All @@ -170,7 +166,7 @@ macro_rules! weak_com_inheritance_chain {
}

impl std::ops::Deref for $name {
type Target = $crate::WeakPtr<$first_type>;
type Target = $crate::ComPtr<$first_type>;
fn deref(&self) -> &Self::Target {
self.$first_unwrap_name()
}
Expand Down Expand Up @@ -223,20 +219,20 @@ macro_rules! weak_com_inheritance_chain {
$($next_variant:ident),*;
) => {
// Construct this enum from weak pointer to this interface. For best usability, always use the highest constructor you can. This doesn't try to upcast.
$vis unsafe fn $from_name(value: $crate::WeakPtr<$type>) -> Self {
$vis unsafe fn $from_name(value: $crate::ComPtr<$type>) -> Self {
Self::$variant(value)
}

// Returns Some if the value implements the interface otherwise returns None.
$vis fn $as_name(&self) -> Option<&$crate::WeakPtr<$type>> {
$vis fn $as_name(&self) -> Option<&$crate::ComPtr<$type>> {
match *self {
$(
Self::$prev_variant(_) => None,
)*
Self::$variant(ref v) => Some(v),
$(
Self::$next_variant(ref v) => {
// v is &WeakPtr<NextType> and se cast to &WeakPtr<Type>
// v is &ComPtr<NextType> and se cast to &ComPtr<Type>
Some(unsafe { std::mem::transmute(v) })
}
)*
Expand All @@ -245,15 +241,15 @@ macro_rules! weak_com_inheritance_chain {

// Returns the interface if the value implements it, otherwise panics.
#[track_caller]
$vis fn $unwrap_name(&self) -> &$crate::WeakPtr<$type> {
$vis fn $unwrap_name(&self) -> &$crate::ComPtr<$type> {
match *self {
$(
Self::$prev_variant(_) => panic!(concat!("Tried to unwrap a ", stringify!($prev_variant), " as a ", stringify!($variant))),
)*
Self::$variant(ref v) => &*v,
$(
Self::$next_variant(ref v) => {
// v is &WeakPtr<NextType> and se cast to &WeakPtr<Type>
// v is &ComPtr<NextType> and se cast to &ComPtr<Type>
unsafe { std::mem::transmute(v) }
}
)*
Expand Down
4 changes: 2 additions & 2 deletions src/command_allocator.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! Command Allocator
use crate::com::WeakPtr;
use crate::com::ComPtr;
use winapi::um::d3d12;

pub type CommandAllocator = WeakPtr<d3d12::ID3D12CommandAllocator>;
pub type CommandAllocator = ComPtr<d3d12::ID3D12CommandAllocator>;

impl CommandAllocator {
pub fn reset(&self) {
Expand Down
16 changes: 8 additions & 8 deletions src/command_list.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Graphics command list
use crate::{
com::WeakPtr, resource::DiscardRegion, CommandAllocator, CpuDescriptor, DescriptorHeap, Format,
com::ComPtr, resource::DiscardRegion, CommandAllocator, CpuDescriptor, DescriptorHeap, Format,
GpuAddress, GpuDescriptor, IndexCount, InstanceCount, PipelineState, Rect, Resource, RootIndex,
RootSignature, Subresource, VertexCount, VertexOffset, WorkGroupCount, HRESULT,
};
Expand Down Expand Up @@ -140,9 +140,9 @@ impl ResourceBarrier {
}
}

pub type CommandSignature = WeakPtr<d3d12::ID3D12CommandSignature>;
pub type CommandList = WeakPtr<d3d12::ID3D12CommandList>;
pub type GraphicsCommandList = WeakPtr<d3d12::ID3D12GraphicsCommandList>;
pub type CommandSignature = ComPtr<d3d12::ID3D12CommandSignature>;
pub type CommandList = ComPtr<d3d12::ID3D12CommandList>;
pub type GraphicsCommandList = ComPtr<d3d12::ID3D12GraphicsCommandList>;

impl GraphicsCommandList {
pub fn as_list(&self) -> CommandList {
Expand All @@ -153,7 +153,7 @@ impl GraphicsCommandList {
unsafe { self.Close() }
}

pub fn reset(&self, allocator: CommandAllocator, initial_pso: PipelineState) -> HRESULT {
pub fn reset(&self, allocator: &CommandAllocator, initial_pso: PipelineState) -> HRESULT {
unsafe { self.Reset(allocator.as_mut_ptr(), initial_pso.as_mut_ptr()) }
}

Expand Down Expand Up @@ -263,7 +263,7 @@ impl GraphicsCommandList {
}
}

pub fn set_pipeline_state(&self, pso: PipelineState) {
pub fn set_pipeline_state(&self, pso:&PipelineState) {
unsafe {
self.SetPipelineState(pso.as_mut_ptr());
}
Expand All @@ -284,13 +284,13 @@ impl GraphicsCommandList {
}
}

pub fn set_compute_root_signature(&self, signature: RootSignature) {
pub fn set_compute_root_signature(&self, signature: &RootSignature) {
unsafe {
self.SetComputeRootSignature(signature.as_mut_ptr());
}
}

pub fn set_graphics_root_signature(&self, signature: RootSignature) {
pub fn set_graphics_root_signature(&self, signature: &RootSignature) {
unsafe {
self.SetGraphicsRootSignature(signature.as_mut_ptr());
}
Expand Down
4 changes: 2 additions & 2 deletions src/debug.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::com::WeakPtr;
use crate::com::ComPtr;
use winapi::um::d3d12sdklayers;
#[cfg(any(feature = "libloading", feature = "implicit-link"))]
use winapi::Interface as _;

pub type Debug = WeakPtr<d3d12sdklayers::ID3D12Debug>;
pub type Debug = ComPtr<d3d12sdklayers::ID3D12Debug>;

#[cfg(feature = "libloading")]
impl crate::D3D12Lib {
Expand Down
6 changes: 3 additions & 3 deletions src/descriptor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{com::WeakPtr, Blob, D3DResult, Error, TextureAddressMode};
use crate::{com::ComPtr, Blob, D3DResult, Error, TextureAddressMode};
use std::{fmt, mem, ops::Range};
use winapi::{shared::dxgiformat, um::d3d12};

Expand Down Expand Up @@ -27,7 +27,7 @@ bitflags! {
}
}

pub type DescriptorHeap = WeakPtr<d3d12::ID3D12DescriptorHeap>;
pub type DescriptorHeap = ComPtr<d3d12::ID3D12DescriptorHeap>;

impl DescriptorHeap {
pub fn start_cpu_descriptor(&self) -> CpuDescriptor {
Expand Down Expand Up @@ -265,7 +265,7 @@ bitflags! {
}
}

pub type RootSignature = WeakPtr<d3d12::ID3D12RootSignature>;
pub type RootSignature = ComPtr<d3d12::ID3D12RootSignature>;
pub type BlobResult = D3DResult<(Blob, Error)>;

#[cfg(feature = "libloading")]
Expand Down
14 changes: 7 additions & 7 deletions src/device.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Device
use crate::{
com::WeakPtr,
com::ComPtr,
command_list::{CmdListType, CommandSignature, IndirectArgument},
descriptor::{CpuDescriptor, DescriptorHeapFlags, DescriptorHeapType, RenderTargetViewDesc},
heap::{Heap, HeapFlags, HeapProperties},
Expand All @@ -12,13 +12,13 @@ use crate::{
use std::ops::Range;
use winapi::{um::d3d12, Interface};

pub type Device = WeakPtr<d3d12::ID3D12Device>;
pub type Device = ComPtr<d3d12::ID3D12Device>;

#[cfg(feature = "libloading")]
impl crate::D3D12Lib {
pub fn create_device<I: Interface>(
pub fn create_device<I: Interface>(
&self,
adapter: WeakPtr<I>,
adapter: &ComPtr<I>,
feature_level: crate::FeatureLevel,
) -> Result<D3DResult<Device>, libloading::Error> {
type Fun = extern "system" fn(
Expand Down Expand Up @@ -46,7 +46,7 @@ impl crate::D3D12Lib {
impl Device {
#[cfg(feature = "implicit-link")]
pub fn create<I: Interface>(
adapter: WeakPtr<I>,
adapter: ComPtr<I>,
feature_level: crate::FeatureLevel,
) -> D3DResult<Self> {
let mut device = Device::null();
Expand Down Expand Up @@ -155,7 +155,7 @@ impl Device {
pub fn create_graphics_command_list(
&self,
list_type: CmdListType,
allocator: CommandAllocator,
allocator: &CommandAllocator,
initial: PipelineState,
node_mask: NodeMask,
) -> D3DResult<GraphicsCommandList> {
Expand Down Expand Up @@ -215,7 +215,7 @@ impl Device {

pub fn create_compute_pipeline_state(
&self,
root_signature: RootSignature,
root_signature: &RootSignature,
cs: Shader,
node_mask: NodeMask,
cached_pso: CachedPSO,
Expand Down
Loading

0 comments on commit a6fa689

Please sign in to comment.