Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move predict from Turing #716

Merged
merged 19 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,156 @@
return keys(c.info.varname_to_symbol)
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
in `chain`, and return the resulting `Chains`.

If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.

# Examples
```jldoctest
julia> using DynamicPPL, AbstractMCMC, AdvancedHMC, ForwardDiff;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: no need to use AdvancedHMC (or any of the other packages), just construct the Chains by hand.
This also doesn't actually show that you need to import MCMCChains for this to work, which might be a good idea


julia> @model function linear_reg(x, y, σ = 0.1)
β ~ Normal(0, 1)
for i ∈ eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end;

julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();

julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);

julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);

julia> m_train = linear_reg(xs_train, ys_train, σ);

julia> n_train_logdensity_function = DynamicPPL.LogDensityFunction(m_train, DynamicPPL.VarInfo(m_train));

julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β], discard_initial=100)
┌ Info: Found initial step size
└ ϵ = 0.003125

julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);

julia> predictions = predict(m_test, chain_lin_reg)
Object of type Chains, with data of type 100×2×1 Array{Float64,3}

Iterations = 1:100
Thinning interval = 1
Chains = 1
Samples per chain = 100
parameters = y[1], y[2]

2-element Array{ChainDataFrame,1}

Summary Statistics
parameters mean std naive_se mcse ess r_hat
────────── ─────── ────── ──────── ─────── ──────── ──────
y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903

Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
────────── ─────── ─────── ─────── ─────── ───────
y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
y[2] 20.1870 20.3178 20.3839 20.4466 20.5895

julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));

julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
true
```
"""
function DynamicPPL.predict(
rng::DynamicPPL.Random.AbstractRNG,
model::DynamicPPL.Model,
chain::MCMCChains.Chains;
include_all=false,
)
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
prototypical_varinfo = DynamicPPL.VarInfo(model)

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
varinfo = deepcopy(prototypical_varinfo)
DynamicPPL.setval_and_resample!(
varinfo, parameter_only_chain, sample_idx, chain_idx
)
model(rng, varinfo, DynamicPPL.SampleFromPrior())

vals = DynamicPPL.values_as_in_model(model, varinfo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually changing the behavior from Turing.jl's implementation. This will result in also including variables used in := statements, which is not currently done.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooooh nice catch; thanks! Hmm, uncertain if this is desired behavior though 😕

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your issue on :=, totally understand the concern here. But if we are not exporting predict, we can change this in near future, also we might want to use fix in the future, so the behavior will be right then.

We would need to make a minor release of Turing if we change this now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't this the purpose of this PR? To move the predict from Turing.jl to DynamicPPL.jl?

also we might want to use fix in the future

Whether we're using fix or not is just an internal impl detail, and is not relevant for its usage, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't this the purpose of this PR? To move the predict from Turing.jl to DynamicPPL.jl?

Ideally, I would want this PR to do a proper implementation of predict in DynamicPPL. But now, I am okay with the PR being only a first step towards that.

Whether we're using fix or not is just an internal impl detail, and is not relevant for its usage, right?

what I was trying to say is that, with fix it should have the right behavior (with regards to :=). Of course not the only way to reach the desired behavior.

Copy link
Member

@yebai yebai Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improving it in a separate PR sounds good, but please create an issue to track @torfjelde's comment.

varname_vals = mapreduce(
collect,
vcat,
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
)

return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo))
end

chain_result = reduce(
MCMCChains.chainscat,
[
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
chain_idx in 1:size(predictive_samples, 2)
],
)
parameter_names = if include_all
MCMCChains.names(chain_result, :parameters)

Check warning on line 146 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L146

Added line #L146 was not covered by tests
else
filter(
k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)),
names(chain_result, :parameters),
)
end
return chain_result[parameter_names]
end

function _predictive_samples_to_arrays(predictive_samples)
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()

sample_dicts = map(predictive_samples) do sample
varname_value_pairs = sample.varname_and_values
varnames = map(first, varname_value_pairs)
values = map(last, varname_value_pairs)
for varname in varnames
push!(variable_names_set, varname)
end

return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
end

variable_names = collect(variable_names_set)
variable_values = [
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
key in variable_names
]

return variable_names, variable_values
end

function _predictive_samples_to_chains(predictive_samples)
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
variable_names_symbols = map(Symbol, variable_names)

internal_parameters = [:lp]
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)

parameter_names = [variable_names_symbols; internal_parameters]
parameter_values = hcat(variable_values, log_probabilities)
parameter_values = MCMCChains.concretize(parameter_values)

return MCMCChains.Chains(
parameter_values, parameter_names, (internals=internal_parameters,)
)
end

"""
generated_quantities(model::Model, chain::MCMCChains.Chains)

Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using AbstractPPL
using Bijectors
using Compat
using Distributions
using OrderedCollections: OrderedDict
using OrderedCollections: OrderedCollections, OrderedDict

using AbstractMCMC: AbstractMCMC
using ADTypes: ADTypes
Expand Down
16 changes: 16 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,22 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
end
end

"""
predict([rng::AbstractRNG,] model::Model, chain; include_all=false)

Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
in `chain`.

If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.
"""
function predict(model::Model, chain; include_all=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Turing.jl we're currently overloading StatsBase.predict, so we should probably do the same here, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with this, but probably not time yet. Definitely after TuringLang/AbstractPPL.jl#81 is merged 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But is this PR then held up until that PR is merged then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, that PR doesn't really matter; overloading StatsBase.predict here and now just means that we'll immediately be compliant with the AbstractPPL.jl interface when that PR merges?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grey area: for me it is okay, because this PR is just about introduce a Turing-faced predict, not a user faced one yet. At the moment predict is not a public API yet

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If nothing significant is missing in TuringLang/AbstractPPL.jl#81, let's merge it and overload AbstractPPL.predict here.

# this is only defined in `ext/DynamicPPLMCMCChainsExt.jl`
# TODO: add other methods for different type of `chain` arguments: e.g., `VarInfo`, `NamedTuple`, and `OrderedDict`
return predict(Random.default_rng(), model, chain; include_all)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, we should definitively inform the user of this, no? Otherwise they'll just be like "oh why is this not defined?"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to export predict right now, so predict is only available through Turing.jl, give or take.

would function not defined be meaningful enough if user give other types of input?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Turing exports it, it's better for DynamicPPL to export it, too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I was proposing delaying this until a good predict spec is reached


"""
generated_quantities(model::Model, parameters::NamedTuple)
generated_quantities(model::Model, values, keys)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this doesn't quite seem worth it to test predict, no? What's the reasoning here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add anything or change the implementation in this PR.

Agree AHMC is heavy dep, but tests like https://github.com/TuringLang/DynamicPPL.jl/blob/fd1277b7201477448d3257cab65557b850bcf5b4/test/ext/DynamicPPLMCMCChainsExt.jl#L48C1-L55C45
rely on quality of samples

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but should just replace them with samples from the prior or something. This is just checking that the statistics are correct; it doesn't matter if these statistics are from the prior or posterior 🤷

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, would it be really bad to make AdvancedHMC be a test dependency of DynamicPPL? (again, I don't like this either, but it's not too bad, I would be for adding an issue for removing this dependency later than tempering more with this PR anymore)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't look at this PR properly until Wednesday, but in https://github.com/TuringLang/DynamicPPL.jl/pull/733/files#diff-3981168ff1709b3f48c35e40f491c26d9b91fc29373e512f1272f3b928cea6c0 I wrote a function that generates a chain by sampling from the prior. (It's called make_chain_from_prior if the link doesn't bring you to the right place)
Feel free to take it if you think it's useful :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunxd3, @penelopeysm the posterior of Bayesian linear regression can be obtained in closed form (i.e. it is a Gaussian, see here). I suggest that

  1. add this BLR model to DynamicPPL test models
  2. implement its analytical posterior
  3. sample from the analytical posterior directly and drop the AHMC deps.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though the closed-form posterior is a good idea, there's really no need to run this test on posterior samples:) These were just some stats that were picked to have something to compare to; prior chain is the way to go I think 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prior chain make sense: should we generate samples from prior, take out samples of a particular variable, and try to predict it?

Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand Down Expand Up @@ -32,6 +33,7 @@ AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9"
Accessors = "0.1"
Bijectors = "0.13.9, 0.14, 0.15"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
Combinatorics = "1"
Compat = "4.3.0"
Distributions = "0.25"
Expand Down
167 changes: 167 additions & 0 deletions test/ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,170 @@
@test size(chain_generated) == (1000, 1)
@test mean(chain_generated) ≈ 0 atol = 0.1
end

@testset "predict" begin
DynamicPPL.Random.seed!(100)

@model function linear_reg(x, y, σ=0.1)
β ~ Normal(0, 1)

for i in eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end

@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal(β .* x, σ^2 * I)
end

f(x) = 2 * x + 0.1 * randn()

Δ = 0.1
xs_train = 0:Δ:10
ys_train = f.(xs_train)
xs_test = [10 + Δ, 10 + 2 * Δ]
ys_test = f.(xs_test)

# Infer
m_lin_reg = linear_reg(xs_train, ys_train)
chain_lin_reg = sample(
DynamicPPL.LogDensityFunction(m_lin_reg),
AdvancedHMC.NUTS(0.65),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really doesn't seem necessary to use NUTS here. Just construct a Chains by hand or something, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same reason as above: some tests relies on the quality of the samples

1000;
chain_type=MCMCChains.Chains,
param_names=[:β],
discard_initial=100,
n_adapt=100,
)

# Predict on two last indices
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)

ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))

# test like this depends on the variance of the posterior
# this only makes sense if the posterior variance is about 0.002
@test sum(abs2, ys_test - ys_pred) ≤ 0.1

# Ensure that `rng` is respected
predictions1 = let rng = MersenneTwister(42)
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
end
predictions2 = let rng = MersenneTwister(42)
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
end
@test all(Array(predictions1) .== Array(predictions2))

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1

# Multiple chains
chain_lin_reg = sample(
DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)),
AdvancedHMC.NUTS(0.65),
MCMCThreads(),
1000,
2;
chain_type=MCMCChains.Chains,
param_names=[:β],
discard_initial=100,
n_adapt=100,
)
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)

@test size(chain_lin_reg, 3) == size(predictions, 3)

for chain_idx in MCMCChains.chains(chain_lin_reg)
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
@test sum(abs2, ys_test - ys_pred) ≤ 0.1
end

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)

for chain_idx in MCMCChains.chains(chain_lin_reg)
ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1))
@test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1
end

# https://github.com/TuringLang/Turing.jl/issues/1352
@model function simple_linear1(x, y)
intercept ~ Normal(0, 1)
coef ~ MvNormal(zeros(2), I)
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear2(x, y)
intercept ~ Normal(0, 1)
coef ~ filldist(Normal(0, 1), 2)
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear3(x, y)
intercept ~ Normal(0, 1)
coef = Vector(undef, 2)
for i in axes(coef, 1)
coef[i] ~ Normal(0, 1)
end
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear4(x, y)
intercept ~ Normal(0, 1)
coef1 ~ Normal(0, 1)
coef2 ~ Normal(0, 1)
coef = [coef1, coef2]
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

x = randn(2, 100)
y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)]

param_names = Dict(
simple_linear1 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear2 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear4 => [:intercept, :coef1, :coef2, :error],
)
@testset "$model" for model in
[simple_linear1, simple_linear2, simple_linear3, simple_linear4]
m = model(x, y)
chain = sample(
DynamicPPL.LogDensityFunction(m),
AdvancedHMC.NUTS(0.65),
400;
initial_params=rand(4),
chain_type=MCMCChains.Chains,
param_names=param_names[model],
discard_initial=100,
n_adapt=100,
)
chain_predict = DynamicPPL.predict(model(x, missing), chain)
mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)]
@test mean(abs2, mean_prediction - y) ≤ 1e-3
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Accessors
using ADTypes
using AdvancedHMC: AdvancedHMC
using DynamicPPL
using AbstractMCMC
using AbstractPPL
Expand Down
Loading