Skip to content

Commit

Permalink
CompatHelper: bump compat for AbstractMCMC to 5, (keep existing compa…
Browse files Browse the repository at this point in the history
…t) (#551)

* CompatHelper: bump compat for AbstractMCMC to 5, (keep existing compat)

* CompatHelper: bump compat for AbstractMCMC to 5 for package test, (keep existing compat) (#552)

Co-authored-by: CompatHelper Julia <[email protected]>

* Update to AbstractMCMC 5

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update sampler.jl

* CompatHelper: bump compat for AbstractMCMC to 5 for package test, (keep existing compat) (#553)

* Fix for `rand` + replace overloads of `rand` with `rand_prior_true` for testing models (#541)

* preserve context from model in `rand`

* replace rand overloads in TestUtils with definitions of
rand_prior_true so we can properly test rand

* removed NamedTuple from signature of TestUtils.rand_prior_true

* updated references to previous overloads of rand to now use rand_prior_true

* test rand for DEMO_MODELS

* formatting

* fixed tests for rand for DEMO_MODELS

* fixed linkning tests

* added missing impl of rand_prior_true for demo_static_transformation

* formatting

* fixed rand_prior_true for demo_static_transformation

* bump minor version as this will be breaking

* bump patch version

* fixed old usage of rand

* Update test/varinfo.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed another usage of rand

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Remove `tonamedtuple` (#547)

* Remove dependencies to `tonamedtuple`

* Remove `tonamedtuple`s

* Minor version bump

---------

Co-authored-by: Hong Ge <[email protected]>

* CompatHelper: bump compat for AbstractMCMC to 5 for package test, (keep existing compat)

---------

Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: CompatHelper Julia <[email protected]>

* bump AbstractPPL version to 0.7

* Update AbstractPPL test dependency

* add `Random.AbstractRNG`

* Update sampler.jl (#557)

* Update src/sampler.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: CompatHelper Julia <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: Xianda Sun <[email protected]>
  • Loading branch information
7 people authored Nov 9, 2023
1 parent 2e940aa commit bda441b
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 37 deletions.
14 changes: 7 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.6"
AbstractMCMC = "5"
AbstractPPL = "0.7"
BangBang = "0.3"
Bijectors = "0.13"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1.5.4"
ChainRulesCore = "1"
Compat = "4"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
ConstructionBase = "1.5.4"
Distributions = "0.25"
DocStringExtensions = "0.9"
LogDensityProblems = "2"
MCMCChains = "6"
MacroTools = "0.5.6"
OrderedCollections = "1"
Requires = "1"
Setfield = "0.7.1, 0.8, 1"
Setfield = "1"
ZygoteRules = "0.2"
julia = "1.6"

Expand Down
12 changes: 12 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ function _check_varname_indexing(c::MCMCChains.Chains)
error("Chains do not support indexing using $vn.")
end

# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
function DynamicPPL.loadstate(chain::MCMCChains.Chains)
if !haskey(chain.info, :samplerstate)
throw(
ArgumentError(
"The chain object does not contain the final state of the sampler: Metadata `:samplerstate` missing.",
),
)
end
return chain.info[:samplerstate]
end

# A few methods needed.
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
Expand Down
44 changes: 29 additions & 15 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,31 @@ function default_varinfo(
return VarInfo(rng, model, init_sampler, context)
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::Model,
spl::Sampler;
sampler::Sampler,
N::Integer;
chain_type=default_chain_type(sampler),
resume_from=nothing,
init_params=nothing,
initial_state=loadstate(resume_from),
kwargs...,
)
if resume_from !== nothing
state = loadstate(resume_from)
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
end
return AbstractMCMC.mcmcsample(
rng, model, sampler, N; chain_type, initial_state, kwargs...
)
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
)
# Sample initial values.
vi = default_varinfo(rng, model, spl)

# Update the parameters if provided.
if init_params !== nothing
vi = initialize_parameters!!(vi, init_params, spl, model)
if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
Expand All @@ -108,15 +113,24 @@ function AbstractMCMC.step(
vi = last(evaluate!!(model, vi, DefaultContext()))
end

return initialstep(rng, model, spl, vi; init_params=init_params, kwargs...)
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
end

"""
loadstate(data)
Load sampler state from `data`.
By default, `data` is returned.
"""
loadstate(data) = data

"""
default_chaintype(sampler)
Default type of the chain of posterior samples from `sampler`.
"""
function loadstate end
default_chain_type(sampler::Sampler) = Any

"""
initialsampler(sampler::Sampler)
Expand All @@ -129,12 +143,12 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
initialsampler(spl::Sampler) = SampleFromPrior()

function initialize_parameters!!(
vi::AbstractVarInfo, init_params, spl::Sampler, model::Model
vi::AbstractVarInfo, initial_params, spl::Sampler, model::Model
)
@debug "Using passed-in initial variable values" init_params
@debug "Using passed-in initial variable values" initial_params

# Flatten parameters.
init_theta = mapreduce(vcat, init_params) do x
init_theta = mapreduce(vcat, initial_params) do x
vec([x;])
end

Expand Down
12 changes: 6 additions & 6 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "2.1, 3.0, 4"
AbstractPPL = "0.6"
AbstractMCMC = "5"
AbstractPPL = "0.7"
Bijectors = "0.13"
Compat = "4.3.0"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Documenter = "0.26.1, 0.27, 1"
Documenter = "1"
ForwardDiff = "0.10.12"
LogDensityProblems = "2"
MCMCChains = "4.0.4, 5, 6"
MCMCChains = "6.0.4"
MacroTools = "0.5.5"
Setfield = "0.7.1, 0.8, 1"
Setfield = "1"
StableRNGs = "1"
Tracker = "0.2.23"
Zygote = "0.5.4, 0.6"
Zygote = "0.6"
julia = "1.6"
18 changes: 9 additions & 9 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
model = coinflip()
sampler = Sampler(alg)
lptrue = logpdf(Binomial(25, 0.2), 10)
chain = sample(model, sampler, 1; init_params=0.2, progress=false)
chain = sample(model, sampler, 1; initial_params=0.2, progress=false)
@test chain[1].metadata.p.vals == [0.2]
@test getlogp(chain[1]) == lptrue

Expand All @@ -95,7 +95,7 @@
MCMCThreads(),
1,
10;
init_params=fill(0.2, 10),
initial_params=fill(0.2, 10),
progress=false,
)
for c in chains
Expand All @@ -110,7 +110,7 @@
end
model = twovars()
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
chain = sample(model, sampler, 1; init_params=[4, -1], progress=false)
chain = sample(model, sampler, 1; initial_params=[4, -1], progress=false)
@test chain[1].metadata.s.vals == [4]
@test chain[1].metadata.m.vals == [-1]
@test getlogp(chain[1]) == lptrue
Expand All @@ -122,7 +122,7 @@
MCMCThreads(),
1,
10;
init_params=fill([4, -1], 10),
initial_params=fill([4, -1], 10),
progress=false,
)
for c in chains
Expand All @@ -132,7 +132,7 @@
end

# set only m = -1
chain = sample(model, sampler, 1; init_params=[missing, -1], progress=false)
chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false)
@test !ismissing(chain[1].metadata.s.vals[1])
@test chain[1].metadata.m.vals == [-1]

Expand All @@ -143,19 +143,19 @@
MCMCThreads(),
1,
10;
init_params=fill([missing, -1], 10),
initial_params=fill([missing, -1], 10),
progress=false,
)
for c in chains
@test !ismissing(c[1].metadata.s.vals[1])
@test c[1].metadata.m.vals == [-1]
end

# specify `init_params=nothing`
# specify `initial_params=nothing`
Random.seed!(1234)
chain1 = sample(model, sampler, 1; progress=false)
Random.seed!(1234)
chain2 = sample(model, sampler, 1; init_params=nothing, progress=false)
chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false)
@test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals
@test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals

Expand All @@ -164,7 +164,7 @@
chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false)
Random.seed!(1234)
chains2 = sample(
model, sampler, MCMCThreads(), 1, 10; init_params=nothing, progress=false
model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false
)
for (c1, c2) in zip(chains1, chains2)
@test c1[1].metadata.m.vals == c2[1].metadata.m.vals
Expand Down

0 comments on commit bda441b

Please sign in to comment.