Skip to content

Commit

Permalink
Added check_model and sub-module DebugUtils (#540)
Browse files Browse the repository at this point in the history
* initial work on model checking

* use record_pre_tilde!, record_post_tilde!, etc. instead of just a
single record_tilde! + support for dot tilde + return issuccess and
additional info in check_model

* added test_context_interface to TestUtils

* added tests for check_model

* moved debug contexts and check_model to a separate file

* export check_model + make DebugContext take the model as input so we
can further customize

* noticed I forgot to include check_models.jl file

* fixed tests

* added record-methods for observe statements too

* use explicit types for the recorded tilde statements + added
convenient show methods to make displaying the trace nicer

* renamd check__model to debug_utils and put it into a module

* renamed test/check_model.jl to test/debug_utils.jl

* removed unnecessary stuff in tests

* added test for logging of statements

* removed unnecessary splatting in broadcasting + improved errors for
encountering missing

* added missing implementation of tilde_observe for PrefixContext

* re-ordered method implementations for DebugContext to make things a
bit more readable

* addeed error message indicating that usage of missing for
de-conditioning is restricted to univariate distributions

* added missing left field to ObserveStmt

* fixed conditioned

* fixed `fixed` too, and moved the `_merge` to a more sensible location

* added check_model_post_evaluation and made it so we're using
SamplingContext by default since we're using an empty VarInfo by default

* removed show_statements

* perform some simple checks to make sure show is working for statements

* improved test for show of statements a tiny bit

* added some more docs

* more docs

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed typo in warning

* moved inclusion of trace and others in return-value from check_model
to check_model_and_extras

* formatting

* drop returning varnames_seen and renamed check_model_and_extras to check_model_and_trace

* drop export of DebugContext

* added check_model and check_model_and_trace to docs

* updated tests

* more updates to tests

* formatting

* added rng as an optional positional argument to check_model methods

* added an example in the docstring of check_model_and_trace

* added example of correct and incorrect model in check_model_and_trace docstring

* Update src/debug_utils.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed docs maybe

* fixed docstring maybe

* fixed reference to Setfield and tests

* fixed docs

* added some conveinence methods in addition to a
`has_static_constraints` method to empirically check whether the model
has static constraints or if they are indeed changing dependent on realizations

* improved show for large arrays of varnames whiich can occur in
dot-tilde statements

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
torfjelde and github-actions[bot] authored Apr 19, 2024
1 parent 33a84c7 commit 824f712
Show file tree
Hide file tree
Showing 8 changed files with 815 additions and 6 deletions.
15 changes: 15 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,21 @@ DynamicPPL.TestUtils.update_values!!
DynamicPPL.TestUtils.test_values
```

## Debugging Utilities

DynamicPPL provides a few methods for checking validity of a model-definition.

```@docs
check_model
check_model_and_trace
```

And some which might be useful to determine certain properties of the model based on the debug trace.

```@docs
DynamicPPL.has_static_constraints
```

## Advanced

### Variable names
Expand Down
8 changes: 6 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ export AbstractVarInfo,
vectorize,
reconstruct,
reconstruct!,
Sample,
init,
vectorize,
OrderedDict,
Expand Down Expand Up @@ -130,7 +129,9 @@ export AbstractVarInfo,
# Convenience macros
@addlogprob!,
@submodel,
value_iterator_from_chain
value_iterator_from_chain,
check_model,
check_model_and_trace

# Reexport
using Distributions: loglikelihood
Expand Down Expand Up @@ -179,6 +180,9 @@ include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")

include("debug_utils.jl")
using .DebugUtils

if !isdefined(Base, :get_extension)
using Requires
end
Expand Down
3 changes: 3 additions & 0 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ end
function tilde_observe(context::PrefixContext, right, left, vi)
return tilde_observe(context.context, right, left, vi)
end
function tilde_observe(context::PrefixContext, sampler, right, left, vi)
return tilde_observe(context.context, sampler, right, left, vi)
end

"""
tilde_observe!!(context, right, left, vname, vi)
Expand Down
8 changes: 4 additions & 4 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,14 +479,14 @@ a merged version of the condition values.
function conditioned(context::AbstractContext)
return conditioned(NodeTrait(conditioned, context), context)
end
conditioned(::IsLeaf, context) = ()
conditioned(::IsLeaf, context) = NamedTuple()
conditioned(::IsParent, context) = conditioned(childcontext(context))
function conditioned(context::ConditionContext)
# Note the order of arguments to `merge`. The behavior of the rest of DPPL
# is that the outermost `context` takes precendence, hence when resolving
# the `conditioned` variables we need to ensure that `context.values` takes
# precedence over decendants of `context`.
return merge(context.values, conditioned(childcontext(context)))
return _merge(context.values, conditioned(childcontext(context)))
end

struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext
Expand Down Expand Up @@ -655,12 +655,12 @@ Note that this will recursively traverse the context stack and return
a merged version of the fix values.
"""
fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context)
fixed(::IsLeaf, context) = ()
fixed(::IsLeaf, context) = NamedTuple()
fixed(::IsParent, context) = fixed(childcontext(context))
function fixed(context::FixedContext)
# Note the order of arguments to `merge`. The behavior of the rest of DPPL
# is that the outermost `context` takes precendence, hence when resolving
# the `fixed` variables we need to ensure that `context.values` takes
# precedence over decendants of `context`.
return merge(context.values, fixed(childcontext(context)))
return _merge(context.values, fixed(childcontext(context)))
end
Loading

0 comments on commit 824f712

Please sign in to comment.