Skip to content

Commit

Permalink
Immutable versions of link and invlink (#525)
Browse files Browse the repository at this point in the history
* added immutable versions of link and invlink

* added explicit invlink implementation for VarInfo

* remove false debug statement

* fixed default impls of invlink for AbstractVarInfo

* formatting

* use x to refer to the constrained space in invlink impl

* added immuatable link implementation for VarInfo

* added threadsafe versions of link and invlink

* added default implementations of link and invlink for DynamicTransformation

* formatting

* added tests for immutable link and invlink

* export link and invlink

* added link and invlink to docs

* fixed setall! for UntypedVarInfo

* added testing model demo_one_variable_multiple_constraints

* fixed BangBang.setindex!! for setting vector in vector

* added tests with unflatten + linking

* fixed reference to logabsdetjac in TestUtils

* improoved tests for unflatten + linking

* improved testing of unflatten + linking a bit

* added demo_lkjchol model to TestUtils

* formatting

* fixed impl of link for AbstractVarInfo

* epxanded comment on BangBang hack

* Apply suggestions from code review

Co-authored-by: Hong Ge <[email protected]>

* added references to BangBang issues and PRs talking about the
`possible` overloads

---------

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai authored Sep 1, 2023
1 parent 866eb6f commit ba16e3b
Show file tree
Hide file tree
Showing 10 changed files with 434 additions and 11 deletions.
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ DynamicPPL.StaticTransformation
DynamicPPL.istrans
DynamicPPL.settrans!!
DynamicPPL.transformation
DynamicPPL.link
DynamicPPL.invlink
DynamicPPL.link!!
DynamicPPL.invlink!!
DynamicPPL.default_transformation
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ export AbstractVarInfo,
updategid!,
setorder!,
istrans,
link,
link!,
link!!,
invlink,
invlink!,
invlink!!,
tonamedtuple,
Expand Down
43 changes: 41 additions & 2 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ function settrans!! end
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
Transforms the variables in `vi` to their linked space, using the transformation `t`.
Transform the variables in `vi` to their linked space, using the transformation `t`,
mutating `vi` if possible.
If `t` is not provided, `default_transformation(model, vi)` will be used.
Expand All @@ -383,12 +384,31 @@ function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
return link!!(default_transformation(model, vi), vi, spl, model)
end

"""
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`.
If `t` is not provided, `default_transformation(model, vi)` will be used.
See also: [`default_transformation`](@ref), [`invlink`](@ref).
"""
link(vi::AbstractVarInfo, model::Model) = link(deepcopy(vi), SampleFromPrior(), model)
function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return link(t, deepcopy(vi), SampleFromPrior(), model)
end
function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
# Use `default_transformation` to decide which transformation to use if none is specified.
return link(default_transformation(model, vi), deepcopy(vi), spl, model)
end

"""
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
Transform the variables in `vi` to their constrained space, using the (inverse of)
transformation `t`.
transformation `t`, mutating `vi` if possible.
If `t` is not provided, `default_transformation(model, vi)` will be used.
Expand Down Expand Up @@ -434,6 +454,25 @@ function invlink!!(
return settrans!!(vi_new, NoTransformation())
end

"""
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of)
transformation `t`.
If `t` is not provided, `default_transformation(model, vi)` will be used.
See also: [`default_transformation`](@ref), [`link`](@ref).
"""
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)
function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return invlink(t, vi, SampleFromPrior(), model)
end
function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
return invlink(transformation(vi), vi, spl, model)
end

"""
maybe_invlink_before_eval!!([t::Transformation,] vi, context, model)
Expand Down
109 changes: 109 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,115 @@ function logprior_true_with_logabsdet_jacobian(
return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp
end

"""
demo_one_variable_multiple_constraints()
A model with a single multivariate `x` whose components have multiple different constraints.
# Model
```julia
x[1] ~ Normal()
x[2] ~ InverseGamma(2, 3)
x[3] ~ truncated(Normal(), -5, 20)
x[4:5] ~ Dirichlet([1.0, 2.0])
```
"""
@model function demo_one_variable_multiple_constraints(
::Type{TV}=Vector{Float64}
) where {TV}
x = TV(undef, 5)
x[1] ~ Normal()
x[2] ~ InverseGamma(2, 3)
x[3] ~ truncated(Normal(), -5, 20)
x[4:5] ~ Dirichlet([1.0, 2.0])

return (x=x,)
end

function logprior_true(model::Model{typeof(demo_one_variable_multiple_constraints)}, x)
return (
logpdf(Normal(), x[1]) +
logpdf(InverseGamma(2, 3), x[2]) +
logpdf(truncated(Normal(), -5, 20), x[3]) +
logpdf(Dirichlet([1.0, 2.0]), x[4:5])
)
end
function loglikelihood_true(model::Model{typeof(demo_one_variable_multiple_constraints)}, x)
return zero(float(eltype(x)))
end
function varnames(model::Model{typeof(demo_one_variable_multiple_constraints)})
return [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(x[4:5])]
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_one_variable_multiple_constraints)}, x
)
b_x2 = Bijectors.bijector(InverseGamma(2, 3))
b_x3 = Bijectors.bijector(truncated(Normal(), -5, 20))
b_x4 = Bijectors.bijector(Dirichlet([1.0, 2.0]))
x_unconstrained = vcat(x[1], b_x2(x[2]), b_x3(x[3]), b_x4(x[4:5]))
Δlogp = (
Bijectors.logabsdetjac(b_x2, x[2]) +
Bijectors.logabsdetjac(b_x3, x[3]) +
Bijectors.logabsdetjac(b_x4, x[4:5])
)
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
end

function Random.rand(
rng::Random.AbstractRNG,
::Type{NamedTuple},
model::Model{typeof(demo_one_variable_multiple_constraints)},
)
x = Vector{Float64}(undef, 5)
x[1] = rand(rng, Normal())
x[2] = rand(rng, InverseGamma(2, 3))
x[3] = rand(rng, truncated(Normal(), -5, 20))
x[4:5] = rand(rng, Dirichlet([1.0, 2.0]))
return (x=x,)
end

"""
demo_lkjchol(d=2)
A model with a single variable `x` with support on the Cholesky factor of a
LKJ distribution.
# Model
```julia
x ~ LKJCholesky(d, 1.0)
```
"""
@model function demo_lkjchol(d::Int=2)
x ~ LKJCholesky(d, 1.0)
return (x=x,)
end

function logprior_true(model::Model{typeof(demo_lkjchol)}, x)
return logpdf(LKJCholesky(model.args.d, 1.0), x)
end

function loglikelihood_true(model::Model{typeof(demo_lkjchol)}, x)
return zero(float(eltype(x)))
end

function varnames(model::Model{typeof(demo_lkjchol)})
return [@varname(x)]
end

function logprior_true_with_logabsdet_jacobian(model::Model{typeof(demo_lkjchol)}, x)
b_x = Bijectors.bijector(LKJCholesky(model.args.d, 1.0))
x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x)
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
end

function Random.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::Model{typeof(demo_lkjchol)}
)
x = rand(rng, LKJCholesky(model.args.d, 1.0))
return (x=x,)
end

# A collection of models for which the posterior should be "similar".
# Some utility methods for these.
function _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
Expand Down
12 changes: 12 additions & 0 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ function invlink!!(
return invlink!!(t, vi.varinfo, spl, model)
end

function link(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link(t, vi.varinfo, spl, model)
end

function invlink(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink(t, vi.varinfo, spl, model)
end

function maybe_invlink_before_eval!!(
vi::ThreadSafeVarInfo, context::AbstractContext, model::Model
)
Expand Down
12 changes: 12 additions & 0 deletions src/transforming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,15 @@ function invlink!!(
NoTransformation(),
)
end

function link(
t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
)
return link!!(t, deepcopy(vi), spl, model)
end

function invlink(
t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model
)
return invlink!!(t, deepcopy(vi), spl, model)
end
19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,8 @@ function splitlens(condition, lens)
return current_parent, current_child, condition(current_parent)
end

# HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233
# and https://github.com/JuliaFolds/BangBang.jl/pull/238.
# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`.
function BangBang.possible(
::typeof(BangBang._setindex!), ::C, ::T, ::Colon, ::Integer
Expand All @@ -514,6 +516,23 @@ function BangBang.possible(
return BangBang.implements(setindex!, C) &&
promote_type(eltype(C), eltype(T)) <: eltype(C)
end
# HACK: Makes it possible to use ranges, etc. for setting a vector.
# For example, without this hack, BangBang.jl will consider
#
# x[1:2] = [1, 2]
#
# as NOT supported. This results is calling the immutable
# `BangBang.setindex` instead, which also ends up expanding the
# type of the containing array (`x` in the above scenario) to
# have element type `Any`.
# The below code just, correctly, marks this as possible and
# thus we hit the mutable `setindex!` instead.
function BangBang.possible(
::typeof(BangBang._setindex!), ::C, ::T, ::AbstractVector{<:Integer}
) where {C<:AbstractVector,T<:AbstractVector}
return BangBang.implements(setindex!, C) &&
promote_type(eltype(C), eltype(T)) <: eltype(C)
end

# HACK(torfjelde): This makes it so it works on iterators, etc. by default.
# TODO(torfjelde): Do better.
Expand Down
Loading

2 comments on commit ba16e3b

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/90639

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.23.15 -m "<description of version>" ba16e3bc91e293c58b03ad287637472d6e11f52f
git push origin v0.23.15

Please sign in to comment.