-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This is a draft PR introducing a `Model` type that stores and makes use the model graph. The main type introduced here is the `Model` struct which stores the `ModelState` and `DAG`, each of which are their own types. `ModelState` contains information about the node values, dependencies and eval functions and `DAG` contains the graph and topologically ordered vertex list. A model can be constructed in the following way: ```julia julia> nt = ( s2 = (0.0, (), () -> InverseGamma(2.0,3.0), :Stochastic), μ = (1.0, (), () -> 1.0, :Logical), y = (0.0, (:μ, :s2), (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic) ) (s2 = (0.0, (), var"#33#36"(), :Stochastic), μ = (1.0, (), var"#34#37"(), :Logical), y = (0.0, (:μ, :s2), var"#35#38"(), :Stochastic)) julia> Model(nt) Nodes: μ = (value = 1.0, input = (), eval = var"#16#19"(), kind = :Logical) s2 = (value = 0.0, input = (), eval = var"#15#18"(), kind = :Stochastic) y = (value = 0.0, input = (:μ, :s2), eval = var"#17#20"(), kind = :Stochastic) DAG: 3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1.0 1.0 ⋅ ``` At present, only functions needed for the constructors are implemented, as well as indexing using `@varname`. I still need to complete the integration with the AbstractPPL api. TODO: ~~- [ ] `condition`/`decondition`,~~ ~~- [ ] `sample`~~ ~~- [ ] `logdensityof`~~ - [x] pure functions for ordered dictionary, as outlined in [AbstractPPL](https://github.com/TuringLang/AbstractPPL.jl#property-interface) Feedback on `Model` structure welcome whilst I implement the remaining features!
- Loading branch information
1 parent
9b64dd8
commit 4692526
Showing
7 changed files
with
332 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
using AbstractPPL | ||
import Base.getindex | ||
using SparseArrays | ||
using Setfield | ||
using Setfield: PropertyLens, get | ||
|
||
""" | ||
GraphInfo | ||
Record the state of the model as a struct of NamedTuples, all | ||
sharing the same key values, namely, those of the model parameters. | ||
`value` should store the initial/current value of the parameters. | ||
`input` stores a tuple of inputs for a given node. `eval` are the | ||
anonymous functions associated with each node. These might typically | ||
be either deterministic values or some distribution, but could an | ||
arbitrary julia program. `kind` is a tuple of symbols indicating | ||
whether the node is a logical or stochastic node. Additionally, the | ||
adjacency matrix and topologically ordered vertex list and stored. | ||
GraphInfo is instantiated using the `Model` constctor. | ||
""" | ||
|
||
struct GraphInfo{T} <: AbstractModelTrace | ||
input::NamedTuple{T} | ||
value::NamedTuple{T} | ||
eval::NamedTuple{T} | ||
kind::NamedTuple{T} | ||
A::SparseMatrixCSC | ||
sorted_vertices::Vector{Symbol} | ||
end | ||
|
||
""" | ||
Model(;kwargs...) | ||
`Model` type constructor that takes in named arguments for | ||
nodes and returns a `Model`. Nodes are pairs of variable names | ||
and tuples containing default value, an eval function | ||
and node type. The inputs of each node are inferred from | ||
their anonymous functions. The returned object has a type | ||
GraphInfo{(sorted_vertices...)}. | ||
# Examples | ||
```jl-doctest | ||
julia> using AbstractPPL | ||
julia> Model( | ||
s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic), | ||
μ = (1.0, () -> 1.0, :Logical), | ||
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic) | ||
) | ||
Nodes: | ||
μ = (value = 1.0, input = (), eval = var"#6#9"(), kind = :Logical) | ||
s2 = (value = 0.0, input = (), eval = var"#5#8"(), kind = :Stochastic) | ||
y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic) | ||
``` | ||
""" | ||
|
||
struct Model{T} <: AbstractProbabilisticProgram | ||
g::GraphInfo{T} | ||
end | ||
|
||
function Model(;kwargs...) | ||
for (i, node) in enumerate(values(kwargs)) | ||
@assert typeof(node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} "Check input order for node $(i) matches Tuple(value, function, kind)" | ||
end | ||
vals = getvals(NamedTuple(kwargs)) | ||
args = [argnames(f) for f in vals[2]] | ||
A, sorted_vertices = dag(NamedTuple{keys(kwargs)}(args)) | ||
modelinputs = NamedTuple{Tuple(sorted_vertices)}.([Tuple.(args), vals...]) | ||
Model(GraphInfo(modelinputs..., A, sorted_vertices)) | ||
end | ||
|
||
""" | ||
dag(inputs) | ||
Function taking in a NamedTuple containing the inputs to each node | ||
and returns the implied adjacency matrix and topologically ordered | ||
vertex list. | ||
""" | ||
function dag(inputs) | ||
input_names = Symbol[keys(inputs)...] | ||
A = adjacency_matrix(inputs) | ||
sorted_vertices = topological_sort_by_dfs(A) | ||
sorted_A = permute(A, collect(1:length(inputs)), sorted_vertices) | ||
sorted_A, input_names[sorted_vertices] | ||
end | ||
|
||
""" | ||
getvals(nt::NamedTuple{T}) | ||
Takes in the arguments to Model(;kwargs...) as a NamedTuple and | ||
reorders into a tuple of tuples each containing either of value, | ||
input, eval and kind, as required by the GraphInfo type. | ||
""" | ||
@generated function getvals(nt::NamedTuple{T}) where T | ||
values = [:(nt[$i][$j]) for i in 1:length(T), j in 1:3] | ||
m = [:($(values[:,i]...), ) for i in 1:3] | ||
return Expr(:tuple, m...) # :($(m...),) | ||
end | ||
|
||
""" | ||
argnames(f::Function) | ||
Returns a Vector{Symbol} of the inputs to an anonymous function `f`. | ||
""" | ||
argnames(f::Function) = Base.method_argnames(first(methods(f)))[2:end] | ||
|
||
""" | ||
adjacency_matrix(inputs) | ||
For a NamedTuple{T} with vertices `T` paired with tuples of input nodes, | ||
`adjacency_matrix` constructs the adjacency matrix using the order | ||
of variables given by `T`. | ||
# Examples | ||
```jl-doctest | ||
julia> inputs = (a = (), b = (), c = (:a, :b)) | ||
(a = (), b = (), c = (:a, :b)) | ||
julia> AbstractPPL.adjacency_matrix(inputs) | ||
3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries: | ||
⋅ ⋅ ⋅ | ||
⋅ ⋅ ⋅ | ||
1.0 1.0 ⋅ | ||
``` | ||
""" | ||
function adjacency_matrix(inputs::NamedTuple{nodes}) where {nodes} | ||
N = length(inputs) | ||
col_inds = NamedTuple{nodes}(ntuple(identity, N)) | ||
A = spzeros(Bool, N, N) | ||
for (row, node) in enumerate(nodes) | ||
for input in inputs[node] | ||
if input ∉ nodes | ||
error("Parent node of $(input) not found in node set: $(nodes)") | ||
end | ||
col = col_inds[input] | ||
A[row, col] = true | ||
end | ||
end | ||
return A | ||
end | ||
|
||
function outneighbors(A::SparseMatrixCSC, u::T) where T <: Int | ||
#adapted from Graph.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/interface.jl#L302 | ||
inds, _ = findnz(A[:, u]) | ||
inds | ||
end | ||
|
||
function topological_sort_by_dfs(A) | ||
# lifted from Graphs.jl https://github.com/JuliaGraphs/Graphs.jl/blob/06669054ed470bcfe4b2ad90ed974f2e65c84bb6/src/traversals/dfs.jl#L44 | ||
# Depth first search implementation optimized from http://www.cs.nott.ac.uk/~psznza/G5BADS03/graphs2.pdf | ||
n_verts = size(A)[1] | ||
vcolor = zeros(UInt8, n_verts) | ||
verts = Vector{Int64}() | ||
for v in 1:n_verts | ||
vcolor[v] != 0 && continue | ||
S = Vector{Int64}([v]) | ||
vcolor[v] = 1 | ||
while !isempty(S) | ||
u = S[end] | ||
w = 0 | ||
for n in outneighbors(A, u) | ||
if vcolor[n] == 1 | ||
error("The input graph contains at least one loop.") # TODO 0.7 should we use a different error? | ||
elseif vcolor[n] == 0 | ||
w = n | ||
break | ||
end | ||
end | ||
if w != 0 | ||
vcolor[w] = 1 | ||
push!(S, w) | ||
else | ||
vcolor[u] = 2 | ||
push!(verts, u) | ||
pop!(S) | ||
end | ||
end | ||
end | ||
return reverse(verts) | ||
end | ||
|
||
""" | ||
Base.getindex(m::Model, vn::VarName{p}) | ||
Index a Model with a `VarName{p}` lens. Retrieves the `value``, `input`, | ||
`eval` and `kind` for node `p`. | ||
# Examples | ||
```jl-doctest | ||
julia> using AbstractPPL | ||
julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic), | ||
μ = (1.0, () -> 1.0, :Logical), | ||
y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)) | ||
(s2 = Symbol[], μ = Symbol[], y = [:μ, :s2]) | ||
Nodes: | ||
μ = (value = 0.0, input = (), eval = var"#43#46"(), kind = :Stochastic) | ||
s2 = (value = 1.0, input = (), eval = var"#44#47"(), kind = :Logical) | ||
y = (value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic) | ||
julia> m[@varname y] | ||
(value = 0.0, input = (:μ, :s2), eval = var"#45#48"(), kind = :Stochastic) | ||
``` | ||
""" | ||
@generated function Base.getindex(g::GraphInfo, vn::VarName{p}) where {p} | ||
fns = fieldnames(GraphInfo)[1:4] | ||
name_lens = Setfield.PropertyLens{p}() | ||
field_lenses = [Setfield.PropertyLens{f}() for f in fns] | ||
values = [:(get(g, Setfield.compose($l, $name_lens, getlens(vn)))) for l in field_lenses] | ||
return :(NamedTuple{$(fns)}(($(values...),))) | ||
end | ||
|
||
function Base.getindex(m::Model, vn::VarName) | ||
return m.g[vn] | ||
end | ||
|
||
function Base.show(io::IO, m::Model) | ||
print(io, "Nodes: \n") | ||
for node in nodes(m) | ||
print(io, "$node = ", m[VarName{node}()], "\n") | ||
end | ||
end | ||
|
||
|
||
function Base.iterate(m::Model, state=1) | ||
state > length(nodes(m)) ? nothing : (m[VarName{m.g.sorted_vertices[state]}()], state+1) | ||
end | ||
|
||
Base.eltype(m::Model) = NamedTuple{fieldnames(GraphInfo)[1:4]} | ||
Base.IteratorEltype(m::Model) = HasEltype() | ||
|
||
Base.keys(m::Model) = (VarName{n}() for n in m.g.sorted_vertices) | ||
Base.values(m::Model) = Base.Generator(identity, m) | ||
Base.length(m::Model) = length(nodes(m)) | ||
Base.keytype(m::Model) = eltype(keys(m)) | ||
Base.valtype(m::Model) = eltype(m) | ||
|
||
|
||
""" | ||
dag(m::Model) | ||
Returns the adjacency matrix of the model as a SparseArray. | ||
""" | ||
get_dag(m::Model) = m.g.A | ||
|
||
""" | ||
nodes(m::Model) | ||
Returns a `Vector{Symbol}` containing the sorted vertices | ||
of the DAG. | ||
""" | ||
nodes(m::Model) = m.g.sorted_vertices |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
using AbstractPPL | ||
import AbstractPPL.GraphPPL: GraphInfo, Model, get_dag | ||
using SparseArrays | ||
using Test | ||
## Example taken from Mamba | ||
line = Dict{Symbol, Any}( | ||
:x => [1, 2, 3, 4, 5], | ||
:y => [1, 3, 3, 3, 5] | ||
) | ||
line[:xmat] = [ones(5) line[:x]] | ||
|
||
# just making it a NamedTuple so that the values can be tested later. Constructor should be used as Model(;kwargs...). | ||
model = ( | ||
β = (zeros(2), () -> MvNormal(2, sqrt(1000)), :Stochastic), | ||
xmat = (line[:xmat], () -> line[:xmat], :Logical), | ||
s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic), | ||
μ = (zeros(5), (xmat, β) -> xmat * β, :Logical), | ||
y = (zeros(5), (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic) | ||
) | ||
|
||
# construct the model! | ||
m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor | ||
|
||
# test the type of the model is correct | ||
@test typeof(m) <: Model | ||
@test typeof(m) == Model{(:s2, :xmat, :β, :μ, :y)} | ||
@test typeof(m.g) <: GraphInfo <: AbstractModelTrace | ||
@test typeof(m.g) == GraphInfo{(:s2, :xmat, :β, :μ, :y)} | ||
|
||
# test the dag is correct | ||
A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0]) | ||
@test get_dag(m) == A | ||
|
||
@test length(m) == 5 | ||
@test eltype(m) == valtype(m) | ||
|
||
# check the values from the NamedTuple match the values in the fields of GraphInfo | ||
vals = AbstractPPL.GraphPPL.getvals(model) | ||
for (i, field) in enumerate([:value, :eval, :kind]) | ||
@test eval( :( values(m.g.$field) == vals[$i] ) ) | ||
end | ||
|
||
for node in m | ||
@test typeof(node) <: NamedTuple{fieldnames(GraphInfo)[1:4]} | ||
end | ||
|
||
# test the right inputs have been inferred | ||
@test m.g.input == (s2 = (), xmat = (), β = (), μ = (:xmat, :β), y = (:μ, :s2)) | ||
|
||
# test keys are VarNames | ||
for key in keys(m) | ||
@test typeof(key) <: VarName | ||
end | ||
|
||
# test Model constructor for model with single parent node | ||
single_parent_m = Model(μ = (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)) | ||
@test typeof(single_parent_m) == Model{(:μ, :y)} | ||
@test typeof(single_parent_m.g) == GraphInfo{(:μ, :y)} | ||
|
||
# test ErrorException for parent node not found | ||
@test_throws ErrorException Model( μ = (1.0, (β) -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)) | ||
|
||
# test AssertionError thrown for kwargs with the wrong order of inputs | ||
@test_throws AssertionError Model( μ = ((β) -> 3, 1.0, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters