Skip to content

Commit

Permalink
use special construct to hijack assume for :=
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Apr 25, 2024
1 parent 503d4de commit b1c3206
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
1 change: 0 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
end

function generate_assign(left, right)
right_expr = :($(Distributions.Dirac)($right))
tilde_expr = generate_tilde(left, right_expr)
return quote
if $(is_extracting_values)(__context__)
Expand Down
22 changes: 20 additions & 2 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
struct TrackedValue{T}
value::T
end

is_tracked_value(::TrackedValue) = true
is_tracked_value(::Any) = false

check_tilde_rhs(x::TrackedValue) = x

"""
ValuesAsInModelContext
Expand Down Expand Up @@ -55,7 +63,12 @@ end

# `tilde_asssume`
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
if is_tracked_value(right)
value = right.value
logp = zero(getlogp(vi))
else
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
end
# Save the value.
push!(context, vn, value)
# Save the value.
Expand All @@ -65,7 +78,12 @@ end
function tilde_assume(
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
)
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
if is_tracked_value(right)
value = right.value
logp = zero(getlogp(vi))
else
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
end
# Save the value.
push!(context, vn, value)
# Pass on.
Expand Down

0 comments on commit b1c3206

Please sign in to comment.