Skip to content

Commit

Permalink
DAG Model interface (#47)
Browse files Browse the repository at this point in the history
This is a draft PR introducing a `Model` type that stores and makes use the model graph. 

The main type introduced here is the `Model` struct which stores the `ModelState` and `DAG`, each of which are their own types. `ModelState` contains information about the node values, dependencies and eval functions and `DAG` contains the graph and topologically ordered vertex list. 

A model can be constructed in the following way: 

```julia
julia> nt = (
               s2 = (0.0, (), () -> InverseGamma(2.0,3.0), :Stochastic), 
               μ = (1.0, (), () -> 1.0, :Logical), 
               y = (0.0, (:μ, :s2), (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)
           )
(s2 = (0.0, (), var"#33#36"(), :Stochastic), μ = (1.0, (), var"#34#37"(), :Logical), y = (0.0, (:μ, :s2), var"#35#38"(), :Stochastic))

julia> Model(nt)
Nodes: 
μ = (value = 1.0, input = (), eval = var"#16#19"(), kind = :Logical)
s2 = (value = 0.0, input = (), eval = var"#15#18"(), kind = :Stochastic)
y = (value = 0.0, input = (:μ, :s2), eval = var"#17#20"(), kind = :Stochastic)
DAG: 
3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
  ⋅    ⋅    ⋅ 
  ⋅    ⋅    ⋅ 
 1.0  1.0   ⋅ 
```

At present, only functions needed for the constructors are implemented, as well as indexing using `@varname`. I still need to complete the integration with the AbstractPPL api. TODO: 
~~- [ ] `condition`/`decondition`,~~
~~- [ ] `sample`~~
~~- [ ] `logdensityof`~~
- [x] pure functions for ordered dictionary, as outlined in [AbstractPPL](https://github.com/TuringLang/AbstractPPL.jl#property-interface)

Feedback on `Model` structure welcome whilst I implement the remaining features!
  • Loading branch information
PavanChaggar committed Feb 7, 2022
1 parent 9b64dd8 commit 4692526
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml

# vs code environment
.vscode
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probablistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.4"
version = "0.4.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
AbstractMCMC = "2, 3"
Expand Down
6 changes: 5 additions & 1 deletion src/AbstractPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ export VarName, getsym, getlens, inspace, subsumes, varname, vsym, @varname, @vs
# Abstract model functions
export AbstractProbabilisticProgram, condition, decondition, logdensityof, densityof


# Abstract traces
export AbstractModelTrace

Expand All @@ -17,4 +16,9 @@ include("abstractmodeltrace.jl")
include("abstractprobprog.jl")
include("deprecations.jl")

# GraphInfo
module GraphPPL
include("graphinfo.jl")
end

end # module
255 changes: 255 additions & 0 deletions src/graphinfo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
using AbstractPPL
import Base.getindex
using SparseArrays
using Setfield
using Setfield: PropertyLens, get

"""
GraphInfo
Record the state of the model as a struct of NamedTuples, all
sharing the same key values, namely, those of the model parameters.
`value` should store the initial/current value of the parameters.
`input` stores a tuple of inputs for a given node. `eval` are the
anonymous functions associated with each node. These might typically
be either deterministic values or some distribution, but could an
arbitrary julia program. `kind` is a tuple of symbols indicating
whether the node is a logical or stochastic node. Additionally, the
adjacency matrix and topologically ordered vertex list and stored.
GraphInfo is instantiated using the `Model` constctor.
"""

struct GraphInfo{T} <: AbstractModelTrace
input::NamedTuple{T}
value::NamedTuple{T}
eval::NamedTuple{T}
kind::NamedTuple{T}
A::SparseMatrixCSC
sorted_vertices::Vector{Symbol}
end

"""
Model(;kwargs...)
`Model` type constructor that takes in named arguments for
nodes and returns a `Model`. Nodes are pairs of variable names
and tuples containing default value, an eval function
and node type. The inputs of each node are inferred from
their anonymous functions. The returned object has a type
GraphInfo{(sorted_vertices...)}.
# Examples
```jl-doctest
julia> using AbstractPPL
julia> Model(
s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
μ = (1.0, () -> 1.0, :Logical),
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)
)
Nodes:
μ = (value = 1.0, input = (), eval = var"#6#9"(), kind = :Logical)
s2 = (value = 0.0, input = (), eval = var"#5#8"(), kind = :Stochastic)
y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic)
```
"""

struct Model{T} <: AbstractProbabilisticProgram
g::GraphInfo{T}
end

function Model(;kwargs...)
for (i, node) in enumerate(values(kwargs))
@assert typeof(node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} "Check input order for node $(i) matches Tuple(value, function, kind)"
end
vals = getvals(NamedTuple(kwargs))
args = [argnames(f) for f in vals[2]]
A, sorted_vertices = dag(NamedTuple{keys(kwargs)}(args))
modelinputs = NamedTuple{Tuple(sorted_vertices)}.([Tuple.(args), vals...])
Model(GraphInfo(modelinputs..., A, sorted_vertices))
end

"""
dag(inputs)
Function taking in a NamedTuple containing the inputs to each node
and returns the implied adjacency matrix and topologically ordered
vertex list.
"""
function dag(inputs)
input_names = Symbol[keys(inputs)...]
A = adjacency_matrix(inputs)
sorted_vertices = topological_sort_by_dfs(A)
sorted_A = permute(A, collect(1:length(inputs)), sorted_vertices)
sorted_A, input_names[sorted_vertices]
end

"""
getvals(nt::NamedTuple{T})
Takes in the arguments to Model(;kwargs...) as a NamedTuple and
reorders into a tuple of tuples each containing either of value,
input, eval and kind, as required by the GraphInfo type.
"""
@generated function getvals(nt::NamedTuple{T}) where T
values = [:(nt[$i][$j]) for i in 1:length(T), j in 1:3]
m = [:($(values[:,i]...), ) for i in 1:3]
return Expr(:tuple, m...) # :($(m...),)
end

"""
argnames(f::Function)
Returns a Vector{Symbol} of the inputs to an anonymous function `f`.
"""
argnames(f::Function) = Base.method_argnames(first(methods(f)))[2:end]

"""
adjacency_matrix(inputs)
For a NamedTuple{T} with vertices `T` paired with tuples of input nodes,
`adjacency_matrix` constructs the adjacency matrix using the order
of variables given by `T`.
# Examples
```jl-doctest
julia> inputs = (a = (), b = (), c = (:a, :b))
(a = (), b = (), c = (:a, :b))
julia> AbstractPPL.adjacency_matrix(inputs)
3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
⋅ ⋅ ⋅
⋅ ⋅ ⋅
1.0 1.0 ⋅
```
"""
function adjacency_matrix(inputs::NamedTuple{nodes}) where {nodes}
N = length(inputs)
col_inds = NamedTuple{nodes}(ntuple(identity, N))
A = spzeros(Bool, N, N)
for (row, node) in enumerate(nodes)
for input in inputs[node]
if input nodes
error("Parent node of $(input) not found in node set: $(nodes)")
end
col = col_inds[input]
A[row, col] = true
end
end
return A
end

function outneighbors(A::SparseMatrixCSC, u::T) where T <: Int
#adapted from Graph.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/interface.jl#L302
inds, _ = findnz(A[:, u])
inds
end

function topological_sort_by_dfs(A)
# lifted from Graphs.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/traversals/dfs.jl#L44
# Depth first search implementation optimized from http://www.cs.nott.ac.uk/~psznza/G5BADS03/graphs2.pdf
n_verts = size(A)[1]
vcolor = zeros(UInt8, n_verts)
verts = Vector{Int64}()
for v in 1:n_verts
vcolor[v] != 0 && continue
S = Vector{Int64}([v])
vcolor[v] = 1
while !isempty(S)
u = S[end]
w = 0
for n in outneighbors(A, u)
if vcolor[n] == 1
error("The input graph contains at least one loop.") # TODO 0.7 should we use a different error?
elseif vcolor[n] == 0
w = n
break
end
end
if w != 0
vcolor[w] = 1
push!(S, w)
else
vcolor[u] = 2
push!(verts, u)
pop!(S)
end
end
end
return reverse(verts)
end

"""
Base.getindex(m::Model, vn::VarName{p})
Index a Model with a `VarName{p}` lens. Retrieves the `value``, `input`,
`eval` and `kind` for node `p`.
# Examples
```jl-doctest
julia> using AbstractPPL
julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
μ = (1.0, () -> 1.0, :Logical),
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
(s2 = Symbol[], μ = Symbol[], y = [:μ, :s2])
Nodes:
μ = (value = 0.0, input = (), eval = var"#43#46"(), kind = :Stochastic)
s2 = (value = 1.0, input = (), eval = var"#44#47"(), kind = :Logical)
y = (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)
julia> m[@varname y]
(value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic)
```
"""
@generated function Base.getindex(g::GraphInfo, vn::VarName{p}) where {p}
fns = fieldnames(GraphInfo)[1:4]
name_lens = Setfield.PropertyLens{p}()
field_lenses = [Setfield.PropertyLens{f}() for f in fns]
values = [:(get(g, Setfield.compose($l, $name_lens, getlens(vn)))) for l in field_lenses]
return :(NamedTuple{$(fns)}(($(values...),)))
end

function Base.getindex(m::Model, vn::VarName)
return m.g[vn]
end

function Base.show(io::IO, m::Model)
print(io, "Nodes: \n")
for node in nodes(m)
print(io, "$node = ", m[VarName{node}()], "\n")
end
end


function Base.iterate(m::Model, state=1)
state > length(nodes(m)) ? nothing : (m[VarName{m.g.sorted_vertices[state]}()], state+1)
end

Base.eltype(m::Model) = NamedTuple{fieldnames(GraphInfo)[1:4]}
Base.IteratorEltype(m::Model) = HasEltype()

Base.keys(m::Model) = (VarName{n}() for n in m.g.sorted_vertices)
Base.values(m::Model) = Base.Generator(identity, m)
Base.length(m::Model) = length(nodes(m))
Base.keytype(m::Model) = eltype(keys(m))
Base.valtype(m::Model) = eltype(m)


"""
dag(m::Model)
Returns the adjacency matrix of the model as a SparseArray.
"""
get_dag(m::Model) = m.g.A

"""
nodes(m::Model)
Returns a `Vector{Symbol}` containing the sorted vertices
of the DAG.
"""
nodes(m::Model) = m.g.sorted_vertices
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
64 changes: 64 additions & 0 deletions test/graphinfo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using AbstractPPL
import AbstractPPL.GraphPPL: GraphInfo, Model, get_dag
using SparseArrays
using Test
## Example taken from Mamba
line = Dict{Symbol, Any}(
:x => [1, 2, 3, 4, 5],
:y => [1, 3, 3, 3, 5]
)
line[:xmat] = [ones(5) line[:x]]

# just making it a NamedTuple so that the values can be tested later. Constructor should be used as Model(;kwargs...).
model = (
β = (zeros(2), () -> MvNormal(2, sqrt(1000)), :Stochastic),
xmat = (line[:xmat], () -> line[:xmat], :Logical),
s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
μ = (zeros(5), (xmat, β) -> xmat * β, :Logical),
y = (zeros(5), (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)
)

# construct the model!
m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor

# test the type of the model is correct
@test typeof(m) <: Model
@test typeof(m) == Model{(:s2, :xmat, , , :y)}
@test typeof(m.g) <: GraphInfo <: AbstractModelTrace
@test typeof(m.g) == GraphInfo{(:s2, :xmat, , , :y)}

# test the dag is correct
A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
@test get_dag(m) == A

@test length(m) == 5
@test eltype(m) == valtype(m)

# check the values from the NamedTuple match the values in the fields of GraphInfo
vals = AbstractPPL.GraphPPL.getvals(model)
for (i, field) in enumerate([:value, :eval, :kind])
@test eval( :( values(m.g.$field) == vals[$i] ) )
end

for node in m
@test typeof(node) <: NamedTuple{fieldnames(GraphInfo)[1:4]}
end

# test the right inputs have been inferred
@test m.g.input == (s2 = (), xmat = (), β = (), μ = (:xmat, ), y = (, :s2))

# test keys are VarNames
for key in keys(m)
@test typeof(key) <: VarName
end

# test Model constructor for model with single parent node
single_parent_m = Model= (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
@test typeof(single_parent_m) == Model{(, :y)}
@test typeof(single_parent_m.g) == GraphInfo{(, :y)}

# test ErrorException for parent node not found
@test_throws ErrorException Model( μ = (1.0, (β) -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))

# test AssertionError thrown for kwargs with the wrong order of inputs
@test_throws AssertionError Model( μ = ((β) -> 3, 1.0, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
5 changes: 2 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using Test

@testset "AbstractPPL.jl" begin
include("deprecations.jl")

include("graphinfo.jl")
@testset "doctests" begin
DocMeta.setdocmeta!(
AbstractPPL,
Expand All @@ -22,5 +22,4 @@ using Test
)
doctest(AbstractPPL; manual=false)
end
end

end

0 comments on commit 4692526

Please sign in to comment.