From b9b6827010b30609d3028c53bdd649f90f8c3321 Mon Sep 17 00:00:00 2001 From: Oren Laadan Date: Tue, 12 Sep 2023 20:53:14 +0800 Subject: [PATCH] feat: add slices helper Filter() Filter() takes a slice and returns a sub-slice based on an input filter function applied on each of the elements of the slice. Also rename the old Filter() to what it really is - Map(), which take a slice and returns a new slice with the result of applying the input map function on the elements of input slice. --- testutil/common/tester.go | 2 +- utils/slices/slices.go | 12 +++++++++++- utils/slices/slices_test.go | 15 +++++++++++---- x/pairing/keeper/pairing_test.go | 8 ++++---- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/testutil/common/tester.go b/testutil/common/tester.go index cd4ef47ef0..2ce3a0c4cc 100644 --- a/testutil/common/tester.go +++ b/testutil/common/tester.go @@ -240,7 +240,7 @@ func NewCoin(amount int64) sdk.Coin { } func NewCoins(amount ...int64) []sdk.Coin { - return slices.Filter(amount, NewCoin) + return slices.Map(amount, NewCoin) } // keeper helpers diff --git a/utils/slices/slices.go b/utils/slices/slices.go index 966298ce43..c0c6d5efef 100644 --- a/utils/slices/slices.go +++ b/utils/slices/slices.go @@ -162,7 +162,7 @@ func UnionByFunc[T ComparableByFunc](arrays ...[]T) []T { return res } -func Filter[T, V any](slice []T, filter func(T) V) []V { +func Map[T, V any](slice []T, filter func(T) V) []V { values := make([]V, len(slice)) for i := range slice { values[i] = filter(slice[i]) @@ -170,6 +170,16 @@ func Filter[T, V any](slice []T, filter func(T) V) []V { return values } +func Filter[T any](slice []T, filter func(T) bool) []T { + values := make([]T, 0) + for _, v := range slice { + if filter(v) { + values = append(values, v) + } + } + return values +} + func UnorderedEqual[T comparable](slices ...[]T) bool { var length int diff --git a/utils/slices/slices_test.go b/utils/slices/slices_test.go index cc7a415d74..a989a88d05 100644 --- a/utils/slices/slices_test.go +++ b/utils/slices/slices_test.go @@ -267,9 +267,16 @@ func TestUnorderedEqual(t *testing.T) { } } +func TestMap(t *testing.T) { + mapFunc := func(_ int) int { return 10 } + require.Equal(t, []int{}, Map([]int{}, mapFunc)) + require.Equal(t, []int{10}, Map([]int{1}, mapFunc)) + require.Equal(t, []int{10, 10, 10}, Map([]int{1, 2, 3}, mapFunc)) +} + func TestFilter(t *testing.T) { - filter := func(_ int) int { return 10 } - require.Equal(t, Filter([]int{}, filter), []int{}) - require.Equal(t, Filter([]int{1}, filter), []int{10}) - require.Equal(t, Filter([]int{1, 2, 3}, filter), []int{10, 10, 10}) + filter := func(v int) bool { return v%2 == 0 } + require.Equal(t, []int{}, Filter([]int{}, filter)) + require.Equal(t, []int{}, Filter([]int{1}, filter)) + require.Equal(t, []int{2, 4}, Filter([]int{1, 2, 3, 4}, filter)) } diff --git a/x/pairing/keeper/pairing_test.go b/x/pairing/keeper/pairing_test.go index 7c05476e6f..c765d4f73a 100644 --- a/x/pairing/keeper/pairing_test.go +++ b/x/pairing/keeper/pairing_test.go @@ -46,10 +46,10 @@ func TestPairingUniqueness(t *testing.T) { pairing2, err := ts.QueryPairingGetPairing(ts.spec.Index, sub2Addr) require.Nil(t, err) - filter := func(p epochstoragetypes.StakeEntry) string { return p.Address } + mapFunc := func(p epochstoragetypes.StakeEntry) string { return p.Address } - providerAddrs1 := slices.Filter(pairing1.Providers, filter) - providerAddrs2 := slices.Filter(pairing2.Providers, filter) + providerAddrs1 := slices.Map(pairing1.Providers, mapFunc) + providerAddrs2 := slices.Map(pairing2.Providers, mapFunc) require.Equal(t, len(pairing1.Providers), len(pairing2.Providers)) require.False(t, slices.UnorderedEqual(providerAddrs1, providerAddrs2)) @@ -60,7 +60,7 @@ func TestPairingUniqueness(t *testing.T) { pairing11, err := ts.QueryPairingGetPairing(ts.spec.Index, sub1Addr) require.Nil(t, err) - providerAddrs11 := slices.Filter(pairing11.Providers, filter) + providerAddrs11 := slices.Map(pairing11.Providers, mapFunc) require.Equal(t, len(pairing1.Providers), len(pairing11.Providers)) require.False(t, slices.UnorderedEqual(providerAddrs1, providerAddrs11))