diff --git a/soa-derive-internal/src/index.rs b/soa-derive-internal/src/index.rs new file mode 100644 index 0000000..8db82c6 --- /dev/null +++ b/soa-derive-internal/src/index.rs @@ -0,0 +1,662 @@ +use proc_macro2::TokenStream; +use quote::quote; + +use crate::input::Input; + +pub fn derive(input: &Input) -> TokenStream { + let vec_name = &input.vec_name(); + let slice_name = &input.slice_name(); + let slice_mut_name = &input.slice_mut_name(); + let ref_name = &input.ref_name(); + let ref_mut_name = &input.ref_mut_name(); + let fields_names = input.fields.iter() + .map(|field| field.ident.clone().unwrap()) + .collect::>(); + let fields_names_1 = &fields_names; + let fields_names_2 = &fields_names; + let first_field_name = &fields_names[0]; + + quote!{ + // usize + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for usize { + type RefOutput = #ref_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + if self < soa.len() { + Some(unsafe { self.get_unchecked(soa) }) + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + #ref_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked(self),)* + } + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + #ref_name { + #(#fields_names_1: & soa.#fields_names_2[self],)* + } + } + } + + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for usize { + type MutOutput = #ref_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + if self < soa.len() { + Some(unsafe { self.get_unchecked_mut(soa) }) + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #ref_mut_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked_mut(self),)* + } + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #ref_mut_name { + #(#fields_names_1: &mut soa.#fields_names_2[self],)* + } + } + } + + + + // Range + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::Range { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + if self.start <= self.end && self.end <= soa.len() { + unsafe { Some(self.get_unchecked(soa)) } + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + #slice_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked(self.clone()),)* + } + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + #slice_name { + #(#fields_names_1: & soa.#fields_names_2[self.clone()],)* + } + } + } + + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::Range { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + if self.start <= self.end && self.end <= soa.len() { + unsafe { Some(self.get_unchecked_mut(soa)) } + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #slice_mut_name { + #(#fields_names_1: soa.#fields_names_2.get_unchecked_mut(self.clone()),)* + } + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + #slice_mut_name { + #(#fields_names_1: &mut soa.#fields_names_2[self.clone()],)* + } + } + } + + + + // RangeTo + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::RangeTo { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + (0..self.end).get(soa) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..self.end).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..self.end).index(soa) + } + } + + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::RangeTo { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + (0..self.end).get_mut(soa) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..self.end).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..self.end).index_mut(soa) + } + } + + + // RangeFrom + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::RangeFrom { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + (self.start..soa.len()).get(soa) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (self.start..soa.len()).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (self.start..soa.len()).index(soa) + } + } + + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::RangeFrom { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + (self.start..soa.len()).get_mut(soa) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (self.start..soa.len()).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (self.start..soa.len()).index_mut(soa) + } + } + + + // RangeFull + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::RangeFull { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + Some(soa.as_slice()) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + soa.as_slice() + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + soa.as_slice() + } + } + + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::RangeFull { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + Some(soa.as_mut_slice()) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + soa.as_mut_slice() + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + soa.as_mut_slice() + } + } + + + // RangeInclusive + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::RangeInclusive { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + if *self.end() == usize::MAX { + None + } else { + (*self.start()..self.end() + 1).get(soa) + } + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (*self.start()..self.end() + 1).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (*self.start()..self.end() + 1).index(soa) + } + } + + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::RangeInclusive { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + if *self.end() == usize::MAX { + None + } else { + (*self.start()..self.end() + 1).get_mut(soa) + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (*self.start()..self.end() + 1).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (*self.start()..self.end() + 1).index_mut(soa) + } + } + + + // RangeToInclusive + impl<'a> ::soa_derive::SoAIndex<&'a #vec_name> for ::std::ops::RangeToInclusive { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, soa: &'a #vec_name) -> Option { + (0..=self.end).get(soa) + } + + #[inline] + unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..=self.end).get_unchecked(soa) + } + + #[inline] + fn index(self, soa: &'a #vec_name) -> Self::RefOutput { + (0..=self.end).index(soa) + } + } + + impl<'a> ::soa_derive::SoAIndexMut<&'a mut #vec_name> for ::std::ops::RangeToInclusive { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, soa: &'a mut #vec_name) -> Option { + (0..=self.end).get_mut(soa) + } + + #[inline] + unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..=self.end).get_unchecked_mut(soa) + } + + #[inline] + fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput { + (0..=self.end).index_mut(soa) + } + } + + // usize + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for usize { + type RefOutput = #ref_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + if self < slice.#first_field_name.len() { + Some(unsafe { self.get_unchecked(slice) }) + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + #ref_name { + #(#fields_names_1: slice.#fields_names_2.get_unchecked(self),)* + } + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + #ref_name { + #(#fields_names_1: & slice.#fields_names_2[self],)* + } + } + } + + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for usize { + type MutOutput = #ref_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + if self < slice.len() { + Some(unsafe { self.get_unchecked_mut(slice) }) + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + #ref_mut_name { + #(#fields_names_1: slice.#fields_names_2.get_unchecked_mut(self),)* + } + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + #ref_mut_name { + #(#fields_names_1: &mut slice.#fields_names_2[self],)* + } + } + } + + + + // Range + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::Range { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + if self.start <= self.end && self.end <= slice.#first_field_name.len() { + unsafe { Some(self.get_unchecked(slice)) } + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + #slice_name { + #(#fields_names_1: slice.#fields_names_2.get_unchecked(self.clone()),)* + } + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + #slice_name { + #(#fields_names_1: & slice.#fields_names_2[self.clone()],)* + } + } + } + + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::Range { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + if self.start <= self.end && self.end <= slice.#first_field_name.len() { + unsafe { Some(self.get_unchecked_mut(slice)) } + } else { + None + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + #slice_mut_name { + #(#fields_names_1: slice.#fields_names_2.get_unchecked_mut(self.clone()),)* + } + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + #slice_mut_name { + #(#fields_names_1: &mut slice.#fields_names_2[self.clone()],)* + } + } + } + + + + // RangeTo + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::RangeTo { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + (0..self.end).get(slice) + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + (0..self.end).get_unchecked(slice) + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + (0..self.end).index(slice) + } + } + + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::RangeTo { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + (0..self.end).get_mut(slice) + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (0..self.end).get_unchecked_mut(slice) + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (0..self.end).index_mut(slice) + } + } + + + // RangeFrom + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::RangeFrom { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + (self.start..slice.len()).get(slice) + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + (self.start..slice.len()).get_unchecked(slice) + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + (self.start..slice.len()).index(slice) + } + } + + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::RangeFrom { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + (self.start..slice.len()).get_mut(slice) + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (self.start..slice.len()).get_unchecked_mut(slice) + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (self.start..slice.len()).index_mut(slice) + } + } + + + // RangeFull + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::RangeFull { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + Some(slice) + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + slice + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + slice + } + } + + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::RangeFull { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + Some(slice) + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + slice + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + slice + } + } + + + // RangeInclusive + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::RangeInclusive { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + if *self.end() == usize::MAX { + None + } else { + (*self.start()..self.end() + 1).get(slice) + } + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + (*self.start()..self.end() + 1).get_unchecked(slice) + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + (*self.start()..self.end() + 1).index(slice) + } + } + + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::RangeInclusive { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + if *self.end() == usize::MAX { + None + } else { + (*self.start()..self.end() + 1).get_mut(slice) + } + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (*self.start()..self.end() + 1).get_unchecked_mut(slice) + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (*self.start()..self.end() + 1).index_mut(slice) + } + } + + + // RangeToInclusive + impl<'a> ::soa_derive::SoAIndex<#slice_name<'a>> for ::std::ops::RangeToInclusive { + type RefOutput = #slice_name<'a>; + + #[inline] + fn get(self, slice: #slice_name<'a>) -> Option { + (0..=self.end).get(slice) + } + + #[inline] + unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput { + (0..=self.end).get_unchecked(slice) + } + + #[inline] + fn index(self, slice: #slice_name<'a>) -> Self::RefOutput { + (0..=self.end).index(slice) + } + } + + impl<'a> ::soa_derive::SoAIndexMut<#slice_mut_name<'a>> for ::std::ops::RangeToInclusive { + type MutOutput = #slice_mut_name<'a>; + + #[inline] + fn get_mut(self, slice: #slice_mut_name<'a>) -> Option { + (0..=self.end).get_mut(slice) + } + + #[inline] + unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (0..=self.end).get_unchecked_mut(slice) + } + + #[inline] + fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput { + (0..=self.end).index_mut(slice) + } + } + } +} diff --git a/soa-derive-internal/src/lib.rs b/soa-derive-internal/src/lib.rs index b2510c0..6adacd1 100644 --- a/soa-derive-internal/src/lib.rs +++ b/soa-derive-internal/src/lib.rs @@ -9,6 +9,7 @@ extern crate proc_macro; use proc_macro2::TokenStream; use quote::TokenStreamExt; +mod index; mod input; mod iter; mod ptr; @@ -27,6 +28,7 @@ pub fn soa_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { generated.append_all(ptr::derive(&input)); generated.append_all(slice::derive(&input)); generated.append_all(slice::derive_mut(&input)); + generated.append_all(index::derive(&input)); generated.append_all(iter::derive(&input)); generated.append_all(derive_trait(&input)); generated.into() diff --git a/soa-derive-internal/src/slice.rs b/soa-derive-internal/src/slice.rs index f47c063..acd3e53 100644 --- a/soa-derive-internal/src/slice.rs +++ b/soa-derive-internal/src/slice.rs @@ -153,22 +153,47 @@ pub fn derive(input: &Input) -> TokenStream { /// Similar to [` #[doc = #slice_name_str] /// ::get()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get). - pub fn get(&self, i: usize) -> Option<#ref_name> { - if self.is_empty() || i >= self.len() { - None - } else { - Some(#ref_name { - #(#fields_names_1: self.#fields_names_2.get(i).unwrap(),)* - }) - } + pub fn get<'b, I>(&'b self, index: I) -> Option + where + I: ::soa_derive::SoAIndex<#slice_name<'b>>, + 'a: 'b + { + let slice: #slice_name<'b> = self.reborrow(); + index.get(slice) } /// Similar to [` #[doc = #slice_name_str] /// ::get_unchecked()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked). - pub unsafe fn get_unchecked(&self, i: usize) -> #ref_name { - #ref_name { - #(#fields_names_1: self.#fields_names_2.get_unchecked(i),)* + pub unsafe fn get_unchecked<'b, I>(&'b self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoAIndex<#slice_name<'b>>, + 'a: 'b + { + let slice: #slice_name<'b> = self.reborrow(); + index.get_unchecked(slice) + } + + /// Similar to [`std::ops::Index` trait](https://doc.rust-lang.org/std/ops/trait.Index.html) on + #[doc = #slice_name_str] + /// . + /// This is required because we cannot implement `std::ops::Index` directly since it requires returning a reference. + pub fn index<'b, I>(&'b self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoAIndex<#slice_name<'b>>, + 'a: 'b + { + let slice: #slice_name<'b> = self.reborrow(); + index.index(slice) + } + + /// Reborrows the slices in a narrower lifetime + pub fn reborrow<'b>(&'b self) -> #slice_name<'b> + where + 'a: 'b + { + #slice_name { + #(#fields_names_1: &self.#fields_names_2,)* } } @@ -215,7 +240,6 @@ pub fn derive_mut(input: &Input) -> TokenStream { let slice_name = &input.slice_name(); let slice_mut_name = &input.slice_mut_name(); let vec_name = &input.vec_name(); - let ref_name = &input.ref_name(); let ref_mut_name = &input.ref_mut_name(); let ptr_name = &input.ptr_name(); let ptr_mut_name = &input.ptr_mut_name(); @@ -382,44 +406,95 @@ pub fn derive_mut(input: &Input) -> TokenStream { /// Similar to [` #[doc = #slice_name_str] /// ::get()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get). - pub fn get(&self, i: usize) -> Option<#ref_name> { - if self.is_empty() || i >= self.len() { - None - } else { - Some(#ref_name { - #(#fields_names_1: self.#fields_names_2.get(i).unwrap(),)* - }) - } + pub fn get<'b, I>(&'b self, index: I) -> Option + where + I: ::soa_derive::SoAIndex<#slice_name<'b>>, + 'a: 'b + { + let slice: #slice_name<'b> = self.as_slice(); + index.get(slice) } /// Similar to [` #[doc = #slice_name_str] /// ::get_unchecked()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked). - pub unsafe fn get_unchecked(&self, i: usize) -> #ref_name { - #ref_name { - #(#fields_names_1: self.#fields_names_2.get_unchecked(i),)* - } + pub unsafe fn get_unchecked<'b, I>(&'b self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoAIndex<#slice_name<'b>>, + 'a: 'b + { + let slice: #slice_name<'b> = self.as_slice(); + index.get_unchecked(slice) + } + + + /// Similar to [`std::ops::Index` trait](https://doc.rust-lang.org/std/ops/trait.Index.html) on + #[doc = #slice_name_str] + /// . + /// This is required because we cannot implement that trait. + pub fn index<'b, I>(&'b self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoAIndex<#slice_name<'b>>, + 'a: 'b + { + let slice: #slice_name<'b> = self.as_slice(); + index.index(slice) } /// Similar to [` #[doc = #slice_name_str] /// ::get_mut()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_mut). - pub fn get_mut(&mut self, i: usize) -> Option<#ref_mut_name> { - if self.is_empty() || i >= self.len() { - None - } else { - Some(#ref_mut_name { - #(#fields_names_1: self.#fields_names_2.get_mut(i).unwrap(),)* - }) - } + pub fn get_mut<'b, I>(&'b mut self, index: I) -> Option + where + I: ::soa_derive::SoAIndexMut<#slice_mut_name<'b>>, + 'a: 'b + { + let slice: #slice_mut_name<'b> = self.reborrow(); + index.get_mut(slice) } /// Similar to [` #[doc = #slice_name_str] /// ::get_unchecked_mut()`](https://doc.rust-lang.org/std/primitive.slice.html#method.get_unchecked_mut). - pub unsafe fn get_unchecked_mut(&mut self, i: usize) -> #ref_mut_name { - #ref_mut_name { - #(#fields_names_1: self.#fields_names_2.get_unchecked_mut(i),)* + pub unsafe fn get_unchecked_mut<'b, I>(&'b mut self, index: I) -> I::MutOutput + where + I: ::soa_derive::SoAIndexMut<#slice_mut_name<'b>>, + 'a: 'b + { + let slice: #slice_mut_name<'b> = self.reborrow(); + index.get_unchecked_mut(slice) + } + + /// Similar to [`std::ops::IndexMut` trait](https://doc.rust-lang.org/std/ops/trait.IndexMut.html) on + #[doc = #slice_name_str] + /// . + /// This is required because we cannot implement `std::ops::IndexMut` directly since it requires returning a mutable reference. + pub fn index_mut<'b, I>(&'b mut self, index: I) -> I::MutOutput + where + I: ::soa_derive::SoAIndexMut<#slice_mut_name<'b>>, + 'a: 'b + { + let slice: #slice_mut_name<'b> = self.reborrow(); + index.index_mut(slice) + } + + /// Returns a non-mutable slice from this mutable slice. + pub fn as_slice<'b>(&'b self) -> #slice_name<'b> + where + 'a: 'b + { + #slice_name { + #(#fields_names_1: &self.#fields_names_2,)* + } + } + + /// Reborrows the slices in a narrower lifetime + pub fn reborrow<'b>(&'b mut self) -> #slice_mut_name<'b> + where + 'a: 'b + { + #slice_mut_name { + #(#fields_names_1: &mut *self.#fields_names_2,)* } } diff --git a/soa-derive-internal/src/vec.rs b/soa-derive-internal/src/vec.rs index e741c05..2f746e7 100644 --- a/soa-derive-internal/src/vec.rs +++ b/soa-derive-internal/src/vec.rs @@ -259,6 +259,66 @@ pub fn derive(input: &Input) -> TokenStream { } } + /// Similar to [` + #[doc = #vec_name_str] + /// ::get()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get). + pub fn get<'a, I>(&'a self, index: I) -> Option + where + I: ::soa_derive::SoAIndex<&'a #vec_name> + { + index.get(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::get_unchecked()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_unchecked). + pub unsafe fn get_unchecked<'a, I>(&'a self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoAIndex<&'a #vec_name> + { + index.get_unchecked(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::index()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.index). + pub fn index<'a, I>(&'a self, index: I) -> I::RefOutput + where + I: ::soa_derive::SoAIndex<&'a #vec_name> + { + index.index(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::get_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_mut). + pub fn get_mut<'a, I>(&'a mut self, index: I) -> Option + where + I: ::soa_derive::SoAIndexMut<&'a mut #vec_name> + { + index.get_mut(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::get_unchecked_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_unchecked_mut). + pub unsafe fn get_unchecked_mut<'a, I>(&'a mut self, index: I) -> I::MutOutput + where + I: ::soa_derive::SoAIndexMut<&'a mut #vec_name> + { + index.get_unchecked_mut(self) + } + + /// Similar to [` + #[doc = #vec_name_str] + /// ::index_mut()`](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.index_mut). + pub fn index_mut<'a, I>(&'a mut self, index: I) -> I::MutOutput + where + I: ::soa_derive::SoAIndexMut<&'a mut #vec_name> + { + index.index_mut(self) + } + /// Similar to [` #[doc = #vec_name_str] /// ::as_ptr()`](https://doc.rust-lang.org/std/struct.Vec.html#method.as_ptr). diff --git a/src/lib.rs b/src/lib.rs index eb4612f..ec2d0dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -168,6 +168,53 @@ pub trait StructOfArray { type Type; } + +mod private_soa_indexs { + // From [`std::slice::SliceIndex`](https://doc.rust-lang.org/std/slice/trait.SliceIndex.html) code. + // Limits the types that may implement the SoA index traits. + // It's also helpful to have the exaustive list of all accepted types. + + use ::std::ops; + + pub trait Sealed {} + + impl Sealed for usize {} // [a] + impl Sealed for ops::Range {} // [a..b] + impl Sealed for ops::RangeTo {} // [..b] + impl Sealed for ops::RangeFrom {} // [a..] + impl Sealed for ops::RangeFull {} // [..] + impl Sealed for ops::RangeInclusive {} // [a..=b] + impl Sealed for ops::RangeToInclusive {} // [..=b] +} + +/// Helper trait used for indexing operations. +/// Inspired by [`std::slice::SliceIndex`](https://doc.rust-lang.org/std/slice/trait.SliceIndex.html). +pub trait SoAIndex: private_soa_indexs::Sealed { + /// The output for the non-mutable functions + type RefOutput; + + /// Returns the reference output in this location if in bounds, `None` otherwise. + fn get(self, soa: T) -> Option; + /// Returns the reference output in this location without performing any bounds check. + unsafe fn get_unchecked(self, soa: T) -> Self::RefOutput; + /// Returns the reference output in this location. Panics if it is not in bounds. + fn index(self, soa: T) -> Self::RefOutput; +} + +/// Helper trait used for indexing operations returning mutable references. +/// Inspired by [`std::slice::SliceIndex`](https://doc.rust-lang.org/std/slice/trait.SliceIndex.html). +pub trait SoAIndexMut: private_soa_indexs::Sealed { + /// The output for the mutable functions + type MutOutput; + + /// Returns the mutable reference output in this location if in bounds, `None` otherwise. + fn get_mut(self, soa: T) -> Option; + /// Returns the mutable reference output in this location without performing any bounds check. + unsafe fn get_unchecked_mut(self, soa: T) -> Self::MutOutput; + /// Returns the mutable reference output in this location. Panics if it is not in bounds. + fn index_mut(self, soa: T) -> Self::MutOutput; +} + /// Create an iterator over multiple fields in a Struct of array style vector. /// /// This macro takes two main arguments: the array/slice container, and a list diff --git a/tests/index.rs b/tests/index.rs new file mode 100644 index 0000000..dc1f294 --- /dev/null +++ b/tests/index.rs @@ -0,0 +1,248 @@ +mod particles; +use self::particles::{Particle, ParticleVec, ParticleRef}; + + +/// Helper function to assert that two iterators (one of SoA and another of AoS) are equal. +fn eq_its<'a, I1, I2>(i1: I1, i2: I2) +where + I1: Iterator>, + I2: Iterator, +{ + for (p1, p2) in i1.zip(i2) { + assert_eq!(p1.name, &p2.name); + assert_eq!(*p1.mass, p2.mass); + } +} + +#[test] +fn index_vec_with_usize() { + let mut aos = Vec::new(); + let mut soa = ParticleVec::new(); + + let particle = Particle::new(String::from("Na"), 56.0); + aos.push(particle.clone()); + soa.push(particle.clone()); + + // SoAIndex + assert_eq!(soa.get(0).unwrap().name, &aos.get(0).unwrap().name); + assert_eq!(soa.get(0).unwrap().mass, &aos.get(0).unwrap().mass); + assert_eq!(aos.get(1), None); + assert_eq!(soa.get(1), None); + + unsafe { + assert_eq!(soa.get_unchecked(0).name, &aos.get_unchecked(0).name); + assert_eq!(soa.get_unchecked(0).mass, &aos.get_unchecked(0).mass); + } + + assert_eq!(soa.index(0).name, &aos[0].name); + assert_eq!(soa.index(0).mass, &aos[0].mass); + + + // SoaIndexMut + assert_eq!(soa.get_mut(0).unwrap().name, &aos.get_mut(0).unwrap().name); + assert_eq!(soa.get_mut(0).unwrap().mass, &aos.get_mut(0).unwrap().mass); + assert_eq!(aos.get_mut(1), None); + assert_eq!(soa.get_mut(1), None); + + unsafe { + assert_eq!(soa.get_unchecked_mut(0).name, &aos.get_unchecked_mut(0).name); + assert_eq!(soa.get_unchecked_mut(0).mass, &aos.get_unchecked_mut(0).mass); + } + + assert_eq!(soa.index_mut(0).name, &aos[0].name); + assert_eq!(soa.index_mut(0).mass, &aos[0].mass); + + + *soa.index_mut(0).mass -= 1.; + assert_eq!(soa.get(0).map(|p| *p.mass), Some(particle.mass - 1.)); + + *soa.get_mut(0).unwrap().mass += 1.; + assert_eq!(soa.get(0).map(|p| *p.mass), Some(particle.mass)); +} + +#[test] +fn index_vec_with_ranges() { + let mut particles = Vec::new(); + particles.push(Particle::new(String::from("Cl"), 1.0)); + particles.push(Particle::new(String::from("Na"), 2.0)); + particles.push(Particle::new(String::from("Br"), 3.0)); + particles.push(Particle::new(String::from("Zn"), 4.0)); + + let mut soa = ParticleVec::new(); + + for particle in particles.iter() { + soa.push(particle.clone()); + } + + eq_its(soa.iter(), particles.iter()); + + // All tests from here are the same only changing the range + + let range = 0..1; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..3; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = 1..; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = 0..=1; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..=2; + eq_its(soa.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa.index_mut(range.clone()).iter(), particles[range.clone()].iter()); +} + +#[test] +fn index_slice_with_usize() { + let mut aos = Vec::new(); + let mut soa = ParticleVec::new(); + + let particle = Particle::new(String::from("Na"), 56.0); + aos.push(particle.clone()); + soa.push(particle.clone()); + + // SoAIndex + let aos_slice = aos.as_slice(); + let soa_slice = soa.as_slice(); + + assert_eq!(soa_slice.get(0).unwrap().name, &aos_slice.get(0).unwrap().name); + assert_eq!(soa_slice.get(0).unwrap().mass, &aos_slice.get(0).unwrap().mass); + assert_eq!(aos_slice.get(1), None); + assert_eq!(soa_slice.get(1), None); + + unsafe { + assert_eq!(soa_slice.get_unchecked(0).name, &aos_slice.get_unchecked(0).name); + assert_eq!(soa_slice.get_unchecked(0).mass, &aos_slice.get_unchecked(0).mass); + } + + assert_eq!(soa_slice.index(0).name, &aos_slice[0].name); + assert_eq!(soa_slice.index(0).mass, &aos_slice[0].mass); + + + // SoaIndexMut + let aos_mut_slice = aos.as_mut_slice(); + let mut soa_mut_slice = soa.as_mut_slice(); + assert_eq!(soa_mut_slice.get_mut(0).unwrap().name, &aos_mut_slice.get_mut(0).unwrap().name); + assert_eq!(soa_mut_slice.get_mut(0).unwrap().mass, &aos_mut_slice.get_mut(0).unwrap().mass); + assert_eq!(soa_mut_slice.get_mut(0).unwrap().mass, &aos_mut_slice.get_mut(0).unwrap().mass); + assert_eq!(aos_mut_slice.get_mut(1), None); + assert_eq!(soa_mut_slice.get_mut(1), None); + + unsafe { + assert_eq!(soa_mut_slice.get_unchecked_mut(0).name, &aos_mut_slice.get_unchecked_mut(0).name); + assert_eq!(soa_mut_slice.get_unchecked_mut(0).mass, &aos_mut_slice.get_unchecked_mut(0).mass); + } + + assert_eq!(soa_mut_slice.index_mut(0).name, &aos_mut_slice[0].name); + assert_eq!(soa_mut_slice.index_mut(0).mass, &aos_mut_slice[0].mass); + + + *soa_mut_slice.index_mut(0).mass -= 1.; + assert_eq!(soa_mut_slice.get(0).map(|p| *p.mass), Some(particle.mass - 1.)); + + *soa_mut_slice.get_mut(0).unwrap().mass += 1.; + assert_eq!(soa_mut_slice.get(0).map(|p| *p.mass), Some(particle.mass)); +} + +#[test] +fn index_slice_with_ranges() { + let mut particles = Vec::new(); + particles.push(Particle::new(String::from("Cl"), 1.0)); + particles.push(Particle::new(String::from("Na"), 2.0)); + particles.push(Particle::new(String::from("Br"), 3.0)); + particles.push(Particle::new(String::from("Zn"), 4.0)); + + let mut soa = ParticleVec::new(); + + for particle in particles.iter() { + soa.push(particle.clone()); + } + + eq_its(soa.iter(), particles.iter()); + + let mut soa_mut_slice = soa.as_mut_slice(); + // All tests from here are the same only changing the range + + let range = 0..1; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..3; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = 1..; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = 0..=1; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); + + let range = ..=2; + eq_its(soa_mut_slice.get(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index(range.clone()).iter(), particles[range.clone()].iter()); + eq_its(soa_mut_slice.get_mut(range.clone()).unwrap().iter(), particles.get(range.clone()).unwrap().iter()); + unsafe { eq_its(soa_mut_slice.get_unchecked_mut(range.clone()).iter(), particles.get_unchecked(range.clone()).iter()); } + eq_its(soa_mut_slice.index_mut(range.clone()).iter(), particles[range.clone()].iter()); +} diff --git a/tests/iter.rs b/tests/iter.rs index 9c7125e..2af51ce 100644 --- a/tests/iter.rs +++ b/tests/iter.rs @@ -35,9 +35,9 @@ fn iter_mut() { for particle in particles.iter_mut() { *particle.mass += 1.0; } - assert_eq!(particles.mass[0], 1.0); - assert_eq!(particles.mass[1], 1.0); - assert_eq!(particles.mass[2], 1.0); + assert_eq!(*particles.index(0).mass, 1.0); + assert_eq!(*particles.index(1).mass, 1.0); + assert_eq!(*particles.index(2).mass, 1.0); { let mut slice = particles.as_mut_slice(); @@ -46,9 +46,9 @@ fn iter_mut() { } } - assert_eq!(particles.mass[0], 2.0); - assert_eq!(particles.mass[1], 2.0); - assert_eq!(particles.mass[2], 2.0); + assert_eq!(*particles.index(0).mass, 2.0); + assert_eq!(*particles.index(1).mass, 2.0); + assert_eq!(*particles.index(2).mass, 2.0); } #[test] diff --git a/tests/vec.rs b/tests/vec.rs index 0416201..d4d7af5 100644 --- a/tests/vec.rs +++ b/tests/vec.rs @@ -78,9 +78,9 @@ fn swap_remove() { let particle = particles.swap_remove(1); assert_eq!(particle.name, "Na"); - assert_eq!(particles.name[0], "Cl"); - assert_eq!(particles.name[1], "Zn"); - assert_eq!(particles.name[2], "Br"); + assert_eq!(particles.index(0).name, "Cl"); + assert_eq!(particles.index(1).name, "Zn"); + assert_eq!(particles.index(2).name, "Br"); } #[test] @@ -90,9 +90,9 @@ fn insert() { particles.push(Particle::new(String::from("Na"), 0.0)); particles.insert(1, Particle::new(String::from("Zn"), 0.0)); - assert_eq!(particles.name[0], "Cl"); - assert_eq!(particles.name[1], "Zn"); - assert_eq!(particles.name[2], "Na"); + assert_eq!(particles.index(0).name, "Cl"); + assert_eq!(particles.index(1).name, "Zn"); + assert_eq!(particles.index(2).name, "Na"); } #[test] @@ -122,10 +122,10 @@ fn append() { others.push(Particle::new(String::from("Mg"), 0.0)); particles.append(&mut others); - assert_eq!(particles.name[0], "Cl"); - assert_eq!(particles.name[1], "Na"); - assert_eq!(particles.name[2], "Zn"); - assert_eq!(particles.name[3], "Mg"); + assert_eq!(particles.index(0).name, "Cl"); + assert_eq!(particles.index(1).name, "Na"); + assert_eq!(particles.index(2).name, "Zn"); + assert_eq!(particles.index(3).name, "Mg"); } #[test] @@ -140,10 +140,10 @@ fn split_off() { assert_eq!(particles.len(), 2); assert_eq!(other.len(), 2); - assert_eq!(particles.name[0], "Cl"); - assert_eq!(particles.name[1], "Na"); - assert_eq!(other.name[0], "Zn"); - assert_eq!(other.name[1], "Mg"); + assert_eq!(particles.index(0).name, "Cl"); + assert_eq!(particles.index(1).name, "Na"); + assert_eq!(other.index(0).name, "Zn"); + assert_eq!(other.index(1).name, "Mg"); } #[test] @@ -156,6 +156,6 @@ fn retain() { particles.retain(|particle| particle.name.starts_with("C")); assert_eq!(particles.len(), 2); - assert_eq!(particles.name[0], "Cl"); - assert_eq!(particles.name[1], "C"); + assert_eq!(particles.index(0).name, "Cl"); + assert_eq!(particles.index(1).name, "C"); }