From 5ad6afe7e5b071281068f5fc4e21658494b488de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Thu, 24 Sep 2020 20:33:51 +0200 Subject: [PATCH] Add access methods to the generated SoA vectors --- soa-derive-internal/src/index.rs | 340 +++++++++++++++++++++++++++++++ soa-derive-internal/src/lib.rs | 2 + soa-derive-internal/src/vec.rs | 60 ++++++ 3 files changed, 402 insertions(+) create mode 100644 soa-derive-internal/src/index.rs diff --git a/soa-derive-internal/src/index.rs b/soa-derive-internal/src/index.rs new file mode 100644 index 0000000..52e4de6 --- /dev/null +++ b/soa-derive-internal/src/index.rs @@ -0,0 +1,340 @@ +use proc_macro2::TokenStream; +use quote::quote; + +use crate::input::Input; + +pub fn derive(input: &Input) -> TokenStream { + let vec_name = &input.vec_name(); + let slice_name = &input.slice_name(); + let slice_mut_name = &input.slice_mut_name(); + let ref_name = &input.ref_name(); + let ref_mut_name = &input.ref_mut_name(); + let fields_names = input.fields.iter() + .map(|field| field.ident.clone().unwrap()) + .collect::>(); + let fields_names_1 = &fields_names; + let fields_names_2 = &fields_names; + + quote!{ + // usize + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for usize { + type RefOutput = #ref_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + if self < soa.len() { + Some(unsafe { self.get_unchecked(soa) }) + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + #ref_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked(self),)* + } + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + #ref_name { + #(#fields_names_1: & soa.#fields_names_2[self],)* + } + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for usize { + type MutOutput = #ref_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + if self < soa.len() { + Some(unsafe { self.get_unchecked_mut(soa) }) + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #ref_mut_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked_mut(self),)* + } + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #ref_mut_name { + #(#fields_names_1: &mut soa.#fields_names_2[self],)* + } + } + } + + + + // Range + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::Range { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + if self.start <= self.end && self.end <= soa.len() { + None + } else { + unsafe { Some(self.get_unchecked(soa)) } + } + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + #slice_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked(self.clone()),)* + } + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + #slice_name { + #(#fields_names_1: & soa.#fields_names_2[self.clone()],)* + } + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::Range { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + if self.start <= self.end && self.end <= soa.len() { + None + } else { + unsafe { Some(self.get_unchecked_mut(soa)) } + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #slice_mut_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked_mut(self.clone()),)* + } + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #slice_mut_name { + #(#fields_names_1: &mut soa.#fields_names_2[self.clone()],)* + } + } + } + + + + // RangeTo + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeTo { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + (0..self.end).get(soa) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..self.end).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..self.end).index(soa) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeTo { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + (0..self.end).get_mut(soa) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..self.end).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..self.end).index_mut(soa) + } + } + + + // RangeFrom + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeFrom { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + (self.start..soa.len()).get(soa) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (self.start..soa.len()).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (self.start..soa.len()).index(soa) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeFrom { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + (self.start..soa.len()).get_mut(soa) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (self.start..soa.len()).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (self.start..soa.len()).index_mut(soa) + } + } + + + // RangeFull + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeFull { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + Some(soa.as_slice()) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + soa.as_slice() + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + soa.as_slice() + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeFull { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + Some(soa.as_mut_slice()) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + soa.as_mut_slice() + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + soa.as_mut_slice() + } + } + + + // RangeInclusive + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeInclusive { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + if *self.end() == usize::MAX { + None + } else { + (*self.start()..self.end() + 1).get(soa) + } + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (*self.start()..self.end() + 1).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (*self.start()..self.end() + 1).index(soa) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeInclusive { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + if *self.end() == usize::MAX { + None + } else { + (*self.start()..self.end() + 1).get_mut(soa) + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (*self.start()..self.end() + 1).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (*self.start()..self.end() + 1).index_mut(soa) + } + } + + + // RangeToInclusive + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeToInclusive { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + (0..=self.end).get(soa) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..=self.end).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..=self.end).index(soa) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeToInclusive { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + (0..=self.end).get_mut(soa) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..=self.end).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..=self.end).index_mut(soa) + } + } + } +} \ No newline at end of file diff --git a/soa-derive-internal/src/lib.rs b/soa-derive-internal/src/lib.rs index b2510c0..6adacd1 100644 --- a/soa-derive-internal/src/lib.rs +++ b/soa-derive-internal/src/lib.rs @@ -9,6 +9,7 @@ extern crate proc_macro; use proc_macro2::TokenStream; use quote::TokenStreamExt; +mod index; mod input; mod iter; mod ptr; @@ -27,6 +28,7 @@ pub fn soa_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { generated.append_all(ptr::derive(&input)); generated.append_all(slice::derive(&input)); generated.append_all(slice::derive_mut(&input)); + generated.append_all(index::derive(&input)); generated.append_all(iter::derive(&input)); generated.append_all(derive_trait(&input)); generated.into() diff --git a/soa-derive-internal/src/vec.rs b/soa-derive-internal/src/vec.rs index e741c05..9c2dd21 100644 --- a/soa-derive-internal/src/vec.rs +++ b/soa-derive-internal/src/vec.rs @@ -259,6 +259,66 @@ pub fn derive(input: &Input) -> TokenStream { } } + /// Similar to [` + #[doc = #vec_name_str] + /// ::get()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get). + pub fn get<'a, I>(&'a self, index: I) -> Option + where + I: ::soa_derive::SoaIndex<&'a #vec_name> + { + index.get(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::get_unchecked()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_unchecked). + pub unsafe fn get_unchecked<'a, I>(&'a self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoaIndex<&'a #vec_name> + { + index.get_unchecked(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::index()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.index). + pub fn index<'a, I>(&'a self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoaIndex<&'a #vec_name> + { + index.index(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::get_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_mut). + pub fn get_mut<'a, I>(&'a mut self, index: I) -> Option + where + I: ::soa_derive::SoaMutIndex<&'a mut #vec_name> + { + index.get_mut(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::get_unchecked_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_unchecked_mut). + pub unsafe fn get_unchecked_mut<'a, I>(&'a mut self, index: I) -> I::MutOutput + where + I: ::soa_derive::SoaMutIndex<&'a mut #vec_name> + { + index.get_unchecked_mut(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::index_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.index_mut). + pub fn index_mut<'a, I>(&'a mut self, index: I) -> I::MutOutput + where + I: ::soa_derive::SoaMutIndex<&'a mut #vec_name> + { + index.index_mut(self) + } + /// Similar to [` #[doc = #vec_name_str] /// ::as_ptr()`](https://doc.rust-lang.org/std/struct.Vec.html#method.as_ptr).