Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for generated_quantities #534

Merged
merged 31 commits into from
Sep 9, 2023
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9735066
added method for extracting the child lens from a varname subsumed by
torfjelde Sep 5, 2023
306946d
added nested_getindex and nested_setindex! for VarInfo
torfjelde Sep 5, 2023
af2a30f
added ConstructionBase.setproperties implementation for `Cholesky`
torfjelde Sep 5, 2023
ce0b2e7
fixed minor formatting issue
torfjelde Sep 5, 2023
3a8201c
added `supports_varname_indexing` for chains and use this in generate…
torfjelde Sep 6, 2023
dc7c675
use a private method rather than overloading getindex for Chains
torfjelde Sep 6, 2023
e4d964e
removed getindex overloads in nested_index testing
torfjelde Sep 6, 2023
a18b435
moved generated_quantities tests to test/model.jl
torfjelde Sep 6, 2023
34e422a
Apply suggestions from code review
torfjelde Sep 6, 2023
c4b2556
will now also correctly set variables to be resampled, etc.
torfjelde Sep 6, 2023
3b27435
Update test/model.jl
torfjelde Sep 6, 2023
ef67495
Update src/varinfo.jl
torfjelde Sep 6, 2023
df2d8e3
added Compat as a test dep so we can methods such as stack
torfjelde Sep 6, 2023
035d592
improved overload of ConstructionBase.setproperties
torfjelde Sep 6, 2023
24076b5
Apply suggestions from code review
torfjelde Sep 6, 2023
c604503
added docstring to remove_parent_lens
torfjelde Sep 6, 2023
ad4a5bd
removed methods which are not useful for the purpose of this PR
torfjelde Sep 6, 2023
378897e
noticed we're incorrectly using chain rather than chain_params in gen…
torfjelde Sep 6, 2023
220646e
Update ext/DynamicPPLMCMCChainsExt.jl
torfjelde Sep 6, 2023
08dd71a
fixed doctests
torfjelde Sep 6, 2023
c09b780
added Requires.jl
torfjelde Sep 6, 2023
44335d4
Update src/DynamicPPL.jl
torfjelde Sep 6, 2023
521775f
bump patch version
torfjelde Sep 6, 2023
3594017
Merge remote-tracking branch 'origin/torfjelde/nested-get-and-setindx…
torfjelde Sep 6, 2023
2ae113c
Update src/DynamicPPL.jl
torfjelde Sep 7, 2023
e625fc5
moved new generated_quantities functionality into setval_and_resample!
torfjelde Sep 7, 2023
6d16806
Apply suggestions from code review
torfjelde Sep 7, 2023
1e42770
Update ext/DynamicPPLMCMCChainsExt.jl
torfjelde Sep 7, 2023
628809c
Update src/chains.jl
torfjelde Sep 7, 2023
8b9e3e0
bump compat entry for ConstructionBase.jl
torfjelde Sep 8, 2023
46e5a94
Merge remote-tracking branch 'origin/torfjelde/nested-get-and-setindx…
torfjelde Sep 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added Requires.jl
torfjelde committed Sep 6, 2023
commit c09b7802d053429b62848b0ec1b1134fa2ec2cf1
22 changes: 12 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -16,19 +16,11 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.6"
@@ -39,9 +31,19 @@ ConstructionBase = "1"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
LogDensityProblems = "2"
MacroTools = "0.5.6"
MCMCChains = "6"
MacroTools = "0.5.6"
OrderedCollections = "1"
Requires = "1"
Setfield = "0.7.1, 0.8, 1"
ZygoteRules = "0.2"
julia = "1.6"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
9 changes: 7 additions & 2 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
module DynamicPPLMCMCChainsExt

using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains
if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains
else
using ..DynamicPPL: DynamicPPL
using ..MCMCChains: MCMCChains
end

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
10 changes: 10 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
@@ -175,4 +175,14 @@ include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")

if !isdefined(Base, :get_extension)
using Requires
end

function __init__()
@static if !isdefined(Base, :get_extension)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include("../ext/DynamicPPLMCMCChainsExt.jl")
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
end

end # module

Unchanged files with check annotations Beta

end
# HACK: Be better.
supports_varname_indexing(chain::AbstractChains) = false

Check warning on line 1261 in src/model.jl

Codecov / codecov/patch

src/model.jl#L1261

Added line #L1261 was not covered by tests
"""
generated_quantities(model::Model, parameters::NamedTuple)
end
# TODO: Remove as soon as https://github.com/JuliaObjects/ConstructionBase.jl/pull/80 goes through.
ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, ::NamedTuple{()}) = C

Check warning on line 1086 in src/utils.jl

Codecov / codecov/patch

src/utils.jl#L1086

Added line #L1086 was not covered by tests
function ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{(:L,)})
return LinearAlgebra.Cholesky(
C.uplo === 'U' ? permutedims(patch.L) : patch.L, C.uplo, C.info
)
end
function ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{(:U,)})
return LinearAlgebra.Cholesky(

Check warning on line 1093 in src/utils.jl

Codecov / codecov/patch

src/utils.jl#L1092-L1093

Added lines #L1092 - L1093 were not covered by tests
C.uplo === 'L' ? permutedims(patch.U) : patch.U, C.uplo, C.info
)
end
function ConstructionBase.setproperties(

Check warning on line 1097 in src/utils.jl

Codecov / codecov/patch

src/utils.jl#L1097

Added line #L1097 was not covered by tests
C::LinearAlgebra.Cholesky, patch::NamedTuple{(:UL,)}
)
return LinearAlgebra.Cholesky(patch.UL, C.uplo, C.info)

Check warning on line 1100 in src/utils.jl

Codecov / codecov/patch

src/utils.jl#L1100

Added line #L1100 was not covered by tests
end
@nospecialize function ConstructionBase.setproperties(

Check warning on line 1102 in src/utils.jl

Codecov / codecov/patch

src/utils.jl#L1102

Added line #L1102 was not covered by tests
C::LinearAlgebra.Cholesky, patch::NamedTuple
)
return error("Can only patch one of :L, :U, :UL at the time")

Check warning on line 1105 in src/utils.jl

Codecov / codecov/patch

src/utils.jl#L1105

Added line #L1105 was not covered by tests
end
return Expr(:||, false, out...)
end
function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName)
return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)

Check warning on line 1068 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1067-L1068

Added lines #L1067 - L1068 were not covered by tests
end
function nested_setindex_maybe!(
vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym}
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
vns = md.vns
if vn in vns
setindex!(vi, val, vn)
return vn

Check warning on line 1084 in src/varinfo.jl

Codecov / codecov/patch

src/varinfo.jl#L1083-L1084

Added lines #L1083 - L1084 were not covered by tests
end
# Otherwise, we need to check if either of the `vns` subsumes `vn`.