Skip to content

Commit

Permalink
Miscellaneous style and docs improvements (#622)
Browse files Browse the repository at this point in the history
* Fix docstring typo

* Add mention of context in the docstring of Model

* Add a docstring for DynamicTransformationContext

* Tiny style improvements
  • Loading branch information
mhauru authored Sep 2, 2024
1 parent 122ecd1 commit 3ffb003
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
16 changes: 8 additions & 8 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(conte
"""
struct DefaultContext <: AbstractContext end
The `DefaultContext` is used by default to compute log the joint probability of the data
The `DefaultContext` is used by default to compute the log joint probability of the data
and parameters when running the model.
"""
struct DefaultContext <: AbstractContext end
Expand All @@ -199,7 +199,7 @@ NodeTrait(context::DefaultContext) = IsLeaf()
vars::Tvars
end
The `PriorContext` enables the computation of the log prior of the parameters `vars` when
The `PriorContext` enables the computation of the log prior of the parameters `vars` when
running the model.
"""
struct PriorContext{Tvars} <: AbstractContext
Expand All @@ -213,8 +213,8 @@ NodeTrait(context::PriorContext) = IsLeaf()
vars::Tvars
end
The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
running the model. `vars` can be used to evaluate the log likelihood for specific values
The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
running the model. `vars` can be used to evaluate the log likelihood for specific values
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
"""
struct LikelihoodContext{Tvars} <: AbstractContext
Expand All @@ -229,10 +229,10 @@ NodeTrait(context::LikelihoodContext) = IsLeaf()
loglike_scalar::T
end
The `MiniBatchContext` enables the computation of
`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the
`loglike_scalar` field, typically equal to `the number of data points / batch size`.
This is useful in batch-based stochastic gradient descent algorithms to be optimizing
The `MiniBatchContext` enables the computation of
`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the
`loglike_scalar` field, typically equal to `the number of data points / batch size`.
This is useful in batch-based stochastic gradient descent algorithms to be optimizing
`log(prior) + log(likelihood of all the data points)` in the expectation.
"""
struct MiniBatchContext{Tctx,T} <: AbstractContext
Expand Down
20 changes: 11 additions & 9 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults}
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstactContext}
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx=DefaultContext()
end
A `Model` struct with model evaluation function of type `F`, arguments of names `argnames`
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, and missing
arguments `missings`.
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing
arguments `missings`, and evaluation context of type `Ctx`.
Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`.
`context` is by default `DefaultContext()`.
An argument with a type of `Missing` will be in `missings` by default. However, in
non-traditional use-cases `missings` can be defined differently. All variables in `missings`
Expand Down Expand Up @@ -1077,7 +1079,7 @@ end
Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.
# Examples
```jldoctest
julia> using MCMCChains, Distributions
Expand All @@ -1093,7 +1095,7 @@ julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> logjoint(demo_model([1., 2.]), chain);
```
```
"""
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
Expand Down Expand Up @@ -1124,7 +1126,7 @@ end
Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.
# Examples
```jldoctest
julia> using MCMCChains, Distributions
Expand All @@ -1140,7 +1142,7 @@ julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> logprior(demo_model([1., 2.]), chain);
```
```
"""
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
Expand Down Expand Up @@ -1171,7 +1173,7 @@ end
Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.
# Examples
```jldoctest
julia> using MCMCChains, Distributions
Expand All @@ -1187,7 +1189,7 @@ julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> loglikelihood(demo_model([1., 2.]), chain);
```
```
"""
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
Expand Down
11 changes: 11 additions & 0 deletions src/transforming.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
"""
struct DynamicTransformationContext{isinverse} <: AbstractContext
When a model is evaluated with this context, transform the accompanying `AbstractVarInfo` to
constrained space if `isinverse` or unconstrained if `!isinverse`.
Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the
`DynamicTransformationContext` methods with more efficient implementations.
`DynamicTransformationContext` is a fallback for when we need to evaluate the model to know
how to do the transformation, used by e.g. `SimpleVarInfo`.
"""
struct DynamicTransformationContext{isinverse} <: AbstractContext end
NodeTrait(::DynamicTransformationContext) = IsLeaf()

Expand Down

0 comments on commit 3ffb003

Please sign in to comment.