-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Illegal type analysis - BigFloat #1621
Comments
Ah yeah BigFloat's aren't presently expected to work. It's on our to do list to add support for, but honestly it's not high priority at the moment. |
There shouldn't be any big floats coming into play in the MWE though, it's all multivariate normals with Float64s. The problem seems to come from the same or related type instability as TuringLang/Turing.jl#1608 and TuringLang/DynamicPPL.jl#643. The |
Hm it definitely thinks there’s a code path that could call a big float — even if practically it’s not used. If you can minimize this a bit more I can work on making sure this error doesn’t happen |
MWE that only depends on Accessors and Distributions: module MWE
import Accessors
import Distributions
using Enzyme
#Enzyme.API.runtimeActivity!(true)
struct VarName{sym,T}
optic::T
function VarName{sym}(optic=identity) where {sym}
return new{sym,typeof(optic)}(optic)
end
end
function Base.:(==)(x::VarName{symx}, y::VarName{symy}) where {symx,symy}
return x.optic == y.optic && symx == symy
end
struct TypeWrap{T} end
struct VarInfo{Tval,Tlogp}
vals::Tval
logp::Base.RefValue{Tlogp}
end
function getindex(vi::VarInfo, vn::VarName)
range = vn == x_varname1 ? (1:2) : (3:4)
return copy(vi.vals[range])
end
VarInfo(old_vi::VarInfo, x) = VarInfo(x, Base.RefValue{eltype(x)}(old_vi.logp[]))
function tilde_assume!!(right, vn, vi)
y = [1.0, 1.0]
_ = Distributions.logpdf(right, y)
r = getindex(vi, vn)
logp = 0.0
vi.logp[] += logp
return r, vi
end
x_varname1 = VarName{:x}((Accessors.@optic _[:, 1]))
x_varname2 = VarName{:x}((Accessors.@optic _[:, 2]))
function satellite_model_matrix(__varinfo__, ::(TypeWrap){TV}) where {TV}
P0 = vcat([0.1 0.0], [0.0 0.1])
x = TV(undef, 2, 2)
v1 = Distributions.MvNormal([0.0, 0.0], P0)
v3, __varinfo__ = tilde_assume!!(v1, x_varname1, __varinfo__)
x[:, 1] .= v3
v1 = Distributions.MvNormal(x[:, 1], P0)
v4, __varinfo__ = tilde_assume!!(v1, x_varname2, __varinfo__)
x[:, 2] .= v4
return nothing, __varinfo__
end
vi = VarInfo(
[0.0, 0.0, 0.0, 0.0],
Base.RefValue{Float64}(0.0),
)
function g(x)
vi_new = VarInfo(vi, x)
_, wrapper_new = satellite_model_matrix(vi_new, TypeWrap{Matrix{Real}}())
return wrapper_new.logp[]
end
x = [1.0, 1.0, 1.0, 1.0]
Enzyme.autodiff(ReverseWithPrimal, g, Active, Enzyme.Duplicated(x, zero(x)))
# using ForwardDiff
# @show ForwardDiff.gradient(g, x)
end |
Fixed by #1658 |
Output:
On 0.0.132 and latest Enzyme.jl main.
I appreciate that's not very minimal, but I'm done for the day and wanted to put this up here. I can try to minimise further tomorrow if that's helpful.
This started life as a reproduction of TuringLang/Turing.jl#1608, but with the latest update to Enzyme stopped producing the error in TuringLang/Turing.jl#1608 and instead started spitting out the above.
The text was updated successfully, but these errors were encountered: