Skip to content

Commit

Permalink
more fixes to link and invlink
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 7, 2023
1 parent 4de2a01 commit 1e4d9f1
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -909,30 +909,30 @@ end
function _link(varinfo::UntypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
return VarInfo(
_link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))),
_link_metadata!(varinfo, varinfo.metadata, _getvns(spl)),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)
end

function _link(varinfo::TypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
md = _link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl)))
# TODO: Update logp, etc.
md = _link_metadata_namedtuple!(varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl)))
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
end

@generated function _link_metadata!(
@generated function _link_metadata_namedtuple!(
varinfo::VarInfo,
metadata::NamedTuple{names},
vns::NamedTuple,
::Val{space}
) where {names,space}
vals = Expr(:tuple)
for f in names
if inspace(f, space) || length(space) == 0
push!(
expr.args,
:(_link_metadata!(varinfo, metadata.$f))
vals.args,
:(_link_metadata!(varinfo, metadata.$f, vns.$f))
)
else
push!(vals.args, :(metadata.$f))
Expand All @@ -941,13 +941,13 @@ end

return :(NamedTuple{$names}($vals))
end
function _link_metadata!(varinfo::VarInfo, metadata::Metadata)
function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns)
vns = metadata.vns

# Construct the new transformed values, and keep track of their lengths.
vals_new = map(vns) do vn
# Return early if we're already in unconstrained space.
if istrans(varinfo, vn)
if istrans(varinfo, vn) || vn target_vns
return metadata.vals[getrange(metadata, vn)]
end

Expand Down Expand Up @@ -997,30 +997,30 @@ end
function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
return VarInfo(
_invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))),
_invlink_metadata!(varinfo, varinfo.metadata, _getvns(spl)),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)
end

function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
md = _invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl)))
# TODO: Update logp, etc.
md = _invlink_metadata_namedtuple!(varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl)))
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
end

@generated function _invlink_metadata!(
@generated function _invlink_metadata_namedtuple!(
varinfo::VarInfo,
metadata::NamedTuple{names},
vns::NamedTuple,
::Val{space}
) where {names,space}
vals = Expr(:tuple)
for f in names
if inspace(f, space) || length(space) == 0
push!(
expr.args,
:(_invlink_metadata!(varinfo, metadata.$f))
vals.args,
:(_invlink_metadata!(varinfo, metadata.$f, vns.$f))
)
else
push!(vals.args, :(metadata.$f))
Expand All @@ -1029,13 +1029,14 @@ end

return :(NamedTuple{$names}($vals))
end
function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata)
function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns)
vns = metadata.vns

# Construct the new transformed values, and keep track of their lengths.
vals_new = map(vns) do vn
# Return early if we're already in constrained space.
if !istrans(varinfo, vn)
# Return early if we're already in constrained space OR if we're not
# supposed to touch this `vn`.
if !istrans(varinfo, vn) || vn target_vns
return metadata.vals[getrange(metadata, vn)]
end

Expand Down

0 comments on commit 1e4d9f1

Please sign in to comment.