diff --git a/kad/trie/trie.go b/kad/trie/trie.go index 588d235..50f33ab 100644 --- a/kad/trie/trie.go +++ b/kad/trie/trie.go @@ -129,7 +129,7 @@ func (tr *Trie[K, D]) addAtDepth(depth int, kk K, data D) bool { } } -// Add adds the key to trie, returning a new trie. +// Add adds the key to trie, returning a new trie if the key was not already in the trie. // Add is immutable/non-destructive: the original trie remains unchanged. func Add[K kad.Key[K], D any](tr *Trie[K, D], kk K, data D) (*Trie[K, D], error) { return addAtDepth(0, tr, kk, data), nil @@ -148,10 +148,14 @@ func addAtDepth[K kad.Key[K], D any](depth int, tr *Trie[K, D], kk K, data D) *T default: dir := kk.Bit(depth) - s := &Trie[K, D]{} - s.branch[dir] = addAtDepth(depth+1, tr.branch[dir], kk, data) - s.branch[1-dir] = tr.branch[1-dir] - return s + b := addAtDepth(depth+1, tr.branch[dir], kk, data) + if b != tr.branch[dir] { + s := &Trie[K, D]{} + s.branch[dir] = b + s.branch[1-dir] = tr.branch[1-dir] + return s + } + return tr } } diff --git a/kad/trie/trie_test.go b/kad/trie/trie_test.go index d7acd65..91cff70 100644 --- a/kad/trie/trie_test.go +++ b/kad/trie/trie_test.go @@ -206,6 +206,28 @@ func TestImmutableAddIgnoresDuplicates(t *testing.T) { } } +func TestImmutableAddReturnsOriginalTrieForDuplicates(t *testing.T) { + tr := New[kadtest.Key32, any]() + var err error + for _, kk := range sampleKeySet.Keys { + tr, err = Add(tr, kk, nil) + require.NoError(t, err) + } + require.Equal(t, len(sampleKeySet.Keys), tr.Size()) + + for _, kk := range sampleKeySet.Keys { + next, err := Add(tr, kk, nil) + require.NoError(t, err) + // trie has not been changed + require.Same(t, tr, next) + } + require.Equal(t, len(sampleKeySet.Keys), tr.Size()) + + if d := CheckInvariant(tr); d != nil { + t.Fatalf("reordered trie invariant discrepancy: %v", d) + } +} + func TestAddWithData(t *testing.T) { tr := New[kadtest.Key32, int]() for i, kk := range sampleKeySet.Keys { diff --git a/kad/triert/table_test.go b/kad/triert/table_test.go index caf1f80..a2268a0 100644 --- a/kad/triert/table_test.go +++ b/kad/triert/table_test.go @@ -321,7 +321,8 @@ type nodeFilter struct { } func (f *nodeFilter) TryAdd(rt *TrieRT[kadtest.Key32, node[kadtest.Key32]], - n node[kadtest.Key32]) bool { + n node[kadtest.Key32], +) bool { if n == node2 { return false }