Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Illegal type analysis - BigFloat #1621

Closed
mhauru opened this issue Jul 8, 2024 · 5 comments
Closed

Illegal type analysis - BigFloat #1621

mhauru opened this issue Jul 8, 2024 · 5 comments

Comments

@mhauru
Copy link
Contributor

mhauru commented Jul 8, 2024

module MWE

import DynamicPPL
import AbstractPPL
import Accessors
using Distributions: MvNormal

using Enzyme

#Enzyme.API.runtimeActivity!(true)

function tilde_assume!!(right, vn, vi)
    _, _ = DynamicPPL.invlink_with_logpdf(vi, vn, right)
    r = vi[vn]
    logp = 0.0
    vi.logp[] += logp
    return r, vi
end

x_varname1 = AbstractPPL.VarName{:x}((Accessors.@optic _[:, 1]))
x_varname2 = AbstractPPL.VarName{:x}((Accessors.@optic _[:, 2]))

function satellite_model_matrix(__varinfo__::DynamicPPL.AbstractVarInfo, ::(DynamicPPL.TypeWrap){TV}) where {TV}
    P0 = vcat([0.1 0.0], [0.0 0.1])
    x = TV(undef, 2, 2)

    v1 = MvNormal([0.0, 0.0], P0)
    v3, __varinfo__ = tilde_assume!!(v1, x_varname1, __varinfo__)
    x[:, 1] .= v3

    v1 = MvNormal(x[:, 1], P0)
    v4, __varinfo__ = tilde_assume!!(v1, x_varname2, __varinfo__)
    x[:, 2] .= v4
    return nothing, __varinfo__
end

vi = DynamicPPL.VarInfo()
P0 = vcat([0.1 0.0], [0.0 0.1])
v1 = [0.0, 0.0]
DynamicPPL.push!!(vi, x_varname1, v1, MvNormal([0.0, 0.0], P0), DynamicPPL.SampleFromPrior())
DynamicPPL.push!!(vi, x_varname2, v1, MvNormal([0.0, 0.0], P0), DynamicPPL.SampleFromPrior())
vi = DynamicPPL.TypedVarInfo(vi)

function g(x)
    context = DynamicPPL.DefaultContext()
    vi_new = DynamicPPL.unflatten(vi, context, x)
    _, wrapper_new = satellite_model_matrix(vi_new, DynamicPPL.TypeWrap{Matrix{Real}}())
    return DynamicPPL.getlogp(wrapper_new)
end

x = [1.0, 1.0, 1.0, 1.0]
Enzyme.autodiff(ReverseWithPrimal, g, Active, Enzyme.Duplicated(x, zero(x)))
# using ForwardDiff
# @show ForwardDiff.gradient(g, x)

end

Output:

┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BigFloat
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
ERROR: LoadError: Enzyme compilation failed due to illegal type analysis.
Current scope:
; Function Attrs: mustprogress willreturn
define internal fastcc noalias nonnull dereferenceable(40) "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @preprocess_julia___1_28309(i64 signext "enzyme_inactive" "enzyme_type"="{[-1]:Integer}" "enzymejl_parmtype"="5091561456" "enzymejl_parmtype_ref"="0" %0) unnamed_addr TuringLang/Turing.jl#32 !dbg !686 {
top:
  %1 = call {}*** @julia.get_pgcstack() TuringLang/Turing.jl#33
  %ptls_field5 = getelementptr inbounds {}**, {}*** %1, i64 2
  %2 = bitcast {}*** %ptls_field5 to i64***
  %ptls_load67 = load i64**, i64*** %2, align 8, !tbaa !14
  %3 = getelementptr inbounds i64*, i64** %ptls_load67, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !18
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) TuringLang/Turing.jl#33, !dbg !687
  fence syncscope("singlethread") seq_cst
  %4 = icmp sgt i64 %0, 0, !dbg !688
  br i1 %4, label %L6, label %L3, !dbg !689

L3:                                               ; preds = %top
  %5 = call noalias nonnull "enzyme_inactive" "enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}" {} addrspace(10)* @ijl_box_int64(i64 signext %0) TuringLang/Turing.jl#34, !dbg !689
  %6 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)*, {} addrspace(10)*, {} addrspace(10)*, ...) @julia.call2({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* noundef nonnull @ijl_invoke, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5063918096 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5021202192 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %5, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5101639360 to {}*) to {} addrspace(10)*)) TuringLang/Turing.jl#35, !dbg !689
  %7 = addrspacecast {} addrspace(10)* %6 to {} addrspace(12)*, !dbg !689
  call void @ijl_throw({} addrspace(12)* %7) TuringLang/Turing.jl#36, !dbg !689
  unreachable, !dbg !689

L6:                                               ; preds = %top
  %current_task14 = getelementptr inbounds {}**, {}*** %1, i64 -14
  %current_task1 = bitcast {}*** %current_task14 to {}**
  %8 = call i64 @mpfr_custom_get_size(i64 %0) TuringLang/Turing.jl#33, !dbg !690
  %9 = add i64 %8, 7, !dbg !691
  %10 = and i64 %9, -8, !dbg !694
  %11 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* (i64, ...) @ijl_alloc_string(i64 %10) TuringLang/Turing.jl#33, !dbg !697
  %12 = addrspacecast {} addrspace(10)* %11 to {} addrspace(11)*, !dbg !698
  %13 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %12) TuringLang/Turing.jl#37, !dbg !698
  %14 = bitcast {}* %13 to {} addrspace(10)**, !dbg !698
  %15 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)** %14, i64 1, !dbg !698
  %string_ptr = ptrtoint {} addrspace(10)** %15 to i64, !dbg !698
  %newstruct = call noalias nonnull dereferenceable(40) "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 40, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5023625904 to {}*) to {} addrspace(10)*)) TuringLang/Turing.jl#38, !dbg !700
  %16 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)*, !dbg !700
  %17 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %16, i64 4, !dbg !700
  store {} addrspace(10)* null, {} addrspace(10)* addrspace(11)* %17, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  %18 = addrspacecast {} addrspace(10)* %newstruct to i64 addrspace(11)*, !dbg !700
  store i64 %0, i64 addrspace(11)* %18, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  %19 = addrspacecast {} addrspace(10)* %newstruct to i8 addrspace(11)*, !dbg !700
  %20 = getelementptr inbounds i8, i8 addrspace(11)* %19, i64 8, !dbg !700
  %memcpy_refined_dst = bitcast i8 addrspace(11)* %20 to i32 addrspace(11)*, !dbg !700
  store i32 1, i32 addrspace(11)* %memcpy_refined_dst, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  %21 = getelementptr inbounds i8, i8 addrspace(11)* %19, i64 16, !dbg !700
  %memcpy_refined_dst3 = bitcast i8 addrspace(11)* %21 to i64 addrspace(11)*, !dbg !700
  store i64 -9223372036854775806, i64 addrspace(11)* %memcpy_refined_dst3, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  %22 = getelementptr inbounds i8, i8 addrspace(11)* %19, i64 24, !dbg !700
  %23 = bitcast i8 addrspace(11)* %22 to i64 addrspace(11)*, !dbg !700
  store i64 %string_ptr, i64 addrspace(11)* %23, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  %24 = getelementptr inbounds i8, i8 addrspace(11)* %19, i64 32, !dbg !700
  %25 = bitcast i8 addrspace(11)* %24 to {} addrspace(10)* addrspace(11)*, !dbg !700
  store atomic {} addrspace(10)* %11, {} addrspace(10)* addrspace(11)* %25 release, align 8, !dbg !700, !tbaa !45, !alias.scope !49, !noalias !702
  ret {} addrspace(10)* %newstruct, !dbg !701
}

 Type analysis state:
<analysis>
  %6 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)*, {} addrspace(10)*, {} addrspace(10)*, ...) @julia.call2({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* noundef nonnull @ijl_invoke, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5063918096 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5021202192 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %5, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5101639360 to {}*) to {} addrspace(10)*)) TuringLang/Turing.jl#35, !dbg !24: {[-1]:Pointer}, intvals: {}
  %newstruct = call noalias nonnull dereferenceable(40) "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 40, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 5023625904 to {}*) to {} addrspace(10)*)) TuringLang/Turing.jl#38, !dbg !42: {[-1]:Pointer, [-1,0]:Float@double}, intvals: {}
{} addrspace(10)* null: {[-1]:Pointer, [-1,-1]:Anything}, intvals: {0,}
  %1 = call {}*** @julia.get_pgcstack() TuringLang/Turing.jl#33: {}, intvals: {}
  %3 = getelementptr inbounds i64*, i64** %ptls_load67, i64 2: {[-1]:Pointer}, intvals: {}
  %11 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* (i64, ...) @ijl_alloc_string(i64 %10) TuringLang/Turing.jl#33, !dbg !34: {[-1]:Pointer}, intvals: {}
  %8 = call i64 @mpfr_custom_get_size(i64 %0) TuringLang/Turing.jl#33, !dbg !25: {}, intvals: {}
  %5 = call noalias nonnull "enzyme_inactive" "enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}" {} addrspace(10)* @ijl_box_int64(i64 signext %0) TuringLang/Turing.jl#34, !dbg !24: {[-1]:Pointer, [-1,-1]:Integer}, intvals: {}
  %ptls_field5 = getelementptr inbounds {}**, {}*** %1, i64 2: {}, intvals: {}
  %15 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)** %14, i64 1, !dbg !37: {[-1]:Pointer}, intvals: {}
  %17 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %16, i64 4, !dbg !42: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
  %current_task14 = getelementptr inbounds {}**, {}*** %1, i64 -14: {}, intvals: {}
  %13 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %12) TuringLang/Turing.jl#37, !dbg !37: {[-1]:Pointer}, intvals: {}
  %current_task1 = bitcast {}*** %current_task14 to {}**: {}, intvals: {}
  %14 = bitcast {}* %13 to {} addrspace(10)**, !dbg !37: {[-1]:Pointer}, intvals: {}
  %2 = bitcast {}*** %ptls_field5 to i64***: {[-1]:Pointer}, intvals: {}
  %ptls_load67 = load i64**, i64*** %2, align 8, !tbaa !14: {}, intvals: {}
  %string_ptr = ptrtoint {} addrspace(10)** %15 to i64, !dbg !37: {[-1]:Pointer}, intvals: {}
  %16 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)*, !dbg !42: {[-1]:Pointer, [-1,0]:Float@double}, intvals: {}
  %safepoint = load i64*, i64** %3, align 8, !tbaa !18: {}, intvals: {}
  %18 = addrspacecast {} addrspace(10)* %newstruct to i64 addrspace(11)*, !dbg !42: {[-1]:Pointer, [-1,0]:Float@double}, intvals: {}
  %12 = addrspacecast {} addrspace(10)* %11 to {} addrspace(11)*, !dbg !37: {[-1]:Pointer}, intvals: {}
i64 7: {[-1]:Integer}, intvals: {7,}
i64 -8: {[-1]:Integer}, intvals: {-8,}
i64 0: {[-1]:Anything}, intvals: {0,}
i64 %0: {[-1]:Integer}, intvals: {}
  %4 = icmp sgt i64 %0, 0, !dbg !21: {[-1]:Integer}, intvals: {}
  %9 = add i64 %8, 7, !dbg !26: {}, intvals: {}
  %10 = and i64 %9, -8, !dbg !30: {}, intvals: {}
</analysis>

Illegal updateAnalysis prev:{[-1]:Pointer, [-1,0]:Float@double} new: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}
val:   %18 = addrspacecast {} addrspace(10)* %newstruct to i64 addrspace(11)*, !dbg !42 origin=  store i64 %0, i64 addrspace(11)* %18, align 8, !dbg !42, !tbaa !45, !alias.scope !49, !noalias !52
MethodInstance for (::Base.MPFR.var"#_#1#2")(::Int64, ::Type{BigFloat})


Caused by:
Stacktrace:
 [1] _BigFloat
   @ ./mpfr.jl:119
 [2] _
   @ ./mpfr.jl:129

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:1996
  [2] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/dev/Enzyme/src/api.jl:192
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:3673
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5867
  [5] codegen
    @ ~/.julia/dev/Enzyme/src/compiler.jl:5143 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6674
  [7] _thunk
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6674 [inlined]
  [8] cached_compilation
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6712 [inlined]
  [9] (::Enzyme.Compiler.var"#28595#28596"{…})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6781
 [10] JuliaContext(f::Enzyme.Compiler.var"#28595#28596"{…}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:52
 [11] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:42
 [12] #s2010#28594
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6732 [inlined]
 [13]
    @ Enzyme.Compiler ./none:0
 [14] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [15] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Base.Broadcast.copyto_nonleaf!), df::Nothing, primal_1::Vector{…}, shadow_1_1::Vector{…}, primal_2::Base.Broadcast.Broadcasted{…}, shadow_2_1::Base.Broadcast.Broadcasted{…}, primal_3::Base.OneTo{…}, shadow_3_1::Nothing, primal_4::Int64, shadow_4_1::Nothing, primal_5::Int64, shadow_5_1::Nothing)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/rules/jitrules.jl:307
 [16] copy
    @ ./broadcast.jl:950 [inlined]
 [17] materialize
    @ ./broadcast.jl:903 [inlined]
 [18] sqmahal
    @ ~/.julia/packages/Distributions/ji8PW/src/multivariate/mvnormal.jl:267
 [19] _logpdf
    @ ~/.julia/packages/Distributions/ji8PW/src/multivariate/mvnormal.jl:143
 [20] logpdf
    @ ~/.julia/packages/Distributions/ji8PW/src/common.jl:263 [inlined]
 [21] invlink_with_logpdf
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/abstract_varinfo.jl:856
 [22] invlink_with_logpdf
    @ ~/.julia/packages/DynamicPPL/ACaKr/src/abstract_varinfo.jl:850 [inlined]
 [23] tilde_assume!!
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:16 [inlined]
 [24] tilde_assume!!
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:0 [inlined]
 [25] augmented_julia_tilde_assume___28212_inner_1wrap
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:0
 [26] macro expansion
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6622 [inlined]
 [27] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::EnzymeCore.Const{…}, ::Type{…}, ::EnzymeCore.Duplicated{…}, ::EnzymeCore.Const{…}, ::EnzymeCore.Duplicated{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6223
 [28] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::EnzymeCore.Const{…}, ::EnzymeCore.Duplicated{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6111
 [29] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Main.MWE.tilde_assume!!), df::Nothing, primal_1::Distributions.MvNormal{…}, shadow_1_1::Distributions.MvNormal{…}, primal_2::AbstractPPL.VarName{…}, shadow_2_1::Nothing, primal_3::DynamicPPL.TypedVarInfo{…}, shadow_3_1::DynamicPPL.TypedVarInfo{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/rules/jitrules.jl:311
 [30] satellite_model_matrix
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:35 [inlined]
 [31] satellite_model_matrix
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:0 [inlined]
 [32] augmented_julia_satellite_model_matrix_27657_inner_1wrap
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:0
 [33] macro expansion
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6622 [inlined]
 [34] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::EnzymeCore.Const{…}, ::Type{…}, ::EnzymeCore.Duplicated{…}, ::EnzymeCore.Const{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6223
 [35] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::EnzymeCore.Const{…}, ::EnzymeCore.Duplicated{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:6111
 [36] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Main.MWE.satellite_model_matrix), df::Nothing, primal_1::DynamicPPL.TypedVarInfo{…}, shadow_1_1::DynamicPPL.TypedVarInfo{…}, primal_2::DynamicPPL.TypeWrap{…}, shadow_2_1::Nothing)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/rules/jitrules.jl:311
 [37] g
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:50 [inlined]
 [38] augmented_julia_g_27631wrap
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:0
 [39] macro expansion
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6622 [inlined]
 [40] enzyme_call
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6223 [inlined]
 [41] AugmentedForwardThunk
    @ ~/.julia/dev/Enzyme/src/compiler.jl:6111 [inlined]
 [42] autodiff
    @ ~/.julia/dev/Enzyme/src/Enzyme.jl:253 [inlined]
 [43] autodiff(mode::EnzymeCore.ReverseMode{…}, f::typeof(Main.MWE.g), ::Type{…}, args::EnzymeCore.Duplicated{…})
    @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:321
 [44] top-level scope
    @ ~/projects/Enzyme-mwes/ref_minus_float/mwe.jl:55
 [45] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [46] top-level scope
    @ REPL[11]:1
in expression starting at /Users/mhauru/projects/Enzyme-mwes/ref_minus_float/mwe.jl:4
Some type information was truncated. Use `show(err)` to see complete types.

On 0.0.132 and latest Enzyme.jl main.

I appreciate that's not very minimal, but I'm done for the day and wanted to put this up here. I can try to minimise further tomorrow if that's helpful.

This started life as a reproduction of TuringLang/Turing.jl#1608, but with the latest update to Enzyme stopped producing the error in TuringLang/Turing.jl#1608 and instead started spitting out the above.

@wsmoses
Copy link
Member

wsmoses commented Jul 8, 2024

Ah yeah BigFloat's aren't presently expected to work. It's on our to do list to add support for, but honestly it's not high priority at the moment.

@mhauru
Copy link
Contributor Author

mhauru commented Jul 9, 2024

There shouldn't be any big floats coming into play in the MWE though, it's all multivariate normals with Float64s.

The problem seems to come from the same or related type instability as TuringLang/Turing.jl#1608 and TuringLang/DynamicPPL.jl#643. The TV variable in the above function has a type of Matrix{Real}, even though it should be inferrable to Matrix{Float64}, and if you fix it to Matrix{Float64} the issue goes away.

@wsmoses
Copy link
Member

wsmoses commented Jul 9, 2024

Hm it definitely thinks there’s a code path that could call a big float — even if practically it’s not used.

If you can minimize this a bit more I can work on making sure this error doesn’t happen

@mhauru
Copy link
Contributor Author

mhauru commented Jul 11, 2024

MWE that only depends on Accessors and Distributions:

module MWE

import Accessors
import Distributions

using Enzyme

#Enzyme.API.runtimeActivity!(true)

struct VarName{sym,T}
    optic::T

    function VarName{sym}(optic=identity) where {sym}
        return new{sym,typeof(optic)}(optic)
    end
end

function Base.:(==)(x::VarName{symx}, y::VarName{symy}) where {symx,symy}
    return x.optic == y.optic && symx == symy
end

struct TypeWrap{T} end

struct VarInfo{Tval,Tlogp}
    vals::Tval
    logp::Base.RefValue{Tlogp}
end

function getindex(vi::VarInfo, vn::VarName)
    range = vn == x_varname1 ? (1:2) : (3:4)
    return copy(vi.vals[range])
end

VarInfo(old_vi::VarInfo, x) = VarInfo(x, Base.RefValue{eltype(x)}(old_vi.logp[]))

function tilde_assume!!(right, vn, vi)
    y = [1.0, 1.0]
    _ = Distributions.logpdf(right, y)
    r = getindex(vi, vn)
    logp = 0.0
    vi.logp[] += logp
    return r, vi
end

x_varname1 = VarName{:x}((Accessors.@optic _[:, 1]))
x_varname2 = VarName{:x}((Accessors.@optic _[:, 2]))

function satellite_model_matrix(__varinfo__, ::(TypeWrap){TV}) where {TV}
    P0 = vcat([0.1 0.0], [0.0 0.1])
    x = TV(undef, 2, 2)

    v1 = Distributions.MvNormal([0.0, 0.0], P0)
    v3, __varinfo__ = tilde_assume!!(v1, x_varname1, __varinfo__)
    x[:, 1] .= v3

    v1 = Distributions.MvNormal(x[:, 1], P0)
    v4, __varinfo__ = tilde_assume!!(v1, x_varname2, __varinfo__)
    x[:, 2] .= v4
    return nothing, __varinfo__
end

vi = VarInfo(
    [0.0, 0.0, 0.0, 0.0],
    Base.RefValue{Float64}(0.0),
)

function g(x)
    vi_new = VarInfo(vi, x)
    _, wrapper_new = satellite_model_matrix(vi_new, TypeWrap{Matrix{Real}}())
    return wrapper_new.logp[]
end

x = [1.0, 1.0, 1.0, 1.0]
Enzyme.autodiff(ReverseWithPrimal, g, Active, Enzyme.Duplicated(x, zero(x)))
# using ForwardDiff
# @show ForwardDiff.gradient(g, x)

end

@wsmoses
Copy link
Member

wsmoses commented Jul 21, 2024

Fixed by #1658

@wsmoses wsmoses closed this as completed Jul 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants