diff --git a/src/contexts.jl b/src/contexts.jl index 63f624b4e..53b454df6 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -188,7 +188,7 @@ getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(conte """ struct DefaultContext <: AbstractContext end -The `DefaultContext` is used by default to compute log the joint probability of the data +The `DefaultContext` is used by default to compute the log joint probability of the data and parameters when running the model. """ struct DefaultContext <: AbstractContext end @@ -199,7 +199,7 @@ NodeTrait(context::DefaultContext) = IsLeaf() vars::Tvars end -The `PriorContext` enables the computation of the log prior of the parameters `vars` when +The `PriorContext` enables the computation of the log prior of the parameters `vars` when running the model. """ struct PriorContext{Tvars} <: AbstractContext @@ -213,8 +213,8 @@ NodeTrait(context::PriorContext) = IsLeaf() vars::Tvars end -The `LikelihoodContext` enables the computation of the log likelihood of the parameters when -running the model. `vars` can be used to evaluate the log likelihood for specific values +The `LikelihoodContext` enables the computation of the log likelihood of the parameters when +running the model. `vars` can be used to evaluate the log likelihood for specific values of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. """ struct LikelihoodContext{Tvars} <: AbstractContext @@ -229,10 +229,10 @@ NodeTrait(context::LikelihoodContext) = IsLeaf() loglike_scalar::T end -The `MiniBatchContext` enables the computation of -`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the -`loglike_scalar` field, typically equal to `the number of data points / batch size`. -This is useful in batch-based stochastic gradient descent algorithms to be optimizing +The `MiniBatchContext` enables the computation of +`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the +`loglike_scalar` field, typically equal to `the number of data points / batch size`. +This is useful in batch-based stochastic gradient descent algorithms to be optimizing `log(prior) + log(likelihood of all the data points)` in the expectation. """ struct MiniBatchContext{Tctx,T} <: AbstractContext diff --git a/src/model.jl b/src/model.jl index 09c0c1be1..082ec3871 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,15 +1,17 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstactContext} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} + context::Ctx=DefaultContext() end A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` -types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, and missing -arguments `missings`. +types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing +arguments `missings`, and evaluation context of type `Ctx`. Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`. +`context` is by default `DefaultContext()`. An argument with a type of `Missing` will be in `missings` by default. However, in non-traditional use-cases `missings` can be defined differently. All variables in `missings` @@ -1077,7 +1079,7 @@ end Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. # Examples - + ```jldoctest julia> using MCMCChains, Distributions @@ -1093,7 +1095,7 @@ julia> # construct a chain of samples using MCMCChains chain = Chains(rand(10, 2, 3), [:s, :m]); julia> logjoint(demo_model([1., 2.]), chain); -``` +``` """ function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model @@ -1124,7 +1126,7 @@ end Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. # Examples - + ```jldoctest julia> using MCMCChains, Distributions @@ -1140,7 +1142,7 @@ julia> # construct a chain of samples using MCMCChains chain = Chains(rand(10, 2, 3), [:s, :m]); julia> logprior(demo_model([1., 2.]), chain); -``` +``` """ function logprior(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model @@ -1171,7 +1173,7 @@ end Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. # Examples - + ```jldoctest julia> using MCMCChains, Distributions @@ -1187,7 +1189,7 @@ julia> # construct a chain of samples using MCMCChains chain = Chains(rand(10, 2, 3), [:s, :m]); julia> loglikelihood(demo_model([1., 2.]), chain); -``` +``` """ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model diff --git a/src/transforming.jl b/src/transforming.jl index 41c877c91..1f6c55e24 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -1,3 +1,14 @@ +""" + struct DynamicTransformationContext{isinverse} <: AbstractContext + +When a model is evaluated with this context, transform the accompanying `AbstractVarInfo` to +constrained space if `isinverse` or unconstrained if `!isinverse`. + +Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the +`DynamicTransformationContext` methods with more efficient implementations. +`DynamicTransformationContext` is a fallback for when we need to evaluate the model to know +how to do the transformation, used by e.g. `SimpleVarInfo`. +""" struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf()