From e8330e65cae1b29b954573d7cb971f4ffe4e6751 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Thu, 24 Sep 2020 20:32:31 +0200 Subject: [PATCH 01/14] Add index traits --- src/lib.rs | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index eb4612f..dc6109a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -168,6 +168,53 @@ pub trait StructOfArray { type Type; } + +mod private_soa_indexs { + // From [`std::slice::SliceIndex`](https://doc.rust-lang.org/std/slice/trait.SliceIndex.html) code. + // Limits the types that may implement the SoA index traits. + // It's also helpful to have the exaustive list of all accepted types. + + use ::std::ops; + + pub trait Sealed {} + + impl Sealed for usize {} // [a] + impl Sealed for ops::Range {} // [a..b] + impl Sealed for ops::RangeTo {} // [..b] + impl Sealed for ops::RangeFrom {} // [a..] + impl Sealed for ops::RangeFull {} // [..] + impl Sealed for ops::RangeInclusive {} // [a..=b] + impl Sealed for ops::RangeToInclusive {} // [..=b] +} + +/// Helper trait used for indexing operations. +/// Inspired by [`std::slice::SliceIndex`](https://doc.rust-lang.org/std/slice/trait.SliceIndex.html). +pub trait SoaIndex: private_soa_indexs::Sealed { + /// The output for the non-mutable functions + type RefOutput; + + /// Returns the reference output in this location if in bounds. None otherwise. + fn get(self, soa: T) -> Option; + /// Returns the reference output in this location withotu performing any bounds check. + unsafe fn get_unchecked(self, soa: T) -> Self::RefOutput; + /// Returns the reference output in this location. Panics if it is not in bounds. + fn index(self, soa: T) -> Self::RefOutput; +} + +/// Helper trait used for indexing operations returning mutable. +/// Inspired by [`std::slice::SliceIndex`](https://doc.rust-lang.org/std/slice/trait.SliceIndex.html). +pub trait SoaMutIndex: private_soa_indexs::Sealed { + /// The output for the mutable functions + type MutOutput; + + /// Returns the mutable reference output in this location if in bounds. None otherwise. + fn get_mut(self, soa: T) -> Option; + /// Returns the mutable reference output in this location withotu performing any bounds check. + unsafe fn get_unchecked_mut(self, soa: T) -> Self::MutOutput; + /// Returns the mutable reference output in this location. Panics if it is not in bounds. + fn index_mut(self, soa: T) -> Self::MutOutput; +} + /// Create an iterator over multiple fields in a Struct of array style vector. /// /// This macro takes two main arguments: the array/slice container, and a list From 1ce49e45d88fc3e6540462c62209c7e5b706af1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Thu, 24 Sep 2020 20:33:51 +0200 Subject: [PATCH 02/14] Add access methods to the generated SoA vectors --- soa-derive-internal/src/index.rs | 340 +++++++++++++++++++++++++++++++ soa-derive-internal/src/lib.rs | 2 + soa-derive-internal/src/vec.rs | 60 ++++++ 3 files changed, 402 insertions(+) create mode 100644 soa-derive-internal/src/index.rs diff --git a/soa-derive-internal/src/index.rs b/soa-derive-internal/src/index.rs new file mode 100644 index 0000000..3bbb37a --- /dev/null +++ b/soa-derive-internal/src/index.rs @@ -0,0 +1,340 @@ +use proc_macro2::TokenStream; +use quote::quote; + +use crate::input::Input; + +pub fn derive(input: &Input) -> TokenStream { + let vec_name = &input.vec_name(); + let slice_name = &input.slice_name(); + let slice_mut_name = &input.slice_mut_name(); + let ref_name = &input.ref_name(); + let ref_mut_name = &input.ref_mut_name(); + let fields_names = input.fields.iter() + .map(|field| field.ident.clone().unwrap()) + .collect::>(); + let fields_names_1 = &fields_names; + let fields_names_2 = &fields_names; + + quote!{ + // usize + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for usize { + type RefOutput = #ref_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + if self < soa.len() { + Some(unsafe { self.get_unchecked(soa) }) + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + #ref_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked(self),)* + } + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + #ref_name { + #(#fields_names_1: & soa.#fields_names_2[self],)* + } + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for usize { + type MutOutput = #ref_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + if self < soa.len() { + Some(unsafe { self.get_unchecked_mut(soa) }) + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #ref_mut_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked_mut(self),)* + } + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #ref_mut_name { + #(#fields_names_1: &mut soa.#fields_names_2[self],)* + } + } + } + + + + // Range + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::Range { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + if self.start <= self.end && self.end <= soa.len() { + unsafe { Some(self.get_unchecked(soa)) } + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + #slice_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked(self.clone()),)* + } + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + #slice_name { + #(#fields_names_1: & soa.#fields_names_2[self.clone()],)* + } + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::Range { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + if self.start <= self.end && self.end <= soa.len() { + unsafe { Some(self.get_unchecked_mut(soa)) } + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #slice_mut_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked_mut(self.clone()),)* + } + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #slice_mut_name { + #(#fields_names_1: &mut soa.#fields_names_2[self.clone()],)* + } + } + } + + + + // RangeTo + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeTo { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + (0..self.end).get(soa) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..self.end).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..self.end).index(soa) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeTo { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + (0..self.end).get_mut(soa) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..self.end).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..self.end).index_mut(soa) + } + } + + + // RangeFrom + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeFrom { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + (self.start..soa.len()).get(soa) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (self.start..soa.len()).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (self.start..soa.len()).index(soa) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeFrom { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + (self.start..soa.len()).get_mut(soa) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (self.start..soa.len()).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (self.start..soa.len()).index_mut(soa) + } + } + + + // RangeFull + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeFull { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + Some(soa.as_slice()) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + soa.as_slice() + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + soa.as_slice() + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeFull { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + Some(soa.as_mut_slice()) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + soa.as_mut_slice() + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + soa.as_mut_slice() + } + } + + + // RangeInclusive + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeInclusive { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + if *self.end() == usize::MAX { + None + } else { + (*self.start()..self.end() + 1).get(soa) + } + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (*self.start()..self.end() + 1).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (*self.start()..self.end() + 1).index(soa) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeInclusive { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + if *self.end() == usize::MAX { + None + } else { + (*self.start()..self.end() + 1).get_mut(soa) + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (*self.start()..self.end() + 1).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (*self.start()..self.end() + 1).index_mut(soa) + } + } + + + // RangeToInclusive + impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeToInclusive { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + (0..=self.end).get(soa) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..=self.end).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..=self.end).index(soa) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeToInclusive { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + (0..=self.end).get_mut(soa) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..=self.end).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..=self.end).index_mut(soa) + } + } + } +} \ No newline at end of file diff --git a/soa-derive-internal/src/lib.rs b/soa-derive-internal/src/lib.rs index b2510c0..6adacd1 100644 --- a/soa-derive-internal/src/lib.rs +++ b/soa-derive-internal/src/lib.rs @@ -9,6 +9,7 @@ extern crate proc_macro; use proc_macro2::TokenStream; use quote::TokenStreamExt; +mod index; mod input; mod iter; mod ptr; @@ -27,6 +28,7 @@ pub fn soa_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { generated.append_all(ptr::derive(&input)); generated.append_all(slice::derive(&input)); generated.append_all(slice::derive_mut(&input)); + generated.append_all(index::derive(&input)); generated.append_all(iter::derive(&input)); generated.append_all(derive_trait(&input)); generated.into() diff --git a/soa-derive-internal/src/vec.rs b/soa-derive-internal/src/vec.rs index e741c05..9c2dd21 100644 --- a/soa-derive-internal/src/vec.rs +++ b/soa-derive-internal/src/vec.rs @@ -259,6 +259,66 @@ pub fn derive(input: &Input) -> TokenStream { } } + /// Similar to [` + #[doc = #vec_name_str] + /// ::get()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get). + pub fn get<'a, I>(&'a self, index: I) -> Option + where + I: ::soa_derive::SoaIndex<&'a #vec_name> + { + index.get(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::get_unchecked()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_unchecked). + pub unsafe fn get_unchecked<'a, I>(&'a self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoaIndex<&'a #vec_name> + { + index.get_unchecked(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::index()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.index). + pub fn index<'a, I>(&'a self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoaIndex<&'a #vec_name> + { + index.index(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::get_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_mut). + pub fn get_mut<'a, I>(&'a mut self, index: I) -> Option + where + I: ::soa_derive::SoaMutIndex<&'a mut #vec_name> + { + index.get_mut(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::get_unchecked_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_unchecked_mut). + pub unsafe fn get_unchecked_mut<'a, I>(&'a mut self, index: I) -> I::MutOutput + where + I: ::soa_derive::SoaMutIndex<&'a mut #vec_name> + { + index.get_unchecked_mut(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::index_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.index_mut). + pub fn index_mut<'a, I>(&'a mut self, index: I) -> I::MutOutput + where + I: ::soa_derive::SoaMutIndex<&'a mut #vec_name> + { + index.index_mut(self) + } + /// Similar to [` #[doc = #vec_name_str] /// ::as_ptr()`](https://doc.rust-lang.org/std/struct.Vec.html#method.as_ptr). From 176fe9e91c9b8318afaab36eb88f74dc50ea6a42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Thu, 24 Sep 2020 23:58:03 +0200 Subject: [PATCH 03/14] Add tests for indexing the generated SoA --- tests/index.rs | 127 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 tests/index.rs diff --git a/tests/index.rs b/tests/index.rs new file mode 100644 index 0000000..208d6fa --- /dev/null +++ b/tests/index.rs @@ -0,0 +1,127 @@ +mod particles; +use self::particles::{Particle, ParticleVec, ParticleRef}; + + +#[test] +fn test_usize() { + let mut aos = Vec::new(); + let mut soa = ParticleVec::new(); + + let particle = Particle::new(String::from("Na"), 56.0); + aos.push(particle.clone()); + soa.push(particle.clone()); + + // SoaIndex + assert_eq!(soa.get(0).unwrap().name, &aos.get(0).unwrap().name); + assert_eq!(soa.get(0).unwrap().mass, &aos.get(0).unwrap().mass); + assert_eq!(aos.get(1), None); + assert_eq!(soa.get(1), None); + + unsafe { + assert_eq!(soa.get_unchecked(0).name, &aos.get_unchecked(0).name); + assert_eq!(soa.get_unchecked(0).mass, &aos.get_unchecked(0).mass); + } + + assert_eq!(soa.index(0).name, &aos[0].name); + assert_eq!(soa.index(0).mass, &aos[0].mass); + + + // SoaIndexMut + assert_eq!(soa.get_mut(0).unwrap().name, &aos.get_mut(0).unwrap().name); + assert_eq!(soa.get_mut(0).unwrap().mass, &aos.get_mut(0).unwrap().mass); + assert_eq!(aos.get_mut(1), None); + assert_eq!(soa.get_mut(1), None); + + unsafe { + assert_eq!(soa.get_unchecked_mut(0).name, &aos.get_unchecked_mut(0).name); + assert_eq!(soa.get_unchecked_mut(0).mass, &aos.get_unchecked_mut(0).mass); + } + + assert_eq!(soa.index_mut(0).name, &aos[0].name); + assert_eq!(soa.index_mut(0).mass, &aos[0].mass); + + + *soa.index_mut(0).mass -= 1.; + assert_eq!(soa.get(0).map(|p| *p.mass), Some(particle.mass - 1.)); + + *soa.get_mut(0).unwrap().mass += 1.; + assert_eq!(soa.get(0).map(|p| *p.mass), Some(particle.mass)); +} + +fn eq_its<'a, I1, I2>(i1: I1, i2: I2) +where + I1: Iterator>, + I2: Iterator, +{ + for (p1, p2) in i1.zip(i2) { + assert_eq!(*p1.mass, p2.mass); + } +} + +#[test] +fn test_ranges() { + let mut particles = Vec::new(); + particles.push(Particle::new(String::from("Cl"), 1.0)); + particles.push(Particle::new(String::from("Na"), 2.0)); + particles.push(Particle::new(String::from("Br"), 3.0)); + particles.push(Particle::new(String::from("Zn"), 4.0)); + + let mut soa = ParticleVec::new(); + + for particle in particles.iter() { + soa.push(particle.clone()); + } + + eq_its(soa.iter(), particles.iter()); + + // All tests from here are the same only changing the range + + let range = 0..1; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..3; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = 1..; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = 0..=1; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..=2; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + +} \ No newline at end of file From f60625318ca3967dbf8338381e01ab63a1b847ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Fri, 25 Sep 2020 02:31:15 +0200 Subject: [PATCH 04/14] Add acess methods to the generated SoA slices --- soa-derive-internal/src/index.rs | 322 +++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) diff --git a/soa-derive-internal/src/index.rs b/soa-derive-internal/src/index.rs index 3bbb37a..6bccb3a 100644 --- a/soa-derive-internal/src/index.rs +++ b/soa-derive-internal/src/index.rs @@ -14,6 +14,7 @@ pub fn derive(input: &Input) -> TokenStream { .collect::>(); let fields_names_1 = &fields_names; let fields_names_2 = &fields_names; + let first_field_name = &fields_names[0]; quote!{ // usize @@ -336,5 +337,326 @@ pub fn derive(input: &Input) -> TokenStream { (0..=self.end).index_mut(soa) } } + + // usize + impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for usize { + type RefOutput = #ref_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + if self < slice.#first_field_name.len() { + Some(unsafe { self.get_unchecked(slice) }) + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + #ref_name { + #(#fields_names_1: slice.#fields_names_2.get_unchecked(self),)* + } + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + #ref_name { + #(#fields_names_1: & slice.#fields_names_2[self],)* + } + } + } + + impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for usize { + type MutOutput = #ref_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + if self < slice.len() { + Some(unsafe { self.get_unchecked_mut(slice) }) + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + #ref_mut_name { + #(#fields_names_1: slice.#fields_names_2.get_unchecked_mut(self),)* + } + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + #ref_mut_name { + #(#fields_names_1: &mut slice.#fields_names_2[self],)* + } + } + } + + + + // Range + impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::Range { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + if self.start <= self.end && self.end <= slice.#first_field_name.len() { + unsafe { Some(self.get_unchecked(slice)) } + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + #slice_name { + #(#fields_names_1: slice.#fields_names_2.get_unchecked(self.clone()),)* + } + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + #slice_name { + #(#fields_names_1: & slice.#fields_names_2[self.clone()],)* + } + } + } + + impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::Range { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + if self.start <= self.end && self.end <= slice.#first_field_name.len() { + unsafe { Some(self.get_unchecked_mut(slice)) } + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + #slice_mut_name { + #(#fields_names_1: slice.#fields_names_2.get_unchecked_mut(self.clone()),)* + } + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + #slice_mut_name { + #(#fields_names_1: &mut slice.#fields_names_2[self.clone()],)* + } + } + } + + + + // RangeTo + impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::RangeTo { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + (0..self.end).get(slice) + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + (0..self.end).get_unchecked(slice) + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + (0..self.end).index(slice) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::RangeTo { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + (0..self.end).get_mut(slice) + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (0..self.end).get_unchecked_mut(slice) + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (0..self.end).index_mut(slice) + } + } + + + // RangeFrom + impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::RangeFrom { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + (self.start..slice.len()).get(slice) + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + (self.start..slice.len()).get_unchecked(slice) + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + (self.start..slice.len()).index(slice) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::RangeFrom { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + (self.start..slice.len()).get_mut(slice) + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (self.start..slice.len()).get_unchecked_mut(slice) + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (self.start..slice.len()).index_mut(slice) + } + } + + + // RangeFull + impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::RangeFull { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + Some(slice) + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + slice + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + slice + } + } + + impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::RangeFull { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + Some(slice) + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + slice + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + slice + } + } + + + // RangeInclusive + impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::RangeInclusive { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + if *self.end() == usize::MAX { + None + } else { + (*self.start()..self.end() + 1).get(slice) + } + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + (*self.start()..self.end() + 1).get_unchecked(slice) + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + (*self.start()..self.end() + 1).index(slice) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::RangeInclusive { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + if *self.end() == usize::MAX { + None + } else { + (*self.start()..self.end() + 1).get_mut(slice) + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (*self.start()..self.end() + 1).get_unchecked_mut(slice) + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (*self.start()..self.end() + 1).index_mut(slice) + } + } + + + // RangeToInclusive + impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::RangeToInclusive { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + (0..=self.end).get(slice) + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + (0..=self.end).get_unchecked(slice) + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + (0..=self.end).index(slice) + } + } + + impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::RangeToInclusive { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + (0..=self.end).get_mut(slice) + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (0..=self.end).get_unchecked_mut(slice) + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (0..=self.end).index_mut(slice) + } + } } } \ No newline at end of file From 47018280327a3dd54a1b9344881b7f57e05eabc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Fri, 25 Sep 2020 02:37:14 +0200 Subject: [PATCH 05/14] Prepare tests to add slice indexing methods' tests --- tests/index.rs | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/index.rs b/tests/index.rs index 208d6fa..2093cfd 100644 --- a/tests/index.rs +++ b/tests/index.rs @@ -2,8 +2,20 @@ mod particles; use self::particles::{Particle, ParticleVec, ParticleRef}; +/// Helper function to assert that two iterators (one of SoA and another of AoS) are equal. +fn eq_its<'a, I1, I2>(i1: I1, i2: I2) +where + I1: Iterator>, + I2: Iterator, +{ + for (p1, p2) in i1.zip(i2) { + assert_eq!(p1.name, &p2.name); + assert_eq!(*p1.mass, p2.mass); + } +} + #[test] -fn test_usize() { +fn test_vec_usize() { let mut aos = Vec::new(); let mut soa = ParticleVec::new(); @@ -48,18 +60,8 @@ fn test_usize() { assert_eq!(soa.get(0).map(|p| *p.mass), Some(particle.mass)); } -fn eq_its<'a, I1, I2>(i1: I1, i2: I2) -where - I1: Iterator>, - I2: Iterator, -{ - for (p1, p2) in i1.zip(i2) { - assert_eq!(*p1.mass, p2.mass); - } -} - #[test] -fn test_ranges() { +fn test_vec_ranges() { let mut particles = Vec::new(); particles.push(Particle::new(String::from("Cl"), 1.0)); particles.push(Particle::new(String::from("Na"), 2.0)); From 3c2fc15a2338cc0f3d89ef0df01784a9ad3ef07f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Fri, 25 Sep 2020 10:32:40 +0200 Subject: [PATCH 06/14] f => actually add methods to the generated slices --- soa-derive-internal/src/slice.rs | 114 +++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 35 deletions(-) diff --git a/soa-derive-internal/src/slice.rs b/soa-derive-internal/src/slice.rs index f47c063..0584741 100644 --- a/soa-derive-internal/src/slice.rs +++ b/soa-derive-internal/src/slice.rs @@ -153,22 +153,37 @@ pub fn derive(input: &Input) -> TokenStream { /// Similar to [` #[doc = #slice_name_str] /// ::get()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get). - pub fn get(&self, i: usize) -> Option<#ref_name> { - if self.is_empty() || i >= self.len() { - None - } else { - Some(#ref_name { - #(#fields_names_1: self.#fields_names_2.get(i).unwrap(),)* - }) - } + pub fn get<'b: 'a, I>(&'b self, index: I) -> Option + where + I: ::soa_derive::SoaIndex<#slice_name<'b>> + { + let slice: #slice_name<'b> = self.reborrow(); + index.get(slice) } /// Similar to [` #[doc = #slice_name_str] /// ::get_unchecked()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked). - pub unsafe fn get_unchecked(&self, i: usize) -> #ref_name { - #ref_name { - #(#fields_names_1: self.#fields_names_2.get_unchecked(i),)* + pub unsafe fn get_unchecked<'b: 'a, I>(&'b self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoaIndex<#slice_name<'b>> + { + let slice: #slice_name<'b> = self.reborrow(); + index.get_unchecked(slice) + } + + pub fn index<'b: 'a, I>(&'b self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoaIndex<#slice_name<'b>> + { + let slice: #slice_name<'b> = self.reborrow(); + index.index(slice) + } + + /// Reborrows the slices in a more narrower lifetime + pub fn reborrow<'b: 'a>(&'b self) -> #slice_name<'b> { + #slice_name { + #(#fields_names_1: &self.#fields_names_2,)* } } @@ -215,7 +230,6 @@ pub fn derive_mut(input: &Input) -> TokenStream { let slice_name = &input.slice_name(); let slice_mut_name = &input.slice_mut_name(); let vec_name = &input.vec_name(); - let ref_name = &input.ref_name(); let ref_mut_name = &input.ref_mut_name(); let ptr_name = &input.ptr_name(); let ptr_mut_name = &input.ptr_mut_name(); @@ -382,44 +396,74 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// Similar to [` #[doc = #slice_name_str] /// ::get()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get). - pub fn get(&self, i: usize) -> Option<#ref_name> { - if self.is_empty() || i >= self.len() { - None - } else { - Some(#ref_name { - #(#fields_names_1: self.#fields_names_2.get(i).unwrap(),)* - }) - } + pub fn get<'b: 'a, I>(&'b self, index: I) -> Option + where + I: ::soa_derive::SoaIndex<#slice_name<'b>> + { + let slice: #slice_name<'b> = self.as_slice(); + index.get(slice) } /// Similar to [` #[doc = #slice_name_str] /// ::get_unchecked()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked). - pub unsafe fn get_unchecked(&self, i: usize) -> #ref_name { - #ref_name { - #(#fields_names_1: self.#fields_names_2.get_unchecked(i),)* - } + pub unsafe fn get_unchecked<'b: 'a, I>(&'b self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoaIndex<#slice_name<'b>> + { + let slice: #slice_name<'b> = self.as_slice(); + index.get_unchecked(slice) + } + + pub fn index<'b: 'a, I>(&'b self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoaIndex<#slice_name<'b>> + { + let slice: #slice_name<'b> = self.as_slice(); + index.index(slice) } /// Similar to [` #[doc = #slice_name_str] /// ::get_mut()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_mut). - pub fn get_mut(&mut self, i: usize) -> Option<#ref_mut_name> { - if self.is_empty() || i >= self.len() { - None - } else { - Some(#ref_mut_name { - #(#fields_names_1: self.#fields_names_2.get_mut(i).unwrap(),)* - }) - } + pub fn get_mut<'b: 'a, I>(&'b mut self, index: I) -> Option + where + I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>> + { + let slice: #slice_mut_name<'b> = self.reborrow(); + index.get_mut(slice) } /// Similar to [` #[doc = #slice_name_str] /// ::get_unchecked_mut()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked_mut). - pub unsafe fn get_unchecked_mut(&mut self, i: usize) -> #ref_mut_name { - #ref_mut_name { - #(#fields_names_1: self.#fields_names_2.get_unchecked_mut(i),)* + pub unsafe fn get_unchecked_mut<'b: 'a, I>(&'b mut self, index: I) -> I::MutOutput + where + I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>> + { + let slice: #slice_mut_name<'b> = self.reborrow(); + index.get_unchecked_mut(slice) + } + + pub fn index_mut<'b: 'a, I>(&'b mut self, index: I) -> I::MutOutput + where + I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>> + { + let slice: #slice_mut_name<'b> = self.reborrow(); + index.index_mut(slice) + } + + /// Returns a non-mutable slice from this mutable slice. + pub fn as_slice<'b: 'a>(&'b self) -> #slice_name<'b> { + #slice_name { + #(#fields_names_1: &self.#fields_names_2,)* + } + } + + /// Reborrows the slices in a more narrower lifetime + pub fn reborrow<'b: 'a>(&'b mut self) -> #slice_mut_name<'b> { + #slice_mut_name { + #(#fields_names_1: &mut self.#fields_names_2,)* } } From 27a5a2a72a498fc56816dee1927003d5bce2e377 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Fri, 25 Sep 2020 18:38:53 +0200 Subject: [PATCH 07/14] f => Fix slice index methods --- soa-derive-internal/src/index.rs | 2 +- soa-derive-internal/src/slice.rs | 79 ++++++++++++++++++++++---------- 2 files changed, 57 insertions(+), 24 deletions(-) diff --git a/soa-derive-internal/src/index.rs b/soa-derive-internal/src/index.rs index 6bccb3a..fe150ed 100644 --- a/soa-derive-internal/src/index.rs +++ b/soa-derive-internal/src/index.rs @@ -659,4 +659,4 @@ pub fn derive(input: &Input) -> TokenStream { } } } -} \ No newline at end of file +} diff --git a/soa-derive-internal/src/slice.rs b/soa-derive-internal/src/slice.rs index 0584741..a3c5637 100644 --- a/soa-derive-internal/src/slice.rs +++ b/soa-derive-internal/src/slice.rs @@ -153,9 +153,10 @@ pub fn derive(input: &Input) -> TokenStream { /// Similar to [` #[doc = #slice_name_str] /// ::get()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get). - pub fn get<'b: 'a, I>(&'b self, index: I) -> Option + pub fn get<'b, I>(&'b self, index: I) -> Option where - I: ::soa_derive::SoaIndex<#slice_name<'b>> + I: ::soa_derive::SoaIndex<#slice_name<'b>>, + 'a: 'b { let slice: #slice_name<'b> = self.reborrow(); index.get(slice) @@ -164,24 +165,33 @@ pub fn derive(input: &Input) -> TokenStream { /// Similar to [` #[doc = #slice_name_str] /// ::get_unchecked()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked). - pub unsafe fn get_unchecked<'b: 'a, I>(&'b self, index: I) -> I::RefOutput + pub unsafe fn get_unchecked<'b, I>(&'b self, index: I) -> I::RefOutput where - I: ::soa_derive::SoaIndex<#slice_name<'b>> + I: ::soa_derive::SoaIndex<#slice_name<'b>>, + 'a: 'b { let slice: #slice_name<'b> = self.reborrow(); index.get_unchecked(slice) } - pub fn index<'b: 'a, I>(&'b self, index: I) -> I::RefOutput + /// Similar to [`std::ops::Index` trait](https://doc.rust-lang.org/std/ops/trait.Index.html) on + #[doc = #slice_name_str] + /// . + /// This is required because we cannot implement that trait. + pub fn index<'b, I>(&'b self, index: I) -> I::RefOutput where - I: ::soa_derive::SoaIndex<#slice_name<'b>> + I: ::soa_derive::SoaIndex<#slice_name<'b>>, + 'a: 'b { let slice: #slice_name<'b> = self.reborrow(); index.index(slice) } /// Reborrows the slices in a more narrower lifetime - pub fn reborrow<'b: 'a>(&'b self) -> #slice_name<'b> { + pub fn reborrow<'b>(&'b self) -> #slice_name<'b> + where + 'a: 'b + { #slice_name { #(#fields_names_1: &self.#fields_names_2,)* } @@ -396,9 +406,10 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// Similar to [` #[doc = #slice_name_str] /// ::get()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get). - pub fn get<'b: 'a, I>(&'b self, index: I) -> Option + pub fn get<'b, I>(&'b self, index: I) -> Option where - I: ::soa_derive::SoaIndex<#slice_name<'b>> + I: ::soa_derive::SoaIndex<#slice_name<'b>>, + 'a: 'b { let slice: #slice_name<'b> = self.as_slice(); index.get(slice) @@ -407,17 +418,24 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// Similar to [` #[doc = #slice_name_str] /// ::get_unchecked()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked). - pub unsafe fn get_unchecked<'b: 'a, I>(&'b self, index: I) -> I::RefOutput + pub unsafe fn get_unchecked<'b, I>(&'b self, index: I) -> I::RefOutput where - I: ::soa_derive::SoaIndex<#slice_name<'b>> + I: ::soa_derive::SoaIndex<#slice_name<'b>>, + 'a: 'b { let slice: #slice_name<'b> = self.as_slice(); index.get_unchecked(slice) } - pub fn index<'b: 'a, I>(&'b self, index: I) -> I::RefOutput + + /// Similar to [`std::ops::Index` trait](https://doc.rust-lang.org/std/ops/trait.Index.html) on + #[doc = #slice_name_str] + /// . + /// This is required because we cannot implement that trait. + pub fn index<'b, I>(&'b self, index: I) -> I::RefOutput where - I: ::soa_derive::SoaIndex<#slice_name<'b>> + I: ::soa_derive::SoaIndex<#slice_name<'b>>, + 'a: 'b { let slice: #slice_name<'b> = self.as_slice(); index.index(slice) @@ -426,44 +444,59 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// Similar to [` #[doc = #slice_name_str] /// ::get_mut()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_mut). - pub fn get_mut<'b: 'a, I>(&'b mut self, index: I) -> Option + pub fn get_mut<'b, I>(&'b mut self, index: I) -> Option where - I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>> + I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>>, + 'a: 'b { - let slice: #slice_mut_name<'b> = self.reborrow(); + let slice: #slice_mut_name<'b> = #slice_mut_name { + #(#fields_names_1: &mut *self.#fields_names_2,)* + }; index.get_mut(slice) } /// Similar to [` #[doc = #slice_name_str] /// ::get_unchecked_mut()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked_mut). - pub unsafe fn get_unchecked_mut<'b: 'a, I>(&'b mut self, index: I) -> I::MutOutput + pub unsafe fn get_unchecked_mut<'b, I>(&'b mut self, index: I) -> I::MutOutput where - I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>> + I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>>, + 'a: 'b { let slice: #slice_mut_name<'b> = self.reborrow(); index.get_unchecked_mut(slice) } - pub fn index_mut<'b: 'a, I>(&'b mut self, index: I) -> I::MutOutput + /// Similar to [`std::ops::IndexMut` trait](https://doc.rust-lang.org/std/ops/trait.IndexMut.html) on + #[doc = #slice_name_str] + /// . + /// This is required because we cannot implement that trait. + pub fn index_mut<'b, I>(&'b mut self, index: I) -> I::MutOutput where - I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>> + I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>>, + 'a: 'b { let slice: #slice_mut_name<'b> = self.reborrow(); index.index_mut(slice) } /// Returns a non-mutable slice from this mutable slice. - pub fn as_slice<'b: 'a>(&'b self) -> #slice_name<'b> { + pub fn as_slice<'b>(&'b self) -> #slice_name<'b> + where + 'a: 'b + { #slice_name { #(#fields_names_1: &self.#fields_names_2,)* } } /// Reborrows the slices in a more narrower lifetime - pub fn reborrow<'b: 'a>(&'b mut self) -> #slice_mut_name<'b> { + pub fn reborrow<'b>(&'b mut self) -> #slice_mut_name<'b> + where + 'a: 'b + { #slice_mut_name { - #(#fields_names_1: &mut self.#fields_names_2,)* + #(#fields_names_1: &mut *self.#fields_names_2,)* } } From 878499f343307e6903855121beee9efb2aaa1d7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Fri, 25 Sep 2020 18:39:18 +0200 Subject: [PATCH 08/14] Add tests for slice index methods --- tests/index.rs | 121 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 120 insertions(+), 1 deletion(-) diff --git a/tests/index.rs b/tests/index.rs index 2093cfd..1ba31e2 100644 --- a/tests/index.rs +++ b/tests/index.rs @@ -125,5 +125,124 @@ fn test_vec_ranges() { eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); +} + +#[test] +fn test_slice_usize() { + let mut aos = Vec::new(); + let mut soa = ParticleVec::new(); + + let particle = Particle::new(String::from("Na"), 56.0); + aos.push(particle.clone()); + soa.push(particle.clone()); + + // SoaIndex + let aos_slice = aos.as_slice(); + let soa_slice = soa.as_slice(); + + assert_eq!(soa_slice.get(0).unwrap().name, &aos_slice.get(0).unwrap().name); + assert_eq!(soa_slice.get(0).unwrap().mass, &aos_slice.get(0).unwrap().mass); + assert_eq!(aos_slice.get(1), None); + assert_eq!(soa_slice.get(1), None); + + unsafe { + assert_eq!(soa_slice.get_unchecked(0).name, &aos_slice.get_unchecked(0).name); + assert_eq!(soa_slice.get_unchecked(0).mass, &aos_slice.get_unchecked(0).mass); + } + + assert_eq!(soa_slice.index(0).name, &aos_slice[0].name); + assert_eq!(soa_slice.index(0).mass, &aos_slice[0].mass); + + + // SoaIndexMut + let aos_mut_slice = aos.as_mut_slice(); + let mut soa_mut_slice = soa.as_mut_slice(); + assert_eq!(soa_mut_slice.get_mut(0).unwrap().name, &aos_mut_slice.get_mut(0).unwrap().name); + assert_eq!(soa_mut_slice.get_mut(0).unwrap().mass, &aos_mut_slice.get_mut(0).unwrap().mass); + assert_eq!(soa_mut_slice.get_mut(0).unwrap().mass, &aos_mut_slice.get_mut(0).unwrap().mass); + assert_eq!(aos_mut_slice.get_mut(1), None); + assert_eq!(soa_mut_slice.get_mut(1), None); + + unsafe { + assert_eq!(soa_mut_slice.get_unchecked_mut(0).name, &aos_mut_slice.get_unchecked_mut(0).name); + assert_eq!(soa_mut_slice.get_unchecked_mut(0).mass, &aos_mut_slice.get_unchecked_mut(0).mass); + } + + assert_eq!(soa_mut_slice.index_mut(0).name, &aos_mut_slice[0].name); + assert_eq!(soa_mut_slice.index_mut(0).mass, &aos_mut_slice[0].mass); + + + *soa_mut_slice.index_mut(0).mass -= 1.; + assert_eq!(soa_mut_slice.get(0).map(|p| *p.mass), Some(particle.mass - 1.)); + + *soa_mut_slice.get_mut(0).unwrap().mass += 1.; + assert_eq!(soa_mut_slice.get(0).map(|p| *p.mass), Some(particle.mass)); +} + +#[test] +fn test_slice_ranges() { + let mut particles = Vec::new(); + particles.push(Particle::new(String::from("Cl"), 1.0)); + particles.push(Particle::new(String::from("Na"), 2.0)); + particles.push(Particle::new(String::from("Br"), 3.0)); + particles.push(Particle::new(String::from("Zn"), 4.0)); + + let mut soa = ParticleVec::new(); + + for particle in particles.iter() { + soa.push(particle.clone()); + } + + eq_its(soa.iter(), particles.iter()); -} \ No newline at end of file + let mut soa_mut_slice = soa.as_mut_slice(); + // All tests from here are the same only changing the range + + let range = 0..1; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..3; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = 1..; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = 0..=1; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..=2; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); +} From f1ebf7e436c05537d76b75c1165576a128390c3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Fri, 25 Sep 2020 18:43:29 +0200 Subject: [PATCH 09/14] Rename `SoaIndex` to `SoAIndex` and `SoaMutIndex` to `SoAIndexMut` --- soa-derive-internal/src/index.rs | 56 ++++++++++++++++---------------- soa-derive-internal/src/slice.rs | 18 +++++----- soa-derive-internal/src/vec.rs | 12 +++---- src/lib.rs | 4 +-- tests/index.rs | 4 +-- 5 files changed, 47 insertions(+), 47 deletions(-) diff --git a/soa-derive-internal/src/index.rs b/soa-derive-internal/src/index.rs index fe150ed..8db82c6 100644 --- a/soa-derive-internal/src/index.rs +++ b/soa-derive-internal/src/index.rs @@ -18,7 +18,7 @@ pub fn derive(input: &Input) -> TokenStream { quote!{ // usize - impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for usize { + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for usize { type RefOutput = #ref_name<'a>; #[inline] @@ -45,7 +45,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for usize { + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for usize { type MutOutput = #ref_mut_name<'a>; #[inline] @@ -75,7 +75,7 @@ pub fn derive(input: &Input) -> TokenStream { // Range - impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::Range { + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::Range { type RefOutput = #slice_name<'a>; #[inline] @@ -102,7 +102,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::Range { + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::Range { type MutOutput = #slice_mut_name<'a>; #[inline] @@ -132,7 +132,7 @@ pub fn derive(input: &Input) -> TokenStream { // RangeTo - impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeTo { + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::RangeTo { type RefOutput = #slice_name<'a>; #[inline] @@ -151,7 +151,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeTo { + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::RangeTo { type MutOutput = #slice_mut_name<'a>; #[inline] @@ -172,7 +172,7 @@ pub fn derive(input: &Input) -> TokenStream { // RangeFrom - impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeFrom { + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::RangeFrom { type RefOutput = #slice_name<'a>; #[inline] @@ -191,7 +191,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeFrom { + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::RangeFrom { type MutOutput = #slice_mut_name<'a>; #[inline] @@ -212,7 +212,7 @@ pub fn derive(input: &Input) -> TokenStream { // RangeFull - impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeFull { + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::RangeFull { type RefOutput = #slice_name<'a>; #[inline] @@ -231,7 +231,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeFull { + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::RangeFull { type MutOutput = #slice_mut_name<'a>; #[inline] @@ -252,7 +252,7 @@ pub fn derive(input: &Input) -> TokenStream { // RangeInclusive - impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeInclusive { + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::RangeInclusive { type RefOutput = #slice_name<'a>; #[inline] @@ -275,7 +275,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeInclusive { + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::RangeInclusive { type MutOutput = #slice_mut_name<'a>; #[inline] @@ -300,7 +300,7 @@ pub fn derive(input: &Input) -> TokenStream { // RangeToInclusive - impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeToInclusive { + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::RangeToInclusive { type RefOutput = #slice_name<'a>; #[inline] @@ -319,7 +319,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeToInclusive { + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::RangeToInclusive { type MutOutput = #slice_mut_name<'a>; #[inline] @@ -339,7 +339,7 @@ pub fn derive(input: &Input) -> TokenStream { } // usize - impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for usize { + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for usize { type RefOutput = #ref_name<'a>; #[inline] @@ -366,7 +366,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for usize { + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for usize { type MutOutput = #ref_mut_name<'a>; #[inline] @@ -396,7 +396,7 @@ pub fn derive(input: &Input) -> TokenStream { // Range - impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::Range { + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::Range { type RefOutput = #slice_name<'a>; #[inline] @@ -423,7 +423,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::Range { + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::Range { type MutOutput = #slice_mut_name<'a>; #[inline] @@ -453,7 +453,7 @@ pub fn derive(input: &Input) -> TokenStream { // RangeTo - impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::RangeTo { + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::RangeTo { type RefOutput = #slice_name<'a>; #[inline] @@ -472,7 +472,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::RangeTo { + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::RangeTo { type MutOutput = #slice_mut_name<'a>; #[inline] @@ -493,7 +493,7 @@ pub fn derive(input: &Input) -> TokenStream { // RangeFrom - impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::RangeFrom { + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::RangeFrom { type RefOutput = #slice_name<'a>; #[inline] @@ -512,7 +512,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::RangeFrom { + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::RangeFrom { type MutOutput = #slice_mut_name<'a>; #[inline] @@ -533,7 +533,7 @@ pub fn derive(input: &Input) -> TokenStream { // RangeFull - impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::RangeFull { + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::RangeFull { type RefOutput = #slice_name<'a>; #[inline] @@ -552,7 +552,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::RangeFull { + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::RangeFull { type MutOutput = #slice_mut_name<'a>; #[inline] @@ -573,7 +573,7 @@ pub fn derive(input: &Input) -> TokenStream { // RangeInclusive - impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::RangeInclusive { + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::RangeInclusive { type RefOutput = #slice_name<'a>; #[inline] @@ -596,7 +596,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::RangeInclusive { + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::RangeInclusive { type MutOutput = #slice_mut_name<'a>; #[inline] @@ -621,7 +621,7 @@ pub fn derive(input: &Input) -> TokenStream { // RangeToInclusive - impl<'a> ::soa_derive::SoaIndex<#slice_name<'a>> for ::std::ops::RangeToInclusive { + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::RangeToInclusive { type RefOutput = #slice_name<'a>; #[inline] @@ -640,7 +640,7 @@ pub fn derive(input: &Input) -> TokenStream { } } - impl<'a> ::soa_derive::SoaMutIndex<#slice_mut_name<'a>> for ::std::ops::RangeToInclusive { + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::RangeToInclusive { type MutOutput = #slice_mut_name<'a>; #[inline] diff --git a/soa-derive-internal/src/slice.rs b/soa-derive-internal/src/slice.rs index a3c5637..b9c9e18 100644 --- a/soa-derive-internal/src/slice.rs +++ b/soa-derive-internal/src/slice.rs @@ -155,7 +155,7 @@ pub fn derive(input: &Input) -> TokenStream { /// ::get()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get). pub fn get<'b, I>(&'b self, index: I) -> Option where - I: ::soa_derive::SoaIndex<#slice_name<'b>>, + I: ::soa_derive::SoAIndex<#slice_name<'b>>, 'a: 'b { let slice: #slice_name<'b> = self.reborrow(); @@ -167,7 +167,7 @@ pub fn derive(input: &Input) -> TokenStream { /// ::get_unchecked()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked). pub unsafe fn get_unchecked<'b, I>(&'b self, index: I) -> I::RefOutput where - I: ::soa_derive::SoaIndex<#slice_name<'b>>, + I: ::soa_derive::SoAIndex<#slice_name<'b>>, 'a: 'b { let slice: #slice_name<'b> = self.reborrow(); @@ -180,7 +180,7 @@ pub fn derive(input: &Input) -> TokenStream { /// This is required because we cannot implement that trait. pub fn index<'b, I>(&'b self, index: I) -> I::RefOutput where - I: ::soa_derive::SoaIndex<#slice_name<'b>>, + I: ::soa_derive::SoAIndex<#slice_name<'b>>, 'a: 'b { let slice: #slice_name<'b> = self.reborrow(); @@ -408,7 +408,7 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// ::get()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get). pub fn get<'b, I>(&'b self, index: I) -> Option where - I: ::soa_derive::SoaIndex<#slice_name<'b>>, + I: ::soa_derive::SoAIndex<#slice_name<'b>>, 'a: 'b { let slice: #slice_name<'b> = self.as_slice(); @@ -420,7 +420,7 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// ::get_unchecked()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked). pub unsafe fn get_unchecked<'b, I>(&'b self, index: I) -> I::RefOutput where - I: ::soa_derive::SoaIndex<#slice_name<'b>>, + I: ::soa_derive::SoAIndex<#slice_name<'b>>, 'a: 'b { let slice: #slice_name<'b> = self.as_slice(); @@ -434,7 +434,7 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// This is required because we cannot implement that trait. pub fn index<'b, I>(&'b self, index: I) -> I::RefOutput where - I: ::soa_derive::SoaIndex<#slice_name<'b>>, + I: ::soa_derive::SoAIndex<#slice_name<'b>>, 'a: 'b { let slice: #slice_name<'b> = self.as_slice(); @@ -446,7 +446,7 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// ::get_mut()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_mut). pub fn get_mut<'b, I>(&'b mut self, index: I) -> Option where - I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>>, + I: ::soa_derive::SoAIndexMut<#slice_mut_name<'b>>, 'a: 'b { let slice: #slice_mut_name<'b> = #slice_mut_name { @@ -460,7 +460,7 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// ::get_unchecked_mut()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked_mut). pub unsafe fn get_unchecked_mut<'b, I>(&'b mut self, index: I) -> I::MutOutput where - I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>>, + I: ::soa_derive::SoAIndexMut<#slice_mut_name<'b>>, 'a: 'b { let slice: #slice_mut_name<'b> = self.reborrow(); @@ -473,7 +473,7 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// This is required because we cannot implement that trait. pub fn index_mut<'b, I>(&'b mut self, index: I) -> I::MutOutput where - I: ::soa_derive::SoaMutIndex<#slice_mut_name<'b>>, + I: ::soa_derive::SoAIndexMut<#slice_mut_name<'b>>, 'a: 'b { let slice: #slice_mut_name<'b> = self.reborrow(); diff --git a/soa-derive-internal/src/vec.rs b/soa-derive-internal/src/vec.rs index 9c2dd21..2f746e7 100644 --- a/soa-derive-internal/src/vec.rs +++ b/soa-derive-internal/src/vec.rs @@ -264,7 +264,7 @@ pub fn derive(input: &Input) -> TokenStream { /// ::get()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get). pub fn get<'a, I>(&'a self, index: I) -> Option where - I: ::soa_derive::SoaIndex<&'a #vec_name> + I: ::soa_derive::SoAIndex<&'a #vec_name> { index.get(self) } @@ -274,7 +274,7 @@ pub fn derive(input: &Input) -> TokenStream { /// ::get_unchecked()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_unchecked). pub unsafe fn get_unchecked<'a, I>(&'a self, index: I) -> I::RefOutput where - I: ::soa_derive::SoaIndex<&'a #vec_name> + I: ::soa_derive::SoAIndex<&'a #vec_name> { index.get_unchecked(self) } @@ -284,7 +284,7 @@ pub fn derive(input: &Input) -> TokenStream { /// ::index()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.index). pub fn index<'a, I>(&'a self, index: I) -> I::RefOutput where - I: ::soa_derive::SoaIndex<&'a #vec_name> + I: ::soa_derive::SoAIndex<&'a #vec_name> { index.index(self) } @@ -294,7 +294,7 @@ pub fn derive(input: &Input) -> TokenStream { /// ::get_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_mut). pub fn get_mut<'a, I>(&'a mut self, index: I) -> Option where - I: ::soa_derive::SoaMutIndex<&'a mut #vec_name> + I: ::soa_derive::SoAIndexMut<&'a mut #vec_name> { index.get_mut(self) } @@ -304,7 +304,7 @@ pub fn derive(input: &Input) -> TokenStream { /// ::get_unchecked_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_unchecked_mut). pub unsafe fn get_unchecked_mut<'a, I>(&'a mut self, index: I) -> I::MutOutput where - I: ::soa_derive::SoaMutIndex<&'a mut #vec_name> + I: ::soa_derive::SoAIndexMut<&'a mut #vec_name> { index.get_unchecked_mut(self) } @@ -314,7 +314,7 @@ pub fn derive(input: &Input) -> TokenStream { /// ::index_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.index_mut). pub fn index_mut<'a, I>(&'a mut self, index: I) -> I::MutOutput where - I: ::soa_derive::SoaMutIndex<&'a mut #vec_name> + I: ::soa_derive::SoAIndexMut<&'a mut #vec_name> { index.index_mut(self) } diff --git a/src/lib.rs b/src/lib.rs index dc6109a..d447690 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -189,7 +189,7 @@ mod private_soa_indexs { /// Helper trait used for indexing operations. /// Inspired by [`std::slice::SliceIndex`](https://doc.rust-lang.org/std/slice/trait.SliceIndex.html). -pub trait SoaIndex: private_soa_indexs::Sealed { +pub trait SoAIndex: private_soa_indexs::Sealed { /// The output for the non-mutable functions type RefOutput; @@ -203,7 +203,7 @@ pub trait SoaIndex: private_soa_indexs::Sealed { /// Helper trait used for indexing operations returning mutable. /// Inspired by [`std::slice::SliceIndex`](https://doc.rust-lang.org/std/slice/trait.SliceIndex.html). -pub trait SoaMutIndex: private_soa_indexs::Sealed { +pub trait SoAIndexMut: private_soa_indexs::Sealed { /// The output for the mutable functions type MutOutput; diff --git a/tests/index.rs b/tests/index.rs index 1ba31e2..1f66ac2 100644 --- a/tests/index.rs +++ b/tests/index.rs @@ -23,7 +23,7 @@ fn test_vec_usize() { aos.push(particle.clone()); soa.push(particle.clone()); - // SoaIndex + // SoAIndex assert_eq!(soa.get(0).unwrap().name, &aos.get(0).unwrap().name); assert_eq!(soa.get(0).unwrap().mass, &aos.get(0).unwrap().mass); assert_eq!(aos.get(1), None); @@ -136,7 +136,7 @@ fn test_slice_usize() { aos.push(particle.clone()); soa.push(particle.clone()); - // SoaIndex + // SoAIndex let aos_slice = aos.as_slice(); let soa_slice = soa.as_slice(); From c65e32a6d530e13ae9b7ec805b6c0c67ab73f48e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Fri, 25 Sep 2020 19:14:52 +0200 Subject: [PATCH 10/14] f => reverse debug change --- soa-derive-internal/src/slice.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/soa-derive-internal/src/slice.rs b/soa-derive-internal/src/slice.rs index b9c9e18..c42e025 100644 --- a/soa-derive-internal/src/slice.rs +++ b/soa-derive-internal/src/slice.rs @@ -449,9 +449,7 @@ pub fn derive_mut(input: &Input) -> TokenStream { I: ::soa_derive::SoAIndexMut<#slice_mut_name<'b>>, 'a: 'b { - let slice: #slice_mut_name<'b> = #slice_mut_name { - #(#fields_names_1: &mut *self.#fields_names_2,)* - }; + let slice: #slice_mut_name<'b> = self.reborrow(); index.get_mut(slice) } From 1697674d725eba0f2791651bba5c3e691d06d472 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Sat, 26 Sep 2020 00:28:58 +0200 Subject: [PATCH 11/14] Use the new access methods in the tests This makes the tests not rely on the underlying structure of the data. --- tests/iter.rs | 12 ++++++------ tests/ptr.rs | 30 +++++++++++++++--------------- tests/vec.rs | 32 ++++++++++++++++---------------- 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/tests/iter.rs b/tests/iter.rs index 9c7125e..2af51ce 100644 --- a/tests/iter.rs +++ b/tests/iter.rs @@ -35,9 +35,9 @@ fn iter_mut() { for particle in particles.iter_mut() { *particle.mass += 1.0; } - assert_eq!(particles.mass[0], 1.0); - assert_eq!(particles.mass[1], 1.0); - assert_eq!(particles.mass[2], 1.0); + assert_eq!(*particles.index(0).mass, 1.0); + assert_eq!(*particles.index(1).mass, 1.0); + assert_eq!(*particles.index(2).mass, 1.0); { let mut slice = particles.as_mut_slice(); @@ -46,9 +46,9 @@ fn iter_mut() { } } - assert_eq!(particles.mass[0], 2.0); - assert_eq!(particles.mass[1], 2.0); - assert_eq!(particles.mass[2], 2.0); + assert_eq!(*particles.index(0).mass, 2.0); + assert_eq!(*particles.index(1).mass, 2.0); + assert_eq!(*particles.index(2).mass, 2.0); } #[test] diff --git a/tests/ptr.rs b/tests/ptr.rs index 1785c7a..cfa219a 100644 --- a/tests/ptr.rs +++ b/tests/ptr.rs @@ -53,11 +53,11 @@ fn slice() { unsafe { let slice = ParticleSlice::from_raw_parts(ptr, 2); assert_eq!(slice.len(), 2); - assert_eq!(slice.name[0], "Na"); - assert_eq!(slice.name[1], "Zn"); + assert_eq!(slice.index(0).name, "Na"); + assert_eq!(slice.index(1).name, "Zn"); - assert_eq!(slice.mass[0], 1.0); - assert_eq!(slice.mass[1], 2.0); + assert_eq!(*slice.index(0).mass, 1.0); + assert_eq!(*slice.index(1).mass, 2.0); } } @@ -77,8 +77,8 @@ fn slice_mut() { *ptr.as_mut().unwrap().mass = 42.0; } - assert_eq!(slice.name[0], "Fe"); - assert_eq!(slice.mass[0], 42.0); + assert_eq!(slice.index(0).name, "Fe"); + assert_eq!(*slice.index(0).mass, 42.0); unsafe { let slice = ParticleSliceMut::from_raw_parts_mut(ptr, 2); @@ -88,9 +88,9 @@ fn slice_mut() { } } - assert_eq!(slice.mass[0], -1.0); - assert_eq!(slice.mass[1], -1.0); - assert_eq!(slice.mass[2], 3.0); + assert_eq!(*slice.index(0).mass, -1.0); + assert_eq!(*slice.index(1).mass, -1.0); + assert_eq!(*slice.index(2).mass, 3.0); } #[test] @@ -119,12 +119,12 @@ fn vec() { assert_eq!(particles.len(), len); assert_eq!(particles.capacity(), capacity); - assert_eq!(particles.name[0], "Fe"); - assert_eq!(particles.mass[0], 42.0); + assert_eq!(particles.index(0).name, "Fe"); + assert_eq!(*particles.index(0).mass, 42.0); - assert_eq!(particles.name[1], "Zn"); - assert_eq!(particles.mass[1], 2.0); + assert_eq!(particles.index(1).name, "Zn"); + assert_eq!(*particles.index(1).mass, 2.0); - assert_eq!(particles.name[2], "Fe"); - assert_eq!(particles.mass[2], 3.0); + assert_eq!(particles.index(2).name, "Fe"); + assert_eq!(*particles.index(2).mass, 3.0); } diff --git a/tests/vec.rs b/tests/vec.rs index 0416201..d4d7af5 100644 --- a/tests/vec.rs +++ b/tests/vec.rs @@ -78,9 +78,9 @@ fn swap_remove() { let particle = particles.swap_remove(1); assert_eq!(particle.name, "Na"); - assert_eq!(particles.name[0], "Cl"); - assert_eq!(particles.name[1], "Zn"); - assert_eq!(particles.name[2], "Br"); + assert_eq!(particles.index(0).name, "Cl"); + assert_eq!(particles.index(1).name, "Zn"); + assert_eq!(particles.index(2).name, "Br"); } #[test] @@ -90,9 +90,9 @@ fn insert() { particles.push(Particle::new(String::from("Na"), 0.0)); particles.insert(1, Particle::new(String::from("Zn"), 0.0)); - assert_eq!(particles.name[0], "Cl"); - assert_eq!(particles.name[1], "Zn"); - assert_eq!(particles.name[2], "Na"); + assert_eq!(particles.index(0).name, "Cl"); + assert_eq!(particles.index(1).name, "Zn"); + assert_eq!(particles.index(2).name, "Na"); } #[test] @@ -122,10 +122,10 @@ fn append() { others.push(Particle::new(String::from("Mg"), 0.0)); particles.append(&mut others); - assert_eq!(particles.name[0], "Cl"); - assert_eq!(particles.name[1], "Na"); - assert_eq!(particles.name[2], "Zn"); - assert_eq!(particles.name[3], "Mg"); + assert_eq!(particles.index(0).name, "Cl"); + assert_eq!(particles.index(1).name, "Na"); + assert_eq!(particles.index(2).name, "Zn"); + assert_eq!(particles.index(3).name, "Mg"); } #[test] @@ -140,10 +140,10 @@ fn split_off() { assert_eq!(particles.len(), 2); assert_eq!(other.len(), 2); - assert_eq!(particles.name[0], "Cl"); - assert_eq!(particles.name[1], "Na"); - assert_eq!(other.name[0], "Zn"); - assert_eq!(other.name[1], "Mg"); + assert_eq!(particles.index(0).name, "Cl"); + assert_eq!(particles.index(1).name, "Na"); + assert_eq!(other.index(0).name, "Zn"); + assert_eq!(other.index(1).name, "Mg"); } #[test] @@ -156,6 +156,6 @@ fn retain() { particles.retain(|particle| particle.name.starts_with("C")); assert_eq!(particles.len(), 2); - assert_eq!(particles.name[0], "Cl"); - assert_eq!(particles.name[1], "C"); + assert_eq!(particles.index(0).name, "Cl"); + assert_eq!(particles.index(1).name, "C"); } From 93f715e3ceff73e7706f42dfc7777ea2ddb9a5eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Fri, 2 Oct 2020 16:42:27 +0200 Subject: [PATCH 12/14] Fix and improve documentation --- soa-derive-internal/src/slice.rs | 8 ++++---- src/lib.rs | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/soa-derive-internal/src/slice.rs b/soa-derive-internal/src/slice.rs index c42e025..acd3e53 100644 --- a/soa-derive-internal/src/slice.rs +++ b/soa-derive-internal/src/slice.rs @@ -177,7 +177,7 @@ pub fn derive(input: &Input) -> TokenStream { /// Similar to [`std::ops::Index` trait](https://doc.rust-lang.org/std/ops/trait.Index.html) on #[doc = #slice_name_str] /// . - /// This is required because we cannot implement that trait. + /// This is required because we cannot implement `std::ops::Index` directly since it requires returning a reference. pub fn index<'b, I>(&'b self, index: I) -> I::RefOutput where I: ::soa_derive::SoAIndex<#slice_name<'b>>, @@ -187,7 +187,7 @@ pub fn derive(input: &Input) -> TokenStream { index.index(slice) } - /// Reborrows the slices in a more narrower lifetime + /// Reborrows the slices in a narrower lifetime pub fn reborrow<'b>(&'b self) -> #slice_name<'b> where 'a: 'b @@ -468,7 +468,7 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// Similar to [`std::ops::IndexMut` trait](https://doc.rust-lang.org/std/ops/trait.IndexMut.html) on #[doc = #slice_name_str] /// . - /// This is required because we cannot implement that trait. + /// This is required because we cannot implement `std::ops::IndexMut` directly since it requires returning a mutable reference. pub fn index_mut<'b, I>(&'b mut self, index: I) -> I::MutOutput where I: ::soa_derive::SoAIndexMut<#slice_mut_name<'b>>, @@ -488,7 +488,7 @@ pub fn derive_mut(input: &Input) -> TokenStream { } } - /// Reborrows the slices in a more narrower lifetime + /// Reborrows the slices in a narrower lifetime pub fn reborrow<'b>(&'b mut self) -> #slice_mut_name<'b> where 'a: 'b diff --git a/src/lib.rs b/src/lib.rs index d447690..ec2d0dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -193,23 +193,23 @@ pub trait SoAIndex: private_soa_indexs::Sealed { /// The output for the non-mutable functions type RefOutput; - /// Returns the reference output in this location if in bounds. None otherwise. + /// Returns the reference output in this location if in bounds, `None` otherwise. fn get(self, soa: T) -> Option; - /// Returns the reference output in this location withotu performing any bounds check. + /// Returns the reference output in this location without performing any bounds check. unsafe fn get_unchecked(self, soa: T) -> Self::RefOutput; /// Returns the reference output in this location. Panics if it is not in bounds. fn index(self, soa: T) -> Self::RefOutput; } -/// Helper trait used for indexing operations returning mutable. +/// Helper trait used for indexing operations returning mutable references. /// Inspired by [`std::slice::SliceIndex`](https://doc.rust-lang.org/std/slice/trait.SliceIndex.html). pub trait SoAIndexMut: private_soa_indexs::Sealed { /// The output for the mutable functions type MutOutput; - /// Returns the mutable reference output in this location if in bounds. None otherwise. + /// Returns the mutable reference output in this location if in bounds, `None` otherwise. fn get_mut(self, soa: T) -> Option; - /// Returns the mutable reference output in this location withotu performing any bounds check. + /// Returns the mutable reference output in this location without performing any bounds check. unsafe fn get_unchecked_mut(self, soa: T) -> Self::MutOutput; /// Returns the mutable reference output in this location. Panics if it is not in bounds. fn index_mut(self, soa: T) -> Self::MutOutput; From 43c8c8bc15915ad4f8f187d3a7ce91507776b34d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Fri, 2 Oct 2020 16:43:43 +0200 Subject: [PATCH 13/14] Improve test names --- tests/index.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/index.rs b/tests/index.rs index 1f66ac2..dc1f294 100644 --- a/tests/index.rs +++ b/tests/index.rs @@ -15,7 +15,7 @@ where } #[test] -fn test_vec_usize() { +fn index_vec_with_usize() { let mut aos = Vec::new(); let mut soa = ParticleVec::new(); @@ -61,7 +61,7 @@ fn test_vec_usize() { } #[test] -fn test_vec_ranges() { +fn index_vec_with_ranges() { let mut particles = Vec::new(); particles.push(Particle::new(String::from("Cl"), 1.0)); particles.push(Particle::new(String::from("Na"), 2.0)); @@ -128,7 +128,7 @@ fn test_vec_ranges() { } #[test] -fn test_slice_usize() { +fn index_slice_with_usize() { let mut aos = Vec::new(); let mut soa = ParticleVec::new(); @@ -180,7 +180,7 @@ fn test_slice_usize() { } #[test] -fn test_slice_ranges() { +fn index_slice_with_ranges() { let mut particles = Vec::new(); particles.push(Particle::new(String::from("Cl"), 1.0)); particles.push(Particle::new(String::from("Na"), 2.0)); From c3c3f01fb0fe76e98eb6be967457e0d16b270990 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Angelats=20i=20Ribera?= Date: Fri, 2 Oct 2020 16:47:31 +0200 Subject: [PATCH 14/14] Revert use of soa slices' index in tests This partially reverts commit 1697674d725eba0f2791651bba5c3e691d06d472. --- tests/ptr.rs | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/ptr.rs b/tests/ptr.rs index cfa219a..1785c7a 100644 --- a/tests/ptr.rs +++ b/tests/ptr.rs @@ -53,11 +53,11 @@ fn slice() { unsafe { let slice = ParticleSlice::from_raw_parts(ptr, 2); assert_eq!(slice.len(), 2); - assert_eq!(slice.index(0).name, "Na"); - assert_eq!(slice.index(1).name, "Zn"); + assert_eq!(slice.name[0], "Na"); + assert_eq!(slice.name[1], "Zn"); - assert_eq!(*slice.index(0).mass, 1.0); - assert_eq!(*slice.index(1).mass, 2.0); + assert_eq!(slice.mass[0], 1.0); + assert_eq!(slice.mass[1], 2.0); } } @@ -77,8 +77,8 @@ fn slice_mut() { *ptr.as_mut().unwrap().mass = 42.0; } - assert_eq!(slice.index(0).name, "Fe"); - assert_eq!(*slice.index(0).mass, 42.0); + assert_eq!(slice.name[0], "Fe"); + assert_eq!(slice.mass[0], 42.0); unsafe { let slice = ParticleSliceMut::from_raw_parts_mut(ptr, 2); @@ -88,9 +88,9 @@ fn slice_mut() { } } - assert_eq!(*slice.index(0).mass, -1.0); - assert_eq!(*slice.index(1).mass, -1.0); - assert_eq!(*slice.index(2).mass, 3.0); + assert_eq!(slice.mass[0], -1.0); + assert_eq!(slice.mass[1], -1.0); + assert_eq!(slice.mass[2], 3.0); } #[test] @@ -119,12 +119,12 @@ fn vec() { assert_eq!(particles.len(), len); assert_eq!(particles.capacity(), capacity); - assert_eq!(particles.index(0).name, "Fe"); - assert_eq!(*particles.index(0).mass, 42.0); + assert_eq!(particles.name[0], "Fe"); + assert_eq!(particles.mass[0], 42.0); - assert_eq!(particles.index(1).name, "Zn"); - assert_eq!(*particles.index(1).mass, 2.0); + assert_eq!(particles.name[1], "Zn"); + assert_eq!(particles.mass[1], 2.0); - assert_eq!(particles.index(2).name, "Fe"); - assert_eq!(*particles.index(2).mass, 3.0); + assert_eq!(particles.name[2], "Fe"); + assert_eq!(particles.mass[2], 3.0); }