diff --git a/.gitignore b/.gitignore index 29126e4..d93db51 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ docs/site/ # committed for packages, but should be committed for applications that require a static # environment. Manifest.toml + +# vs code environment +.vscode \ No newline at end of file diff --git a/Project.toml b/Project.toml index c973f66..9e05105 100644 --- a/Project.toml +++ b/Project.toml @@ -3,12 +3,13 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.4" +version = "0.4.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] AbstractMCMC = "2, 3" diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index ee9dfab..ad8ebd0 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -7,7 +7,6 @@ export VarName, getsym, getlens, inspace, subsumes, varname, vsym, @varname, @vs # Abstract model functions export AbstractProbabilisticProgram, condition, decondition, logdensityof, densityof - # Abstract traces export AbstractModelTrace @@ -17,4 +16,9 @@ include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("deprecations.jl") +# GraphInfo +module GraphPPL + include("graphinfo.jl") +end + end # module diff --git a/src/graphinfo.jl b/src/graphinfo.jl new file mode 100644 index 0000000..8019d7e --- /dev/null +++ b/src/graphinfo.jl @@ -0,0 +1,255 @@ +using AbstractPPL +import Base.getindex +using SparseArrays +using Setfield +using Setfield: PropertyLens, get + +""" + GraphInfo + +Record the state of the model as a struct of NamedTuples, all +sharing the same key values, namely, those of the model parameters. +`value` should store the initial/current value of the parameters. +`input` stores a tuple of inputs for a given node. `eval` are the +anonymous functions associated with each node. These might typically +be either deterministic values or some distribution, but could an +arbitrary julia program. `kind` is a tuple of symbols indicating +whether the node is a logical or stochastic node. Additionally, the +adjacency matrix and topologically ordered vertex list and stored. + +GraphInfo is instantiated using the `Model` constctor. +""" + +struct GraphInfo{T} <: AbstractModelTrace + input::NamedTuple{T} + value::NamedTuple{T} + eval::NamedTuple{T} + kind::NamedTuple{T} + A::SparseMatrixCSC + sorted_vertices::Vector{Symbol} +end + +""" + Model(;kwargs...) + +`Model` type constructor that takes in named arguments for +nodes and returns a `Model`. Nodes are pairs of variable names +and tuples containing default value, an eval function +and node type. The inputs of each node are inferred from +their anonymous functions. The returned object has a type +GraphInfo{(sorted_vertices...)}. + +# Examples +```jl-doctest +julia> using AbstractPPL + +julia> Model( + s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic), + μ = (1.0, () -> 1.0, :Logical), + y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic) + ) +Nodes: +μ = (value = 1.0, input = (), eval = var"#6#9"(), kind = :Logical) +s2 = (value = 0.0, input = (), eval = var"#5#8"(), kind = :Stochastic) +y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic) +``` +""" + +struct Model{T} <: AbstractProbabilisticProgram + g::GraphInfo{T} +end + +function Model(;kwargs...) + for (i, node) in enumerate(values(kwargs)) + @assert typeof(node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} "Check input order for node $(i) matches Tuple(value, function, kind)" + end + vals = getvals(NamedTuple(kwargs)) + args = [argnames(f) for f in vals[2]] + A, sorted_vertices = dag(NamedTuple{keys(kwargs)}(args)) + modelinputs = NamedTuple{Tuple(sorted_vertices)}.([Tuple.(args), vals...]) + Model(GraphInfo(modelinputs..., A, sorted_vertices)) +end + +""" + dag(inputs) + +Function taking in a NamedTuple containing the inputs to each node +and returns the implied adjacency matrix and topologically ordered +vertex list. +""" +function dag(inputs) + input_names = Symbol[keys(inputs)...] + A = adjacency_matrix(inputs) + sorted_vertices = topological_sort_by_dfs(A) + sorted_A = permute(A, collect(1:length(inputs)), sorted_vertices) + sorted_A, input_names[sorted_vertices] +end + +""" + getvals(nt::NamedTuple{T}) + +Takes in the arguments to Model(;kwargs...) as a NamedTuple and +reorders into a tuple of tuples each containing either of value, +input, eval and kind, as required by the GraphInfo type. +""" +@generated function getvals(nt::NamedTuple{T}) where T + values = [:(nt[$i][$j]) for i in 1:length(T), j in 1:3] + m = [:($(values[:,i]...), ) for i in 1:3] + return Expr(:tuple, m...) # :($(m...),) +end + +""" + argnames(f::Function) + +Returns a Vector{Symbol} of the inputs to an anonymous function `f`. +""" +argnames(f::Function) = Base.method_argnames(first(methods(f)))[2:end] + +""" + adjacency_matrix(inputs) + +For a NamedTuple{T} with vertices `T` paired with tuples of input nodes, +`adjacency_matrix` constructs the adjacency matrix using the order +of variables given by `T`. + +# Examples +```jl-doctest +julia> inputs = (a = (), b = (), c = (:a, :b)) +(a = (), b = (), c = (:a, :b)) + +julia> AbstractPPL.adjacency_matrix(inputs) +3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries: + ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ + 1.0 1.0 ⋅ +``` +""" +function adjacency_matrix(inputs::NamedTuple{nodes}) where {nodes} + N = length(inputs) + col_inds = NamedTuple{nodes}(ntuple(identity, N)) + A = spzeros(Bool, N, N) + for (row, node) in enumerate(nodes) + for input in inputs[node] + if input ∉ nodes + error("Parent node of $(input) not found in node set: $(nodes)") + end + col = col_inds[input] + A[row, col] = true + end + end + return A +end + +function outneighbors(A::SparseMatrixCSC, u::T) where T <: Int + #adapted from Graph.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/interface.jl#L302 + inds, _ = findnz(A[:, u]) + inds +end + +function topological_sort_by_dfs(A) + # lifted from Graphs.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/traversals/dfs.jl#L44 + # Depth first search implementation optimized from http://www.cs.nott.ac.uk/~psznza/G5BADS03/graphs2.pdf + n_verts = size(A)[1] + vcolor = zeros(UInt8, n_verts) + verts = Vector{Int64}() + for v in 1:n_verts + vcolor[v] != 0 && continue + S = Vector{Int64}([v]) + vcolor[v] = 1 + while !isempty(S) + u = S[end] + w = 0 + for n in outneighbors(A, u) + if vcolor[n] == 1 + error("The input graph contains at least one loop.") # TODO 0.7 should we use a different error? + elseif vcolor[n] == 0 + w = n + break + end + end + if w != 0 + vcolor[w] = 1 + push!(S, w) + else + vcolor[u] = 2 + push!(verts, u) + pop!(S) + end + end + end + return reverse(verts) +end + +""" + Base.getindex(m::Model, vn::VarName{p}) + +Index a Model with a `VarName{p}` lens. Retrieves the `value``, `input`, +`eval` and `kind` for node `p`. + +# Examples + +```jl-doctest +julia> using AbstractPPL + +julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic), + μ = (1.0, () -> 1.0, :Logical), + y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)) +(s2 = Symbol[], μ = Symbol[], y = [:μ, :s2]) +Nodes: +μ = (value = 0.0, input = (), eval = var"#43#46"(), kind = :Stochastic) +s2 = (value = 1.0, input = (), eval = var"#44#47"(), kind = :Logical) +y = (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic) + + +julia> m[@varname y] +(value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic) +``` +""" +@generated function Base.getindex(g::GraphInfo, vn::VarName{p}) where {p} + fns = fieldnames(GraphInfo)[1:4] + name_lens = Setfield.PropertyLens{p}() + field_lenses = [Setfield.PropertyLens{f}() for f in fns] + values = [:(get(g, Setfield.compose($l, $name_lens, getlens(vn)))) for l in field_lenses] + return :(NamedTuple{$(fns)}(($(values...),))) +end + +function Base.getindex(m::Model, vn::VarName) + return m.g[vn] +end + +function Base.show(io::IO, m::Model) + print(io, "Nodes: \n") + for node in nodes(m) + print(io, "$node = ", m[VarName{node}()], "\n") + end +end + + +function Base.iterate(m::Model, state=1) + state > length(nodes(m)) ? nothing : (m[VarName{m.g.sorted_vertices[state]}()], state+1) +end + +Base.eltype(m::Model) = NamedTuple{fieldnames(GraphInfo)[1:4]} +Base.IteratorEltype(m::Model) = HasEltype() + +Base.keys(m::Model) = (VarName{n}() for n in m.g.sorted_vertices) +Base.values(m::Model) = Base.Generator(identity, m) +Base.length(m::Model) = length(nodes(m)) +Base.keytype(m::Model) = eltype(keys(m)) +Base.valtype(m::Model) = eltype(m) + + +""" + dag(m::Model) + +Returns the adjacency matrix of the model as a SparseArray. +""" +get_dag(m::Model) = m.g.A + +""" + nodes(m::Model) + +Returns a `Vector{Symbol}` containing the sorted vertices +of the DAG. +""" +nodes(m::Model) = m.g.sorted_vertices \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index f0592fe..3c3efc0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/graphinfo.jl b/test/graphinfo.jl new file mode 100644 index 0000000..4e4ba18 --- /dev/null +++ b/test/graphinfo.jl @@ -0,0 +1,64 @@ +using AbstractPPL +import AbstractPPL.GraphPPL: GraphInfo, Model, get_dag +using SparseArrays +using Test +## Example taken from Mamba +line = Dict{Symbol, Any}( + :x => [1, 2, 3, 4, 5], + :y => [1, 3, 3, 3, 5] +) +line[:xmat] = [ones(5) line[:x]] + +# just making it a NamedTuple so that the values can be tested later. Constructor should be used as Model(;kwargs...). +model = ( + β = (zeros(2), () -> MvNormal(2, sqrt(1000)), :Stochastic), + xmat = (line[:xmat], () -> line[:xmat], :Logical), + s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic), + μ = (zeros(5), (xmat, β) -> xmat * β, :Logical), + y = (zeros(5), (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic) +) + +# construct the model! +m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor + +# test the type of the model is correct +@test typeof(m) <: Model +@test typeof(m) == Model{(:s2, :xmat, :β, :μ, :y)} +@test typeof(m.g) <: GraphInfo <: AbstractModelTrace +@test typeof(m.g) == GraphInfo{(:s2, :xmat, :β, :μ, :y)} + +# test the dag is correct +A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0]) +@test get_dag(m) == A + +@test length(m) == 5 +@test eltype(m) == valtype(m) + +# check the values from the NamedTuple match the values in the fields of GraphInfo +vals = AbstractPPL.GraphPPL.getvals(model) +for (i, field) in enumerate([:value, :eval, :kind]) + @test eval( :( values(m.g.$field) == vals[$i] ) ) +end + +for node in m + @test typeof(node) <: NamedTuple{fieldnames(GraphInfo)[1:4]} +end + +# test the right inputs have been inferred +@test m.g.input == (s2 = (), xmat = (), β = (), μ = (:xmat, :β), y = (:μ, :s2)) + +# test keys are VarNames +for key in keys(m) + @test typeof(key) <: VarName +end + +# test Model constructor for model with single parent node +single_parent_m = Model(μ = (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)) +@test typeof(single_parent_m) == Model{(:μ, :y)} +@test typeof(single_parent_m.g) == GraphInfo{(:μ, :y)} + +# test ErrorException for parent node not found +@test_throws ErrorException Model( μ = (1.0, (β) -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)) + +# test AssertionError thrown for kwargs with the wrong order of inputs +@test_throws AssertionError Model( μ = ((β) -> 3, 1.0, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 9efca34..8af6d7b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,7 +12,7 @@ using Test @testset "AbstractPPL.jl" begin include("deprecations.jl") - + include("graphinfo.jl") @testset "doctests" begin DocMeta.setdocmeta!( AbstractPPL, @@ -22,5 +22,4 @@ using Test ) doctest(AbstractPPL; manual=false) end -end - +end \ No newline at end of file