Skip to content

Commit

Permalink
fixup: add tests and docs for Trie methods
Browse files Browse the repository at this point in the history
  • Loading branch information
olynch committed Mar 13, 2024
1 parent a0ae6d1 commit c77920f
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 36 deletions.
123 changes: 99 additions & 24 deletions src/syntax/Tries.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
module Tries
export Trie, NonEmptyTrie, AbstractTrie, PACKAGE_ROOT, ■, TrieVar, mapwithkey, traversewithkey
export Trie, NonEmptyTrie, AbstractTrie, PACKAGE_ROOT, ■, TrieVar, filtermap, mapwithkey, traversewithkey

using AbstractTrees
using OrderedCollections
using MLStyle
using StructEquality

"""
An internal node of a [`Trie`](@ref). Should not be used outside of this module.
Cannot be empty.
"""
struct Node_{X}
@struct_hash_equal struct Node_{X}
branches::OrderedDict{Symbol, X}
function Node_{X}(branches::OrderedDict{Symbol, X}) where {X}
if length(branches) == 0
Expand All @@ -31,12 +32,16 @@ end
"""
A leaf node of a [`Trie`](@ref). Should not be used outside of this module.
"""
struct Leaf_{A}
@struct_hash_equal struct Leaf_{A}
value::A
end

abstract type AbstractTrie{A} end

function Base.:(==)(t1::AbstractTrie, t2::AbstractTrie)
content(t1) == content(t2)
end

"""
A non-empty trie.
Expand All @@ -60,17 +65,13 @@ content(t::NonEmptyTrie) = getfield(t, :content)
inner = content(t)
if inner isa Node_
Some(inner.branches)
else
nothing
end
end

@active Leaf(t) begin
inner = content(t)
if inner isa Leaf_
Some(inner.value)
else
nothing
end
end

Expand All @@ -95,6 +96,13 @@ struct Trie{A} <: AbstractTrie{A}
content::Union{Nothing, NonEmptyTrie{A}}
end

function content(p::Trie)
unwrapped = getfield(p, :content)
if !isnothing(unwrapped)
content(unwrapped)
end
end

@active Empty(t) begin
isnothing(content(t))
end
Expand All @@ -106,21 +114,10 @@ end
c = getfield(t, :content)
if !isnothing(c)
Some(c)
else
nothing
end
end
end

function content(p::Trie)
unwrapped = getfield(p, :content)
if !isnothing(unwrapped)
content(unwrapped)
else
nothing
end
end

"""
Construct a new Trie node.
"""
Expand Down Expand Up @@ -187,19 +184,27 @@ function Base.filter(f, t::AbstractTrie{A}) where {A}
end
end

"""
filtermap(f, return_type::Type, t::AbstractTrie)
Map the function `f : eltype(t) -> Union{Some{return_type}, Nothing}` over the
trie `t` to produce a trie of type `return_type`, filtering out the elements of
`t` on which `f` returns `nothing`. We pass `return_type` explicitly so that in
the case `t` is the empty trie this doesn't return `Trie{Any}`.
"""
function filtermap(f, return_type::Type, t::AbstractTrie)
@match t begin
Leaf(v) => begin
v′ = f(v)
if !isnothing(v′)
leaf(v′)
@match f(v) begin
Some(v′) => leaf(v′)
nothing => Trie{return_type}()
end
end
Node(bs) => begin
bs′ = OrderedDict{Symbol, NonEmptyTrie{return_type}}()
for (n, s) in bs
@match filtermap(f, return_type, s) begin
NonEmpty(net) => (bs′[n] = s′)
NonEmpty(net) => (bs′[n] = net)
Empty() => nothing
end
end
Expand All @@ -209,18 +214,34 @@ function filtermap(f, return_type::Type, t::AbstractTrie)
end
end

"""
zipwith(f, t1::AbstractTrie, t2::AbstractTrie)
Produces a new trie whose leaf node at a path `p` is given by `f(t1[p], t2[p])`.
Throws an error if `t1` and `t2` are not of the same shape: i.e. they don't
have the exact same set of paths.
"""
function zipwith(f, t1::AbstractTrie{A1}, t2::AbstractTrie{A2}) where {A1, A2}
@match (t1, t2) begin
(Leaf(v1), Leaf(v2)) => leaf(f(v1, v2))
(Node(bs1), Node(bs2)) => begin
keys(bs1) == keys(bs2) || error("cannot zip two tries not of the same shape")
node(OrderedDict(n => zip(s1, s2) for ((n, s1), (_, s2)) in zip(bs1, bs2)))
node(OrderedDict(n => zipwith(f, s1, s2) for ((n, s1), (_, s2)) in zip(bs1, bs2)))
end
(Empty, Empty) => Trie{Tuple{A1, A2}}
_ => error("cannot zip two tries not of the same shape")
end
end

"""
zip(t1, t2)
Produces a new trie whose leaf node at a path `p` is given by `(t1[p], t2[p])`.
Throws an error if `t1` and `t2` are not of the same shape: i.e. they don't
have the exact same set of paths.
"""
Base.zip(t1::AbstractTrie, t2::AbstractTrie) = zipwith((a,b) -> (a,b), t1, t2)

Base.getindex(p::Trie, n::Symbol) = getproperty(p, n)
Expand Down Expand Up @@ -258,6 +279,17 @@ Base.keys(t) = Base.propertynames(t)
Base.valtype(t::AbstractTrie{A}) where {A} = A
Base.valtype(::Type{<:AbstractTrie{A}}) where {A} = A

"""
map(f, return_type::Type, t::AbstractTrie)
Produce a new trie of the same shape as `t` where the value at a path `p` is
given by `f(t[p])`. We pass in the return type explicitly so that in the case
that `t` is empty we don't get `Trie{Any}`.
There is a variant defined later where `return_type` is not passed in, and it
tries to use type inference from the Julia compiler to infer `return_type`: use
of this should be discouraged.
"""
function Base.map(f, return_type::Type, t::AbstractTrie)
@match t begin
Leaf(v) => leaf(f(v))
Expand All @@ -270,14 +302,27 @@ function Base.map(f, return_type::Type, t::AbstractTrie)
end
end

"""
map(f, t::AbstractTrie)
Variant of `map(f, return_type, t::AbstractTrie)` which attempts to infer the
return type of `f`.
"""
function Base.map(f, t::AbstractTrie{A}) where {A}
B = Core.Compiler.return_type(f, Tuple{A})
map(f, B, t)
end

"""
flatten(t::AbstractTrie{Trie{A}})
The monad operation for Tries. Works on NonEmptyTries and Tries.
"""
function flatten(t::AbstractTrie{Trie{A}}) where {A}
@match t begin
Leaf(v) => v
# Note that if flatten(v) is empty, the `node` constructor will
# automatically remove it from the built trie.
Node(bs) => node(n => flatten(v) for (n, v) in bs)
Empty() => Trie{A}()
end
Expand Down Expand Up @@ -319,6 +364,12 @@ function Base.:(*)(v1::TrieVar, v2::TrieVar)
end
end

"""
mapwithkey(f, return_type::Type, t::AbstractTrie)
Constructs a new trie with the same shape as `t` where the value at the path
`p` is `f(p, t[p])`.
"""
function mapwithkey(f, return_type::Type, t::AbstractTrie; prefix=PACKAGE_ROOT)
@match t begin
Leaf(v) => leaf(f(prefix, v))
Expand All @@ -328,6 +379,12 @@ function mapwithkey(f, return_type::Type, t::AbstractTrie; prefix=PACKAGE_ROOT)
end
end

"""
traversewithkey(f, t::AbstractTrie; prefix=PACKAGE_ROOT)
Similar to [`mapwithkey`](@ref) but just evaluates `f` for its side effects
instead of constructing a new trie.
"""
function traversewithkey(f, t::AbstractTrie; prefix=PACKAGE_ROOT)
@match t begin
Leaf(v) => (f(prefix, v); nothing)
Expand All @@ -340,6 +397,16 @@ function traversewithkey(f, t::AbstractTrie; prefix=PACKAGE_ROOT)
end
end

"""
fold(emptycase::A, leafcase, nodecase, t::AbstractTrie)::A
Fold over `t` to produce a single value.
Args:
- `emptycase::A`
- `leafcase::eltype(t) -> A`
- `nodecase::OrderedDict{Symbol, A} -> A`
"""
function fold(emptycase::A, leafcase, nodecase, t::AbstractTrie)::A where {A}
@match t begin
Empty() => emptycase
Expand All @@ -348,8 +415,16 @@ function fold(emptycase::A, leafcase, nodecase, t::AbstractTrie)::A where {A}
end
end

"""
all(f, t::AbstractTrie)
Checks if `f` returns `true` when applied to all of the elements of `t`.
Args:
- `f::eltype(t) -> Bool`
"""
function Base.all(f, t::AbstractTrie)
fold(f, d -> all(values(d)), t)
fold(true, f, d -> all(values(d)), t)
end

# precondition: the union of the keys in t1 and t2 is prefix-free
Expand Down
82 changes: 70 additions & 12 deletions test/syntax/Tries.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,85 @@
module TestTries

using GATlab.Syntax.Tries
import .Tries: node, leaf
import .Tries: node, leaf, Node, Leaf, Empty, NonEmpty, zipwith, flatten, fold
using Test

p1 = node(:a => leaf(1), :b => node(:a => leaf(2), :c => leaf(3)))
using MLStyle

@test p1.a isa AbstractTrie
@test_throws Tries.TrieDerefError p1[]
@test p1.a[] == 1
@test p1.b.a isa NonEmptyTrie
@test p1.b.a[] == 2
t1 = node(:a => leaf(1), :b => node(:a => leaf(2), :c => leaf(3)))

@test sprint(show, p1.a) == "leaf(1)::NonEmptyTrie{Int64}"
@test sprint(show, p1.b) == "NonEmptyTrie{Int64}\n├─ :a ⇒ 2\n└─ :c ⇒ 3\n"
@test t1.a isa AbstractTrie
@test_throws Tries.TrieDerefError t1[]
@test t1.a[] == 1
@test t1.b.a isa NonEmptyTrie
@test t1.b.a[] == 2

@test sprint(show, t1.a) == "leaf(1)::NonEmptyTrie{Int64}"
@test sprint(show, t1.b) == "NonEmptyTrie{Int64}\n├─ :a ⇒ 2\n└─ :c ⇒ 3\n"

@test== PACKAGE_ROOT
@test.a isa TrieVar
@test.a.b isa TrieVar
@test_throws Tries.TrieVarNotFound p1[■]
@test_throws Tries.TrieVarNotFound t1[■]

@test t1[■.a] == 1

@test t1[■.b.c] == 3

@test filter(x -> x % 2 == 0, t1) == node(:b => node(:a => leaf(2)))

function int_sqrt(x)
try
Some(Int(sqrt(x)))
catch e
nothing
end
end

@test filtermap(int_sqrt, Int, t1) == node(:a => leaf(1))

@test t1 == @match t1 begin
NonEmpty(net1) => net1
end

@test zipwith(+, t1, t1) == node(:a => leaf(2), :b => node(:a => leaf(4), :c => leaf(6)))
@test zip(t1, t1) == node(:a => leaf((1,1)), :b => node(:a => leaf((2,2)), :c => leaf((3,3))))


@test flatten(leaf(t1)) == t1
@test flatten(leaf(leaf(1))) == leaf(1)
@test flatten(node(:f => leaf(leaf(1)), :g => leaf(t1))) ==
node(
:f => leaf(1)
, :g => node(
:a => leaf(1)
, :b => node(
:a => leaf(2)
, :c => leaf(3)
)
)
)

@test mapwithkey((k, _) -> k, TrieVar, t1) == node(:a => leaf(■.a), :b => node(:a => leaf(■.b.a), :c => leaf(■.b.c)))

keys = TrieVar[]

traversewithkey((k, _) -> push!(keys, k), t1)

@test keys == [■.a, ■.b.a, ■.b.c]

@test fold(0, identity, d -> sum(values(d)), t1) == 6
@test fold(0, identity, d -> sum(values(d)), Trie{Int}()) == 0

@test all(iseven, t1) == false
@test all(iseven, filter(iseven, t1))

@test merge(t1, node(:z => leaf(4))) == node(:a => leaf(1), :b => node(:a => leaf(2), :c => leaf(3)), :z => leaf(4))

@test Trie(■.a => 1, ■.b.a => 2, ■.b.c => 3) == t1

@test p1[■.a] == 1
g2 = node(:a => Trie{Int}())

@test p1[■.b.c] == 3
@test t2 == Trie{Int}()

end

0 comments on commit c77920f

Please sign in to comment.