From 8cd2785bd2d58c67c1d14c20e9677b5c8c957051 Mon Sep 17 00:00:00 2001 From: Owen Lynch Date: Wed, 13 Mar 2024 11:51:19 -0400 Subject: [PATCH] fixup: add tests and docs for Trie methods --- src/syntax/Tries.jl | 123 ++++++++++++++++++++++++++++++++++--------- test/syntax/Tries.jl | 82 ++++++++++++++++++++++++----- 2 files changed, 169 insertions(+), 36 deletions(-) diff --git a/src/syntax/Tries.jl b/src/syntax/Tries.jl index e154ebd4..64200780 100644 --- a/src/syntax/Tries.jl +++ b/src/syntax/Tries.jl @@ -1,16 +1,17 @@ module Tries -export Trie, NonEmptyTrie, AbstractTrie, PACKAGE_ROOT, ■, TrieVar, mapwithkey, traversewithkey +export Trie, NonEmptyTrie, AbstractTrie, PACKAGE_ROOT, ■, TrieVar, filtermap, mapwithkey, traversewithkey using AbstractTrees using OrderedCollections using MLStyle +using StructEquality """ An internal node of a [`Trie`](@ref). Should not be used outside of this module. Cannot be empty. """ -struct Node_{X} +@struct_hash_equal struct Node_{X} branches::OrderedDict{Symbol, X} function Node_{X}(branches::OrderedDict{Symbol, X}) where {X} if length(branches) == 0 @@ -31,12 +32,16 @@ end """ A leaf node of a [`Trie`](@ref). Should not be used outside of this module. """ -struct Leaf_{A} +@struct_hash_equal struct Leaf_{A} value::A end abstract type AbstractTrie{A} end +function Base.:(==)(t1::AbstractTrie, t2::AbstractTrie) + content(t1) == content(t2) +end + """ A non-empty trie. @@ -60,8 +65,6 @@ content(t::NonEmptyTrie) = getfield(t, :content) inner = content(t) if inner isa Node_ Some(inner.branches) - else - nothing end end @@ -69,8 +72,6 @@ end inner = content(t) if inner isa Leaf_ Some(inner.value) - else - nothing end end @@ -95,6 +96,13 @@ struct Trie{A} <: AbstractTrie{A} content::Union{Nothing, NonEmptyTrie{A}} end +function content(p::Trie) + unwrapped = getfield(p, :content) + if !isnothing(unwrapped) + content(unwrapped) + end +end + @active Empty(t) begin isnothing(content(t)) end @@ -106,21 +114,10 @@ end c = getfield(t, :content) if !isnothing(c) Some(c) - else - nothing end end end -function content(p::Trie) - unwrapped = getfield(p, :content) - if !isnothing(unwrapped) - content(unwrapped) - else - nothing - end -end - """ Construct a new Trie node. """ @@ -187,19 +184,27 @@ function Base.filter(f, t::AbstractTrie{A}) where {A} end end +""" + filtermap(f, return_type::Type, t::AbstractTrie) + +Map the function `f : eltype(t) -> Union{Some{return_type}, Nothing}` over the +trie `t` to produce a trie of type `return_type`, filtering out the elements of +`t` on which `f` returns `nothing`. We pass `return_type` explicitly so that in +the case `t` is the empty trie this doesn't return `Trie{Any}`. +""" function filtermap(f, return_type::Type, t::AbstractTrie) @match t begin Leaf(v) => begin - v′ = f(v) - if !isnothing(v′) - leaf(v′) + @match f(v) begin + Some(v′) => leaf(v′) + nothing => Trie{return_type}() end end Node(bs) => begin bs′ = OrderedDict{Symbol, NonEmptyTrie{return_type}}() for (n, s) in bs @match filtermap(f, return_type, s) begin - NonEmpty(net) => (bs′[n] = s′) + NonEmpty(net) => (bs′[n] = net) Empty() => nothing end end @@ -209,18 +214,34 @@ function filtermap(f, return_type::Type, t::AbstractTrie) end end +""" + zipwith(f, t1::AbstractTrie, t2::AbstractTrie) + +Produces a new trie whose leaf node at a path `p` is given by `f(t1[p], t2[p])`. + +Throws an error if `t1` and `t2` are not of the same shape: i.e. they don't +have the exact same set of paths. +""" function zipwith(f, t1::AbstractTrie{A1}, t2::AbstractTrie{A2}) where {A1, A2} @match (t1, t2) begin (Leaf(v1), Leaf(v2)) => leaf(f(v1, v2)) (Node(bs1), Node(bs2)) => begin keys(bs1) == keys(bs2) || error("cannot zip two tries not of the same shape") - node(OrderedDict(n => zip(s1, s2) for ((n, s1), (_, s2)) in zip(bs1, bs2))) + node(OrderedDict(n => zipwith(f, s1, s2) for ((n, s1), (_, s2)) in zip(bs1, bs2))) end (Empty, Empty) => Trie{Tuple{A1, A2}} _ => error("cannot zip two tries not of the same shape") end end +""" + zip(t1, t2) + +Produces a new trie whose leaf node at a path `p` is given by `(t1[p], t2[p])`. + +Throws an error if `t1` and `t2` are not of the same shape: i.e. they don't +have the exact same set of paths. +""" Base.zip(t1::AbstractTrie, t2::AbstractTrie) = zipwith((a,b) -> (a,b), t1, t2) Base.getindex(p::Trie, n::Symbol) = getproperty(p, n) @@ -258,6 +279,17 @@ Base.keys(t) = Base.propertynames(t) Base.valtype(t::AbstractTrie{A}) where {A} = A Base.valtype(::Type{<:AbstractTrie{A}}) where {A} = A +""" + map(f, return_type::Type, t::AbstractTrie) + +Produce a new trie of the same shape as `t` where the value at a path `p` is +given by `f(t[p])`. We pass in the return type explicitly so that in the case +that `t` is empty we don't get `Trie{Any}`. + +There is a variant defined later where `return_type` is not passed in, and it +tries to use type inference from the Julia compiler to infer `return_type`: use +of this should be discouraged. +""" function Base.map(f, return_type::Type, t::AbstractTrie) @match t begin Leaf(v) => leaf(f(v)) @@ -270,14 +302,27 @@ function Base.map(f, return_type::Type, t::AbstractTrie) end end +""" + map(f, t::AbstractTrie) + +Variant of `map(f, return_type, t::AbstractTrie)` which attempts to infer the +return type of `f`. +""" function Base.map(f, t::AbstractTrie{A}) where {A} B = Core.Compiler.return_type(f, Tuple{A}) map(f, B, t) end +""" + flatten(t::AbstractTrie{Trie{A}}) + +The monad operation for Tries. Works on NonEmptyTries and Tries. +""" function flatten(t::AbstractTrie{Trie{A}}) where {A} @match t begin Leaf(v) => v + # Note that if flatten(v) is empty, the `node` constructor will + # automatically remove it from the built trie. Node(bs) => node(n => flatten(v) for (n, v) in bs) Empty() => Trie{A}() end @@ -319,6 +364,12 @@ function Base.:(*)(v1::TrieVar, v2::TrieVar) end end +""" + mapwithkey(f, return_type::Type, t::AbstractTrie) + +Constructs a new trie with the same shape as `t` where the value at the path +`p` is `f(p, t[p])`. +""" function mapwithkey(f, return_type::Type, t::AbstractTrie; prefix=PACKAGE_ROOT) @match t begin Leaf(v) => leaf(f(prefix, v)) @@ -328,6 +379,12 @@ function mapwithkey(f, return_type::Type, t::AbstractTrie; prefix=PACKAGE_ROOT) end end +""" + traversewithkey(f, t::AbstractTrie; prefix=PACKAGE_ROOT) + +Similar to [`mapwithkey`](@ref) but just evaluates `f` for its side effects +instead of constructing a new trie. +""" function traversewithkey(f, t::AbstractTrie; prefix=PACKAGE_ROOT) @match t begin Leaf(v) => (f(prefix, v); nothing) @@ -340,6 +397,16 @@ function traversewithkey(f, t::AbstractTrie; prefix=PACKAGE_ROOT) end end +""" + fold(emptycase::A, leafcase, nodecase, t::AbstractTrie)::A + +Fold over `t` to produce a single value. + +Args: +- `emptycase::A` +- `leafcase::eltype(t) -> A` +- `nodecase::OrderedDict{Symbol, A} -> A` +""" function fold(emptycase::A, leafcase, nodecase, t::AbstractTrie)::A where {A} @match t begin Empty() => emptycase @@ -348,8 +415,16 @@ function fold(emptycase::A, leafcase, nodecase, t::AbstractTrie)::A where {A} end end +""" + all(f, t::AbstractTrie) + +Checks if `f` returns `true` when applied to all of the elements of `t`. + +Args: +- `f::eltype(t) -> Bool` +""" function Base.all(f, t::AbstractTrie) - fold(f, d -> all(values(d)), t) + fold(true, f, d -> all(values(d)), t) end # precondition: the union of the keys in t1 and t2 is prefix-free diff --git a/test/syntax/Tries.jl b/test/syntax/Tries.jl index c1728b23..b6cc8d07 100644 --- a/test/syntax/Tries.jl +++ b/test/syntax/Tries.jl @@ -1,27 +1,85 @@ module TestTries using GATlab.Syntax.Tries -import .Tries: node, leaf +import .Tries: node, leaf, Node, Leaf, Empty, NonEmpty, zipwith, flatten, fold using Test -p1 = node(:a => leaf(1), :b => node(:a => leaf(2), :c => leaf(3))) +using MLStyle -@test p1.a isa AbstractTrie -@test_throws Tries.TrieDerefError p1[] -@test p1.a[] == 1 -@test p1.b.a isa NonEmptyTrie -@test p1.b.a[] == 2 +t1 = node(:a => leaf(1), :b => node(:a => leaf(2), :c => leaf(3))) -@test sprint(show, p1.a) == "leaf(1)::NonEmptyTrie{Int64}" -@test sprint(show, p1.b) == "NonEmptyTrie{Int64}\n├─ :a ⇒ 2\n└─ :c ⇒ 3\n" +@test t1.a isa AbstractTrie +@test_throws Tries.TrieDerefError t1[] +@test t1.a[] == 1 +@test t1.b.a isa NonEmptyTrie +@test t1.b.a[] == 2 + +@test sprint(show, t1.a) == "leaf(1)::NonEmptyTrie{Int64}" +@test sprint(show, t1.b) == "NonEmptyTrie{Int64}\n├─ :a ⇒ 2\n└─ :c ⇒ 3\n" @test ■ == PACKAGE_ROOT @test ■.a isa TrieVar @test ■.a.b isa TrieVar -@test_throws Tries.TrieVarNotFound p1[■] +@test_throws Tries.TrieVarNotFound t1[■] + +@test t1[■.a] == 1 + +@test t1[■.b.c] == 3 + +@test filter(x -> x % 2 == 0, t1) == node(:b => node(:a => leaf(2))) + +function int_sqrt(x) + try + Some(Int(sqrt(x))) + catch e + nothing + end +end + +@test filtermap(int_sqrt, Int, t1) == node(:a => leaf(1)) + +@test t1 == @match t1 begin + NonEmpty(net1) => net1 +end + +@test zipwith(+, t1, t1) == node(:a => leaf(2), :b => node(:a => leaf(4), :c => leaf(6))) +@test zip(t1, t1) == node(:a => leaf((1,1)), :b => node(:a => leaf((2,2)), :c => leaf((3,3)))) + + +@test flatten(leaf(t1)) == t1 +@test flatten(leaf(leaf(1))) == leaf(1) +@test flatten(node(:f => leaf(leaf(1)), :g => leaf(t1))) == + node( + :f => leaf(1) + , :g => node( + :a => leaf(1) + , :b => node( + :a => leaf(2) + , :c => leaf(3) + ) + ) + ) + +@test mapwithkey((k, _) -> k, TrieVar, t1) == node(:a => leaf(■.a), :b => node(:a => leaf(■.b.a), :c => leaf(■.b.c))) + +keys = TrieVar[] + +traversewithkey((k, _) -> push!(keys, k), t1) + +@test keys == [■.a, ■.b.a, ■.b.c] + +@test fold(0, identity, d -> sum(values(d)), t1) == 6 +@test fold(0, identity, d -> sum(values(d)), Trie{Int}()) == 0 + +@test all(iseven, t1) == false +@test all(iseven, filter(iseven, t1)) + +@test merge(t1, node(:z => leaf(4))) == node(:a => leaf(1), :b => node(:a => leaf(2), :c => leaf(3)), :z => leaf(4)) + +@test Trie(■.a => 1, ■.b.a => 2, ■.b.c => 3) == t1 -@test p1[■.a] == 1 +g2 = node(:a => Trie{Int}()) -@test p1[■.b.c] == 3 +@test t2 == Trie{Int}() end