Skip to content

Commit

Permalink
Generalise compare_fields to work with iterators (#4823)
Browse files Browse the repository at this point in the history
## Proposed Changes

Add `compare_fields(as_iter)` as a field attribute to `compare_fields_derive`. This allows any iterable type to be compared in the same as a slice (by index). 

This is forwards-compatible with tree-states types like `List` and `Vector` which can not be cast to slices.
  • Loading branch information
michaelsproul committed Oct 18, 2023
1 parent 1b4545c commit 463e62e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 13 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions common/compare_fields/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ version = "0.2.0"
authors = ["Paul Hauner <[email protected]>"]
edition = { workspace = true }

[dependencies]
itertools = { workspace = true }

[dev-dependencies]
compare_fields_derive = { workspace = true }

Expand Down
36 changes: 29 additions & 7 deletions common/compare_fields/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,8 @@
//! }
//! ];
//! assert_eq!(bar_a.compare_fields(&bar_b), bar_a_b);
//!
//!
//!
//! // TODO:
//! ```
use itertools::{EitherOrBoth, Itertools};
use std::fmt::Debug;

#[derive(Debug, PartialEq, Clone)]
Expand All @@ -112,13 +109,38 @@ impl Comparison {
}

pub fn from_slice<T: Debug + PartialEq<T>>(field_name: String, a: &[T], b: &[T]) -> Self {
Self::from_iter(field_name, a.iter(), b.iter())
}

pub fn from_into_iter<'a, T: Debug + PartialEq + 'a>(
field_name: String,
a: impl IntoIterator<Item = &'a T>,
b: impl IntoIterator<Item = &'a T>,
) -> Self {
Self::from_iter(field_name, a.into_iter(), b.into_iter())
}

pub fn from_iter<'a, T: Debug + PartialEq + 'a>(
field_name: String,
a: impl Iterator<Item = &'a T>,
b: impl Iterator<Item = &'a T>,
) -> Self {
let mut children = vec![];
let mut all_equal = true;

for i in 0..std::cmp::max(a.len(), b.len()) {
children.push(FieldComparison::new(format!("{i}"), &a.get(i), &b.get(i)));
for (i, entry) in a.zip_longest(b).enumerate() {
let comparison = match entry {
EitherOrBoth::Both(x, y) => {
FieldComparison::new(format!("{i}"), &Some(x), &Some(y))
}
EitherOrBoth::Left(x) => FieldComparison::new(format!("{i}"), &Some(x), &None),
EitherOrBoth::Right(y) => FieldComparison::new(format!("{i}"), &None, &Some(y)),
};
all_equal = all_equal && comparison.equal();
children.push(comparison);
}

Self::parent(field_name, a == b, children)
Self::parent(field_name, all_equal, children)
}

pub fn retain_children<F>(&mut self, f: F)
Expand Down
13 changes: 7 additions & 6 deletions common/compare_fields_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};

fn is_slice(field: &syn::Field) -> bool {
fn is_iter(field: &syn::Field) -> bool {
field.attrs.iter().any(|attr| {
attr.path.is_ident("compare_fields")
&& attr.tokens.to_string().replace(' ', "") == "(as_slice)"
&& (attr.tokens.to_string().replace(' ', "") == "(as_slice)"
|| attr.tokens.to_string().replace(' ', "") == "(as_iter)")
})
}

Expand All @@ -34,13 +35,13 @@ pub fn compare_fields_derive(input: TokenStream) -> TokenStream {
let field_name = ident_a.to_string();
let ident_b = ident_a.clone();

let quote = if is_slice(field) {
let quote = if is_iter(field) {
quote! {
comparisons.push(compare_fields::Comparison::from_slice(
comparisons.push(compare_fields::Comparison::from_into_iter(
#field_name.to_string(),
&self.#ident_a,
&b.#ident_b)
);
&b.#ident_b
));
}
} else {
quote! {
Expand Down

0 comments on commit 463e62e

Please sign in to comment.