diff --git a/soa-derive-internal/src/iter.rs b/soa-derive-internal/src/iter.rs index b8984b1..035d406 100644 --- a/soa-derive-internal/src/iter.rs +++ b/soa-derive-internal/src/iter.rs @@ -25,6 +25,10 @@ pub fn derive(input: &Input) -> TokenStream { .map(|field| field.ident.clone().unwrap()) .collect::>(); + let fields_types = &input.fields.iter() + .map(|field| field.ty.clone()) + .collect::>(); + let iter_type = input.map_fields_nested_or( |_, field_type| quote! { <#field_type as soa_derive::SoAIter<'a>>::Iter }, |_, field_type| quote! { slice::Iter<'a, #field_type> }, @@ -289,6 +293,24 @@ pub fn derive(input: &Input) -> TokenStream { self.as_mut_slice().into_iter() } } + + impl Extend<#name> for #vec_name { + fn extend>(&mut self, iter: I) { + for item in iter { + self.push(item) + } + } + } + + impl<'a> Extend<#ref_name<'a>> for #vec_name + // only expose if all fields are Clone + // https://github.com/rust-lang/rust/issues/48214#issuecomment-1150463333 + where #( for<'b> #fields_types: Clone, )* + { + fn extend>>(&mut self, iter: I) { + self.extend(iter.into_iter().map(|item| item.to_owned())) + } + } }); } diff --git a/soa-derive-internal/src/refs.rs b/soa-derive-internal/src/refs.rs index 92bf924..a4dfc0e 100644 --- a/soa-derive-internal/src/refs.rs +++ b/soa-derive-internal/src/refs.rs @@ -122,7 +122,7 @@ pub fn derive(input: &Input) -> TokenStream { /// into an owned value. This is only available if all fields /// implement `Clone`. pub fn to_owned(&self) -> #name - // only expose to_owned is all fields are Clone + // only expose to_owned if all fields are Clone // https://github.com/rust-lang/rust/issues/48214#issuecomment-1150463333 where #( for<'b> #fields_types: Clone, )* { @@ -138,7 +138,7 @@ pub fn derive(input: &Input) -> TokenStream { /// into an owned value. This is only available if all fields /// implement `Clone`. pub fn to_owned(&self) -> #name - // only expose to_owned is all fields are Clone + // only expose to_owned if all fields are Clone // https://github.com/rust-lang/rust/issues/48214#issuecomment-1150463333 where #( for<'b> #fields_types: Clone, )* { diff --git a/tests/iter.rs b/tests/iter.rs index 2df9dc3..3ec1640 100644 --- a/tests/iter.rs +++ b/tests/iter.rs @@ -69,3 +69,19 @@ fn from_iter() { assert_eq!(particles, particles_from_iter) } + +#[test] +fn extend() { + let vec_with_particles = vec![ + Particle::new(String::from("Na"), 0.0), + Particle::new(String::from("Cl"), 0.0), + Particle::new(String::from("Zn"), 0.0), + ]; + + let particles_from_iter: ParticleVec = vec_with_particles.clone().into_iter().collect(); + + let mut particles = ParticleVec::new(); + particles.extend(vec_with_particles); + + assert_eq!(particles, particles_from_iter) +}