Skip to content

Commit

Permalink
OrderedSet: optimize intersect
Browse files Browse the repository at this point in the history
  • Loading branch information
GoPavel committed Nov 8, 2024
1 parent d745f59 commit a104c88
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 31 deletions.
93 changes: 66 additions & 27 deletions src/OrderedSet.mo
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
/// * Stefan Kahrs, "Red-black trees with types", Journal of Functional Programming, 11(4): 425-432 (2001), [version 1 in web appendix](http://www.cs.ukc.ac.uk/people/staff/smk/redblack/rb.html).

import Debug "Debug";
import Buffer "Buffer";
import I "Iter";
import List "List";
import Nat "Nat";
import Option "Option";
import O "Order";

module {
Expand Down Expand Up @@ -269,32 +269,26 @@ module {
/// ```
///
/// Runtime: `O(m * log(n))`.
/// Space: `O(1)`, retained memory plus garbage, see the note below.
/// Space: `O(m)`, retained memory plus garbage, see the note below.
/// where `m` and `n` denote the number of elements in the sets, and `m <= n`.
///
/// Note: Creates `O(log(n))` temporary objects that will be collected as garbage.
/// Note: Creates `O(n * log(n))` temporary objects that will be collected as garbage.
public func intersect(s1 : Set<T>, s2 : Set<T>) : Set<T> {
if (size(s1) < size(s2)) {
foldLeft(s1, empty(),
func (elem : T, acc : Set<T>) : Set<T> {
if (Internal.contains(s2.root, compare, elem)) {
Internal.put(acc, compare, elem)
} else {
acc
}
let elems = Buffer.Buffer<T>(Nat.min(Nat.min(s1.size, s2.size), 100));
if (s1.size < s2.size) {
Internal.iterate(s1.root, func (x: T) {
if (Internal.contains(s2.root, compare, x)) {
elems.add(x)
}
)
});
} else {
foldLeft(s2, empty(),
func (elem : T, acc : Set<T>) : Set<T> {
if (Internal.contains(s1.root, compare, elem)) {
Internal.put(acc, compare, elem)
} else {
acc
}
Internal.iterate(s2.root, func (x: T) {
if (Internal.contains(s1.root, compare, x)) {
elems.add(x)
}
)
}
});
};
{ root = Internal.buildFromSorted(elems); size = elems.size() }
};

/// [Set difference](https://en.wikipedia.org/wiki/Difference_(set_theory)).
Expand All @@ -315,17 +309,20 @@ module {
/// ```
///
/// Runtime: `O(m * log(n))`.
/// Space: `O(1)`, retained memory plus garbage, see the note below.
/// Space: `O(m)`, retained memory plus garbage, see the note below.
/// where `m` and `n` denote the number of elements in the sets, and `m <= n`.
///
/// Note: Creates `O(log(n))` temporary objects that will be collected as garbage.
/// Note: Creates `O(m * log(n))` temporary objects that will be collected as garbage.
public func diff(s1 : Set<T>, s2 : Set<T>) : Set<T> {
if (size(s1) < size(s2)) {
foldLeft(s1, empty(),
func (elem : T, acc : Set<T>) : Set<T> {
if (Internal.contains(s2.root, compare, elem)) { acc } else { Internal.put(acc, compare, elem) }
let elems = Buffer.Buffer<T>(Nat.min(s1.size, 100));
Internal.iterate(s1.root, func (x : T) {
if (not Internal.contains(s2.root, compare, x)) {
elems.add(x)
}
}
)
);
{ root = Internal.buildFromSorted(elems); size = elems.size() }
}
else {
foldLeft(s2, s1,
Expand Down Expand Up @@ -782,6 +779,44 @@ module {
}
};

public func iterate<V>(m : Tree<V>, f : V -> ()) {
switch m {
case (#leaf) { };
case (#black(l, v, r)) { iterate(l, f); f(v); iterate(r, f) };
case (#red(l, v, r)) { iterate(l, f); f(v); iterate(r, f) }
}
};

// build tree from elements arr[l]..arr[r-1]
public func buildFromSorted<V>(buf : Buffer.Buffer<V>) : Tree<V> {
var maxDepth = 0;
var maxSize = 1;
while (maxSize < buf.size()) {
maxDepth += 1;
maxSize += maxSize + 1;
};
maxDepth := if (maxDepth == 0) {1} else {maxDepth}; // keep root black for 1 element tree
func buildFromSortedHelper(l : Nat, r : Nat, depth : Nat) : Tree<V> {
if (l + 1 == r) {
if (depth == maxDepth) {
return #red(#leaf, buf.get(l), #leaf);
} else {
return #black(#leaf, buf.get(l), #leaf);
}
};
if (l >= r) {
return #leaf;
};
let m = (l + r) / 2;
return #black(
buildFromSortedHelper(l, m, depth+1),
buf.get(m),
buildFromSortedHelper(m+1, r, depth+1)
)
};
buildFromSortedHelper(0, buf.size(), 0);
};

type IterRep<T> = List.List<{ #tr : Tree<T>; #x : T }>;

type SetTraverser<T> = (Tree<T>, T, Tree<T>, IterRep<T>) -> IterRep<T>;
Expand Down Expand Up @@ -1113,6 +1148,10 @@ module {

/// Test helpers
public module SetDebug {
public func buildFromSorted<T>(a : [T]) : Set<T> {
{ root = Internal.buildFromSorted(Buffer.fromArray<T>(a)); size = a.size()}
};

// check binary search tree order of elements and black depth invariant of the RB-tree
public func checkSetInvariants<T>(s : Set<T>, comp : (T, T) -> O.Order) {
ignore blackDepth(s.root, comp)
Expand Down
21 changes: 17 additions & 4 deletions test/OrderedSet.prop.test.mo
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,23 @@ func run_all_props(range: (Nat, Nat), size: Nat, set_samples: Nat, query_samples
}),
]),

prop("search tree invariant", func (s) {
Set.SetDebug.checkSetInvariants<Nat>(s, Nat.compare);
true
}),
suite(("Internal"), [
prop("search tree invariant", func (s) {
Set.SetDebug.checkSetInvariants<Nat>(s, Nat.compare);
true
}),
prop("buildFromSorted makes RB tree", func (s) {
let a = Iter.toArray(natSet.vals(s));
let t = Set.SetDebug.buildFromSorted(a);
Set.SetDebug.checkSetInvariants<Nat>(t, Nat.compare);
true
}),
prop("buildFromSorted(toArray(t)) == t", func (s) {
let a = Iter.toArray(natSet.vals(s));
let t = Set.SetDebug.buildFromSorted(a);
SetMatcher(s).matches(t)
})
]),

suite("mapFilter", [
prop_with_elem("not contains(mapFilter(s, (!=e)), e)", func (s, e) {
Expand Down

0 comments on commit a104c88

Please sign in to comment.