diff --git a/src/PersistentOrderedMap.mo b/src/PersistentOrderedMap.mo index 18506243..62f55b7a 100644 --- a/src/PersistentOrderedMap.mo +++ b/src/PersistentOrderedMap.mo @@ -38,14 +38,12 @@ import O "Order"; module { - /// Node color: Either red (`#R`) or black (`#B`). - public type Color = { #R; #B }; - /// Red-black tree of nodes with key-value entries, ordered by the keys. /// The keys have the generic type `K` and the values the generic type `V`. /// Leaves are considered implicitly black. public type Map = { - #node : (Color, Map, (K, V), Map); + #red : (Map, (K, V), Map); + #black : (Map, (K, V), Map); #leaf }; @@ -324,30 +322,43 @@ module { /// /// Note: Full map iteration creates `O(n)` temporary objects that will be collected as garbage. public func iter(rbMap : Map, direction : Direction) : I.Iter<(K, V)> { - object { - var trees : IterRep = ?(#tr(rbMap), null); - public func next() : ?(K, V) { - switch (direction, trees) { - case (_, null) { null }; - case (_, ?(#tr(#leaf), ts)) { + let turnLeftFirst : MapTraverser + = func (l, xy, r, ts) { ?(#tr(l), ?(#xy(xy), ?(#tr(r), ts))) }; + + let turnRightFirst : MapTraverser + = func (l, xy, r, ts) { ?(#tr(r), ?(#xy(xy), ?(#tr(l), ts))) }; + + switch direction { + case (#fwd) IterMap(rbMap, turnLeftFirst); + case (#bwd) IterMap(rbMap, turnRightFirst) + } + }; + + type MapTraverser = (Map, (K, V), Map, IterRep) -> IterRep; + + class IterMap(rbMap : Map, mapTraverser : MapTraverser) { + var trees : IterRep = ?(#tr(rbMap), null); + public func next() : ?(K, V) { + switch (trees) { + case (null) { null }; + case (?(#tr(#leaf), ts)) { trees := ts; next() }; - case (_, ?(#xy(xy), ts)) { + case (?(#xy(xy), ts)) { trees := ts; ?xy - }; // TODO: Let's float-out case on direction - case (#fwd, ?(#tr(#node(_, l, xy, r)), ts)) { - trees := ?(#tr(l), ?(#xy(xy), ?(#tr(r), ts))); + }; + case (?(#tr(#red(l, xy, r)), ts)) { + trees := mapTraverser(l, xy, r, ts); next() }; - case (#bwd, ?(#tr(#node(_, l, xy, r)), ts)) { - trees := ?(#tr(r), ?(#xy(xy), ?(#tr(l), ts))); + case (?(#tr(#black(l, xy, r)), ts)) { + trees := mapTraverser(l, xy, r, ts); next() } } } - } }; /// Returns an Iterator (`Iter`) over the key-value pairs in the map. @@ -458,8 +469,11 @@ module { func mapRec(m : Map) : Map { switch m { case (#leaf) { #leaf }; - case (#node(c, l, xy, r)) { - #node(c, mapRec l, (xy.0, f xy), mapRec r) // TODO: try destination-passing style to avoid non tail-call recursion + case (#red(l, xy, r)) { + #red(mapRec l, (xy.0, f xy), mapRec r) + }; + case (#black(l, xy, r)) { + #black(mapRec l, (xy.0, f xy), mapRec r) }; } }; @@ -488,7 +502,10 @@ module { public func size(t : Map) : Nat { switch t { case (#leaf) { 0 }; - case (#node(_, l, _, r)) { + case (#red(l, _, r)) { + size(l) + size(r) + 1 + }; + case (#black(l, _, r)) { size(l) + size(r) + 1 } } @@ -527,11 +544,19 @@ module { combine : (Key, Value, Accum) -> Accum ) : Accum { - var acc = base; - for(val in iter(rbMap, #fwd)){ - acc := combine(val.0, val.1, acc); - }; - acc + switch (rbMap) { + case (#leaf) { base }; + case (#red(l, (k, v), r)) { + let left = foldLeft(l, base, combine); + let middle = combine(k, v, left); + foldLeft(r, middle, combine) + }; + case (#black(l, (k, v), r)) { + let left = foldLeft(l, base, combine); + let middle = combine(k, v, left); + foldLeft(r, middle, combine) + } + } }; /// Collapses the elements in `rbMap` into a single value by starting with `base` @@ -567,15 +592,23 @@ module { combine : (Key, Value, Accum) -> Accum ) : Accum { - var acc = base; - for(val in iter(rbMap, #bwd)){ - acc := combine(val.0, val.1, acc); - }; - acc + switch (rbMap) { + case (#leaf) { base }; + case (#red(l, (k, v), r)) { + let right = foldRight(r, base, combine); + let middle = combine(k, v, right); + foldRight(l, middle, combine) + }; + case (#black(l, (k, v), r)) { + let right = foldRight(r, base, combine); + let middle = combine(k, v, right); + foldRight(l, middle, combine) + } + } }; - module Internal { + public module Internal { public func fromIter(i : I.Iter<(K,V)>, compare : (K, K) -> O.Order) : Map { @@ -587,25 +620,28 @@ module { }; public func mapFilter(t : Map, compare : (K, K) -> O.Order, f : (K, V1) -> ?V2) : Map{ - var map = #leaf : Map; - for(kv in iter(t, #fwd)) - { - switch(f kv){ - case null {}; - case (?v1) { - // The keys still are monotonic, so we can - // merge trees using `append` and avoid compare here - map := put(map, compare, kv.0, v1); + func combine(key : K, value1 : V1, acc : Map) : Map { + switch (f(key, value1)){ + case null { acc }; + case (?value2) { + put(acc, compare, key, value2) } } }; - map + foldLeft(t, #leaf, combine) }; public func get(t : Map, compare : (K, K) -> O.Order, x : K) : ?V { switch t { case (#leaf) { null }; - case (#node(_c, l, xy, r)) { + case (#red(l, xy, r)) { + switch (compare(x, xy.0)) { + case (#less) { get(l, compare, x) }; + case (#equal) { ?xy.1 }; + case (#greater) { get(r, compare, x) } + } + }; + case (#black(l, xy, r)) { switch (compare(x, xy.0)) { case (#less) { get(l, compare, x) }; case (#equal) { ?xy.1 }; @@ -617,8 +653,8 @@ module { func redden(t : Map) : Map { switch t { - case (#node (#B, l, xy, r)) { - (#node (#R, l, xy, r)) + case (#black (l, xy, r)) { + (#red (l, xy, r)) }; case _ { Debug.trap "RBTree.red" @@ -628,44 +664,40 @@ module { func lbalance(left : Map, xy : (K,V), right : Map) : Map { switch (left, right) { - case (#node(#R, #node(#R, l1, xy1, r1), xy2, r2), r) { - #node( - #R, - #node(#B, l1, xy1, r1), + case (#red(#red(l1, xy1, r1), xy2, r2), r) { + #red( + #black(l1, xy1, r1), xy2, - #node(#B, r2, xy, r)) + #black(r2, xy, r)) }; - case (#node(#R, l1, xy1, #node(#R, l2, xy2, r2)), r) { - #node( - #R, - #node(#B, l1, xy1, l2), + case (#red(l1, xy1, #red(l2, xy2, r2)), r) { + #red( + #black(l1, xy1, l2), xy2, - #node(#B, r2, xy, r)) + #black(r2, xy, r)) }; case _ { - #node(#B, left, xy, right) + #black(left, xy, right) } } }; func rbalance(left : Map, xy : (K,V), right : Map) : Map { switch (left, right) { - case (l, #node(#R, l1, xy1, #node(#R, l2, xy2, r2))) { - #node( - #R, - #node(#B, l, xy, l1), + case (l, #red(l1, xy1, #red(l2, xy2, r2))) { + #red( + #black(l, xy, l1), xy1, - #node(#B, l2, xy2, r2)) + #black(l2, xy2, r2)) }; - case (l, #node(#R, #node(#R, l1, xy1, r1), xy2, r2)) { - #node( - #R, - #node(#B, l, xy, l1), + case (l, #red(#red(l1, xy1, r1), xy2, r2)) { + #red( + #black(l, xy, l1), xy1, - #node(#B, r1, xy2, r2)) + #black(r1, xy2, r2)) }; case _ { - #node(#B, left, xy, right) + #black(left, xy, right) }; } }; @@ -683,9 +715,9 @@ module { func ins(tree : Map) : Map { switch tree { case (#leaf) { - #node(#R, #leaf, (key,val), #leaf) + #red(#leaf, (key,val), #leaf) }; - case (#node(#B, left, xy, right)) { + case (#black(left, xy, right)) { switch (compare (key, xy.0)) { case (#less) { lbalance(ins left, xy, right) @@ -695,29 +727,29 @@ module { }; case (#equal) { let newVal = onClash({ new = val; old = xy.1 }); - #node(#B, left, (key,newVal), right) + #black(left, (key,newVal), right) } } }; - case (#node(#R, left, xy, right)) { + case (#red(left, xy, right)) { switch (compare (key, xy.0)) { case (#less) { - #node(#R, ins left, xy, right) + #red(ins left, xy, right) }; case (#greater) { - #node(#R, left, xy, ins right) + #red(left, xy, ins right) }; case (#equal) { let newVal = onClash { new = val; old = xy.1 }; - #node(#R, left, (key,newVal), right) + #red(left, (key,newVal), right) } } } }; }; switch (ins m) { - case (#node(#R, left, xy, right)) { - #node(#B, left, xy, right); + case (#red(left, xy, right)) { + #black(left, xy, right); }; case other { other }; }; @@ -750,19 +782,18 @@ module { func balLeft(left : Map, xy : (K,V), right : Map) : Map { switch (left, right) { - case (#node(#R, l1, xy1, r1), r) { - #node( - #R, - #node(#B, l1, xy1, r1), + case (#red(l1, xy1, r1), r) { + #red( + #black(l1, xy1, r1), xy, r) }; - case (_, #node(#B, l2, xy2, r2)) { - rbalance(left, xy, #node(#R, l2, xy2, r2)) + case (_, #black(l2, xy2, r2)) { + rbalance(left, xy, #red(l2, xy2, r2)) }; - case (_, #node(#R, #node(#B, l2, xy2, r2), xy3, r3)) { - #node(#R, - #node(#B, left, xy, l2), + case (_, #red(#black(l2, xy2, r2), xy3, r3)) { + #red( + #black(left, xy, l2), xy2, rbalance(r2, xy3, redden r3)) }; @@ -772,20 +803,20 @@ module { func balRight(left : Map, xy : (K,V), right : Map) : Map { switch (left, right) { - case (l, #node(#R, l1, xy1, r1)) { - #node(#R, + case (l, #red(l1, xy1, r1)) { + #red( l, xy, - #node(#B, l1, xy1, r1)) + #black(l1, xy1, r1)) }; - case (#node(#B, l1, xy1, r1), r) { - lbalance(#node(#R, l1, xy1, r1), xy, r); + case (#black(l1, xy1, r1), r) { + lbalance(#red(l1, xy1, r1), xy, r); }; - case (#node(#R, l1, xy1, #node(#B, l2, xy2, r2)), r3) { - #node(#R, + case (#red(l1, xy1, #black(l2, xy2, r2)), r3) { + #red( lbalance(redden l1, xy1, l2), xy2, - #node(#B, r2, xy, r3)) + #black(r2, xy, r3)) }; case _ { Debug.trap "balRight" }; } @@ -795,40 +826,39 @@ module { switch (left, right) { case (#leaf, _) { right }; case (_, #leaf) { left }; - case (#node (#R, l1, xy1, r1), - #node (#R, l2, xy2, r2)) { + case (#red (l1, xy1, r1), + #red (l2, xy2, r2)) { switch (append (r1, l2)) { - case (#node (#R, l3, xy3, r3)) { - #node( - #R, - #node(#R, l1, xy1, l3), + case (#red (l3, xy3, r3)) { + #red( + #red(l1, xy1, l3), xy3, - #node(#R, r3, xy2, r2)) + #red(r3, xy2, r2)) }; case r1l2 { - #node(#R, l1, xy1, #node(#R, r1l2, xy2, r2)) + #red(l1, xy1, #red(r1l2, xy2, r2)) } } }; - case (t1, #node(#R, l2, xy2, r2)) { - #node(#R, append(t1, l2), xy2, r2) + case (t1, #red(l2, xy2, r2)) { + #red(append(t1, l2), xy2, r2) }; - case (#node(#R, l1, xy1, r1), t2) { - #node(#R, l1, xy1, append(r1, t2)) + case (#red(l1, xy1, r1), t2) { + #red(l1, xy1, append(r1, t2)) }; - case (#node(#B, l1, xy1, r1), #node (#B, l2, xy2, r2)) { + case (#black(l1, xy1, r1), #black (l2, xy2, r2)) { switch (append (r1, l2)) { - case (#node (#R, l3, xy3, r3)) { - #node(#R, - #node(#B, l1, xy1, l3), + case (#red (l3, xy3, r3)) { + #red( + #black(l1, xy1, l3), xy3, - #node(#B, r3, xy2, r2)) + #black(r3, xy2, r2)) }; case r1l2 { balLeft ( l1, xy1, - #node(#B, r1l2, xy2, r2) + #black(r1l2, xy2, r2) ) } } @@ -846,22 +876,22 @@ module { case (#less) { let newLeft = del left; switch left { - case (#node(#B, _, _, _)) { + case (#black(_, _, _)) { balLeft(newLeft, xy, right) }; case _ { - #node(#R, newLeft, xy, right) + #red(newLeft, xy, right) } } }; case (#greater) { let newRight = del right; switch right { - case (#node(#B, _, _, _)) { + case (#black(_, _, _)) { balRight(left, xy, newRight) }; case _ { - #node(#R, left, xy, newRight) + #red(left, xy, newRight) } } }; @@ -876,14 +906,17 @@ module { case (#leaf) { tree }; - case (#node(_, left, xy, right)) { + case (#red(left, xy, right)) { + delNode(left, xy, right) + }; + case (#black(left, xy, right)) { delNode(left, xy, right) } }; }; switch (del(tree)) { - case (#node(#R, left, xy, right)) { - (#node(#B, left, xy, right), y0); + case (#red(left, xy, right)) { + (#black(left, xy, right), y0); }; case other { (other, y0) }; }; diff --git a/test/PersistentOrderedMap.test.mo b/test/PersistentOrderedMap.test.mo index 58e7b41d..0e6794c6 100644 --- a/test/PersistentOrderedMap.test.mo +++ b/test/PersistentOrderedMap.test.mo @@ -31,24 +31,24 @@ func checkMap(rbMap : Map.Map) { }; func blackDepth(node : Map.Map) : Nat { + func checkNode(left : Map.Map, key : Nat, right : Map.Map) : Nat { + checkKey(left, func(x) { x < key }); + checkKey(right, func(x) { x > key }); + let leftBlacks = blackDepth(left); + let rightBlacks = blackDepth(right); + assert (leftBlacks == rightBlacks); + leftBlacks + }; switch node { case (#leaf) 0; - case (#node(color, left, (key, _), right)) { - checkKey(left, func(x) { x < key }); - checkKey(right, func(x) { x > key }); - let leftBlacks = blackDepth(left); - let rightBlacks = blackDepth(right); - assert (leftBlacks == rightBlacks); - switch color { - case (#R) { - assert (not isRed(left)); - assert (not isRed(right)); - leftBlacks - }; - case (#B) { - leftBlacks + 1 - } - } + case (#red(left, (key, _), right)) { + let leftBlacks = checkNode(left, key, right); + assert (not isRed(left)); + assert (not isRed(right)); + leftBlacks + }; + case (#black(left, (key, _), right)) { + checkNode(left, key, right) + 1 } } }; @@ -56,15 +56,18 @@ func blackDepth(node : Map.Map) : Nat { func isRed(node : Map.Map) : Bool { switch node { - case (#leaf) false; - case (#node(color, _, _, _)) color == #R + case (#red(_, _, _)) true; + case _ false } }; func checkKey(node : Map.Map, isValid : Nat -> Bool) { switch node { case (#leaf) {}; - case (#node(_, _, (key, _), _)) { + case (#red( _, (key, _), _)) { + assert (isValid(key)) + }; + case (#black( _, (key, _), _)) { assert (isValid(key)) } }