Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segfault when calling rand with Bijectors #2074

Closed
Red-Portal opened this issue Nov 8, 2024 · 24 comments · Fixed by #2117
Closed

Segfault when calling rand with Bijectors #2074

Red-Portal opened this issue Nov 8, 2024 · 24 comments · Fixed by #2117

Comments

@Red-Portal
Copy link

Red-Portal commented Nov 8, 2024

Hi!

The following code segfaults on 1.10:

using Bijectors
using LinearAlgebra
using Functors
using Optimisers
using Enzyme
using Random, StableRNGs

struct TestProb1 end

logdensity(::TestProb1, θ) = sum(abs2, θ)

function Bijectors.bijector(::TestProb1)
    return Bijectors.Stacked(
        [Base.Fix1(broadcast, log), identity],
        [1:1, 2:3],
    )
end

struct TestProb2 end

logdensity(::TestProb2, θ) = sum(abs2, θ)

struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <:
       ContinuousMultivariateDistribution
    location::L
    scale::S
    dist::D
    scale_eps::E
end

Base.length(q::MvLocationScale) = length(q.location)

Functors.@functor MvLocationScale (location, scale)

# This specialization improves AD performance of the sampling path
function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{<:Diagonal,L}, num_samples::Int
) where {L}
    (; location, scale) = q
    n_dims = length(location)
    scale_diag = diag(scale)
    return scale_diag .* randn(rng, n_dims, num_samples) .+ location
end

function restructure_ad_forward(restructure, params)
    return restructure(params)::typeof(restructure.model)
end

function estimate_repgradelbo_ad_forward(params′, aux)
    (; rng, problem, restructure) = aux
    q = restructure_ad_forward(restructure, params′)
    zs = rand(rng, q, 10)
    return mean(Base.Fix1(logdensity, problem), eachcol(zs))
end

function main()
    d = 5

    seed = (0x38bef07cf9cc549d)
    rng = StableRNG(seed)

    for prob in [TestProb1(), TestProb2()]
        q = if prob isa TestProb1
            MvLocationScale(zeros(d), Diagonal(ones(d)), Normal(), 1e-5)
        else
            Bijectors.TransformedDistribution(
                MvLocationScale(zeros(d), Diagonal(ones(d)), Normal(), 1e-5),
                inverse(
                    Bijectors.Stacked(
                        [Base.Fix1(broadcast, log), identity],
                        [1:1, 2:d],
                    )
                )
            )
        end

        params, re = Optimisers.destructure(q)
        buf = zero(params)
        aux = (rng=rng, problem=prob, restructure=re)
        Enzyme.autodiff(
            set_runtime_activity(Enzyme.ReverseWithPrimal, true),
            estimate_repgradelbo_ad_forward,
            Enzyme.Active,
            Enzyme.Duplicated(params, buf),
            Enzyme.Const(aux),
        )
    end
end

This bug is very sensitive, and very seemingly minor changes (like changing the order of TestProb1 and TestProb2) immediately make it go away. As such it was pretty hard to contain, but the above seems to do. Below is the segfault error message.

[9153] signal (11.128): Segmentation fault
in expression starting at REPL[2]:1
runtime_generic_augfwd at /home/krkim/.julia/packages/Enzyme/RvNgp/src/rules/jitrules.jl:486
unknown function (ip: 0x7f0e60997750)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
estimate_repgradelbo_ad_forward at /home/krkim/.julia/dev/AdvancedVI/test/scratch2.jl:64 [inlined]
estimate_repgradelbo_ad_forward at /home/krkim/.julia/dev/AdvancedVI/test/scratch2.jl:0 [inlined]
augmented_julia_estimate_repgradelbo_ad_forward_7697_inner_1wrap at /home/krkim/.julia/dev/AdvancedVI/test/scratch2.jl:0
macro expansion at /home/krkim/.julia/packages/Enzyme/RvNgp/src/compiler.jl:8305 [inlined]
enzyme_call at /home/krkim/.julia/packages/Enzyme/RvNgp/src/compiler.jl:7868 [inlined]
AugmentedForwardThunk at /home/krkim/.julia/packages/Enzyme/RvNgp/src/compiler.jl:7705 [inlined]
autodiff at /home/krkim/.julia/packages/Enzyme/RvNgp/src/Enzyme.jl:384
unknown function (ip: 0x7f0e60db7dd1)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
autodiff at /home/krkim/.julia/packages/Enzyme/RvNgp/src/Enzyme.jl:512 [inlined]
main at /home/krkim/.julia/dev/AdvancedVI/test/scratch2.jl:94
unknown function (ip: 0x7f0ed2d17c92)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
do_call at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:126
eval_value at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:223
eval_stmt_value at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:174 [inlined]
eval_body at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:617
jl_interpret_toplevel_thunk at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:775
jl_toplevel_eval_flex at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/toplevel.c:934
jl_toplevel_eval_flex at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/toplevel.c:877
jl_toplevel_eval_flex at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/toplevel.c:877
ijl_toplevel_eval_in at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/toplevel.c:985
eval at ./boot.jl:385 [inlined]
eval_user_input at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
repl_backend_loop at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
#start_repl_backend#46 at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
start_repl_backend at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:228
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
#run_repl#59 at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
run_repl at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
jfptr_run_repl_91949.1 at /home/krkim/.julia/juliaup/julia-1.10.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
#1013 at ./client.jl:437
jfptr_YY.1013_82918.1 at /home/krkim/.julia/juliaup/julia-1.10.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
jl_f__call_latest at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/builtins.c:812
#invokelatest#2 at ./essentials.jl:892 [inlined]
invokelatest at ./essentials.jl:889 [inlined]
run_main_repl at ./client.jl:421
exec_options at ./client.jl:338
_start at ./client.jl:557
jfptr__start_82944.1 at /home/krkim/.julia/juliaup/julia-1.10.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
true_main at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/jlapi.c:582
jl_repl_entrypoint at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/jlapi.c:731
main at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/cli/loader_exe.c:58
unknown function (ip: 0x7f0edf07ce07)
__libc_start_main at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 138892172 (Pool: 138519917; Big: 372255); GC: 197
zsh: segmentation fault (core dumped)  julia
@wsmoses
Copy link
Member

wsmoses commented Nov 8, 2024

hm so sadly this does not err on my computer

@Red-Portal
Copy link
Author

@wsmoses Hmm let me check a few things on my system. As a last resort, would it be useful for you if I can reproduce this in a Docker container?

@Red-Portal
Copy link
Author

Starting from a fresh .julia did the same. I created a docker container where I can reproduce the segfault. You can access it through:

docker pull kyrkim/enzymeissue2074
docker run -it kyrkim/enzymeissue2074  bash

And then copy-paste the snippet above on the pre-installed julia REPL.

@Red-Portal
Copy link
Author

@wsmoses Would this be enough to reproduce on your end? This bug is breaking all the Enzyme tests in AdvancedVI, so it would be really great to have it fixed.

@wsmoses
Copy link
Member

wsmoses commented Nov 16, 2024

I managed to repro and slightly reduce with the following, but I'll need more help reducing:

using Bijectors
using LinearAlgebra
using Functors
using Optimisers
using Enzyme
using Random, StableRNGs

struct TestProb1 end

logdensity(::TestProb1, θ) = sum(θ)

function Bijectors.bijector(::TestProb1)
    return Bijectors.Stacked(
        [Base.Fix1(broadcast, log), identity],
        [1:1, 2:3],
    )
end

struct TestProb2 end

logdensity(::TestProb2, θ) = sum(θ)

struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <:
       ContinuousMultivariateDistribution
    location::L
    scale::S
    dist::D
    scale_eps::E
end

Base.length(q::MvLocationScale) = length(q.location)

Functors.@functor MvLocationScale (location, scale)

# This specialization improves AD performance of the sampling path
function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{<:Diagonal,L}, num_samples::Int
) where {L}
    (; location, scale) = q
    n_dims = length(location)
    scale_diag = diag(scale)
    return randn(rng, n_dims, num_samples) 
end

function estimate_repgradelbo_ad_forward(params, aux)
    (; rng, problem, restructure) = aux
    q = restructure(params)
    zs = rand(rng, q, 10)
    return mean(Base.Fix1(logdensity, problem), eachcol(zs))
end

function main()
    d = 5

    seed = (0x38bef07cf9cc549d)
    rng = StableRNG(seed)

    for prob in [TestProb1(), TestProb2()]
        q = if prob isa TestProb1
            MvLocationScale(zeros(d), Diagonal(ones(d)), Normal(), 1e-5)
        else
            Bijectors.TransformedDistribution(
                MvLocationScale(zeros(d), Diagonal(ones(d)), Normal(), 1e-5),
                inverse(
                    Bijectors.Stacked(
                        [Base.Fix1(broadcast, log), identity],
                        [1:1, 2:d],
                    )
                )
            )
        end

        params, re = Optimisers.destructure(q)
        buf = zero(params)
        aux = (rng=rng, problem=prob, restructure=re)
        Enzyme.autodiff(
            set_runtime_activity(Enzyme.ReverseWithPrimal, true),
            estimate_repgradelbo_ad_forward,
            Enzyme.Active,
            Enzyme.Duplicated(params, buf),
            Enzyme.Const(aux),
        )
    end
end

main()

@Red-Portal
Copy link
Author

@wsmoses Unfortunately, your version does not fail on my system 😓

@wsmoses
Copy link
Member

wsmoses commented Nov 17, 2024

Naturally

@Red-Portal
Copy link
Author

Oh adding back the ::typeof(restructure.model) brings back the sorrows. Let me try to axe it from here.

@Red-Portal
Copy link
Author

Hope the following works on your end:

using Bijectors
using LinearAlgebra
using Functors
using Optimisers
using Enzyme
using Random

struct TestProb1 end

logdensity(::TestProb1, θ) = sum(θ)

function Bijectors.bijector(::TestProb1)
    return Bijectors.Stacked(
        [Base.Fix1(broadcast, log), identity],
        [1:1, 2:3],
    )
end

struct TestProb2 end

logdensity(::TestProb2, θ) = sum(θ)

struct MvLocationScale{L} <: ContinuousMultivariateDistribution
    location::L
end

Base.length(q::MvLocationScale) = length(q.location)

Functors.@functor MvLocationScale (location,)

# This specialization improves AD performance of the sampling path
function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale, num_samples::Int
)
    (; location,) = q
    n_dims = length(location)
    return randn(rng, n_dims, num_samples) 
end

function estimate_repgradelbo_ad_forward(params, aux)
    (; rng, problem, restructure) = aux
    q = restructure(params)
    zs = rand(rng, q, 10)
    return mean(Base.Fix1(logdensity, problem), eachcol(zs))
end

function main()
    d = 5
    rng = Random.default_rng()

    for prob in [TestProb1(), TestProb2()]
        q = if prob isa TestProb1
            MvLocationScale(zeros(d))
        else
            Bijectors.TransformedDistribution(
                MvLocationScale(zeros(d)),
                inverse(
                    Bijectors.Stacked(
                        [Base.Fix1(broadcast, log), identity],
                        [1:1, 2:d],
                    )
                )
            )
        end

        params, re = Optimisers.destructure(q)
        buf = zero(params)
        aux = (rng=rng, problem=prob, restructure=re)
        Enzyme.autodiff(
            set_runtime_activity(Enzyme.ReverseWithPrimal, true),
            estimate_repgradelbo_ad_forward,
            Enzyme.Active,
            Enzyme.Duplicated(params, buf),
            Enzyme.Const(aux),
        )
    end
end

main()

@wsmoses
Copy link
Member

wsmoses commented Nov 17, 2024

Still segfaults for me!

Any chance you can reduce further (and also ideally get rid of bijectors)

@Red-Portal
Copy link
Author

@wsmoses I strongly suspect Bijectors is the offender here; the tests not involving Bijectors never failed. I could try opening up Bijectors, though that might take some time.

@Red-Portal
Copy link
Author

Red-Portal commented Nov 17, 2024

Bingo. I got it distilled.

using Statistics
using Base.Iterators
using LinearAlgebra
using Functors
using Optimisers
using Enzyme
using Random

struct TransformedDistribution{D,B}
    dist::D
    transform::B
end

Functors.@functor TransformedDistribution

function rand(rng::AbstractRNG, td::TransformedDistribution, num_samples::Int)
    samples = rand(rng, td.dist, num_samples)
    res = reduce(
        hcat,
        map(axes(samples, 2)) do i
            return td.transform(view(samples, :, i))
        end,
    )
    return res
end

struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}}
    bs::Bs
    ranges_in::Rs
    ranges_out::Rs
    length_in::Int
    length_out::Int
end

function mapvcat(f, args...)
    out = map(f, args...)
    init = vcat(out[1])
    return reduce(vcat, drop(out, 1); init=init)
end

@generated function _transform_stacked_recursive(
    x, rs::NTuple{N,UnitRange{Int}}, bs...
) where {N}
    exprs = []
    for i in 1:N
        push!(exprs, :(bs[$i](x[rs[$i]])))
    end
    return :(vcat($(exprs...)))
end

function _transform_stacked_recursive(x, rs::NTuple{1,UnitRange{Int}}, b)
    rs[1] == 1:length(x) || error("range must be 1:length(x)")
    return b(x)
end

function _transform_stacked(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real})
    y = _transform_stacked_recursive(x, sb.ranges_in, sb.bs...)
    return y
end

function _transform_stacked(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real})
    N = length(sb.bs)
    N == 1 && return sb.bs[1](x[sb.ranges_in[1]])

    y = mapvcat(1:N) do i
        sb.bs[i](x[sb.ranges_in[i]])
    end
    return y
end

function (sb::Stacked)(x::AbstractVector{<:Real})
    y = _transform_stacked(sb, x)
    return y
end

struct TestProb1 end

logdensity(::TestProb1, θ) = sum(θ)

struct TestProb2 end

logdensity(::TestProb2, θ) = sum(θ)

struct MvLocationScale{L} <: ContinuousMultivariateDistribution
    location::L
end

Base.length(q::MvLocationScale) = length(q.location)

Functors.@functor MvLocationScale (location,)

# This specialization improves AD performance of the sampling path
function rand(
    rng::AbstractRNG, q::MvLocationScale, num_samples::Int
)
    (; location,) = q
    n_dims = length(location)
    return randn(rng, n_dims, num_samples) 
end

function estimate_repgradelbo_ad_forward(params, aux)
    (; rng, problem, restructure) = aux
    q = restructure(params)
    zs = rand(rng, q, 10)
    return mean(Base.Fix1(logdensity, problem), eachcol(zs))
end

function main()
    d = 5
    rng = Random.default_rng()

    for prob in [TestProb1(), TestProb2()]
        q = if prob isa TestProb1
            MvLocationScale(zeros(d))
        else
            TransformedDistribution(
                MvLocationScale(zeros(d)),
                Stacked(
                    [Base.Fix1(broadcast, exp), identity],
                    [1:1, 2:d],
                    [1:1, 2:d],
                    d, d,
                )
            )
        end

        params, re = Optimisers.destructure(q)
        buf = zero(params)
        aux = (rng=rng, problem=prob, restructure=re)
        Enzyme.autodiff(
            set_runtime_activity(Enzyme.ReverseWithPrimal, true),
            estimate_repgradelbo_ad_forward,
            Enzyme.Active,
            Enzyme.Duplicated(params, buf),
            Enzyme.Const(aux),
        )
    end
end

main()

The Statistics dependency has to stay because mean is necessary for some bizarre reason. (Realized this the hard way.)

@wsmoses
Copy link
Member

wsmoses commented Nov 17, 2024

Statistics is fine. Would it be possible to also get rid of functors and optimizers?

@Red-Portal
Copy link
Author

That... I am afraid it is going to be too much of a pain.

@wsmoses
Copy link
Member

wsmoses commented Nov 17, 2024

Let me see if we can debug the segfault as is (this will likely generate hundreds of thousands of instructions, so any simplification here will be immensely helpful, and also reducing the dependencies will make sure we can add this as a test)

@Red-Portal
Copy link
Author

Red-Portal commented Nov 17, 2024

I could try to do it, but restructure/destructure take up a big portion of what Functors and Optimisers are built to do, so the MWE will not be very minimal. It would be helpful if I could replicate the segfaults on your end, but since this isn't the case, I can't easily try to weed out irrelevant parts of Functors and Optimisere. But let me give it a shot later today.

@wsmoses
Copy link
Member

wsmoses commented Nov 17, 2024

sounds good! and yeah sorry this is so much of a pain (sadly this is usually what segfaults are now, and often end up as bugs in Julia itself =/)

@Red-Portal
Copy link
Author

Huge pain incoming:

using Statistics
using Base.Iterators
using LinearAlgebra
using Enzyme
using Random

struct FunctionConstructor{F} end

_isgensym(s::Symbol) = occursin("#", string(s))

@generated function (fc::FunctionConstructor{F})(args...) where F
    isempty(args) && return Expr(:new, F)

    T = getfield(parentmodule(F), nameof(F))
    # We assume all gensym names are anonymous functions
    _isgensym(nameof(F)) || return :($T(args...))
    # Define `new` for rebuilt function type that matches args
    exp = Expr(:new, Expr(:curly, T, args...))
    for i in 1:length(args)
        push!(exp.args, :(args[$i]))
    end
    return exp
end

const NoChildren = Tuple{}

function constructorof(f::Type{F}) where F <: Function
    FunctionConstructor{F}()
end

_vec(x::Number) = LinRange(x,x,1)
_vec(x::AbstractArray) = vec(x)

struct ExcludeWalk{T, F, G}
  walk::T
  fn::F
  exclude::G
end

_map(f, x::Dict, ys...) = Dict(k => f(v, (y[k] for y in ys)...) for (k, v) in x)
_map(f, x::D, ys...) where {D<:AbstractDict} = 
  constructorof(D)([k => f(v, (y[k] for y in ys)...) for (k, v) in x]...)

struct DefaultWalk end

function (::DefaultWalk)(recurse, x, ys...)
  func, re = functor(x)
  yfuncs = map(y -> functor(typeof(x), y)[1], ys)
  re(_map(recurse, func, yfuncs...))
end

(walk::ExcludeWalk)(recurse, x, ys...) =
  walk.exclude(x) ? walk.fn(x, ys...) : walk.walk(recurse, x, ys...)

struct NoKeyword end

struct CachedWalk{T, S, C <: AbstractDict}
  walk::T
  prune::S
  cache::C
end

CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) =
  CachedWalk(walk, prune, cache)

function (walk::CachedWalk)(recurse, x, ys...)
  should_cache = usecache(walk.cache, x)
  if should_cache && haskey(walk.cache, x)
    return walk.prune isa NoKeyword ? cacheget(walk.cache, x, recurse, x, ys...) : walk.prune
  else
    ret = walk.walk(recurse, x, ys...)
    if should_cache
      walk.cache[x] = ret
    end
    return ret
  end
end

@generated function anymutable(x::T) where {T}
  ismutabletype(T) && return true
  fns = QuoteNode.(filter(n -> fieldtype(T, n) != T, fieldnames(T)))
  subs =  [:(anymutable(getfield(x, $f))) for f in fns]
  return Expr(:(||), subs...)
end

usecache(::Union{AbstractDict, AbstractSet}, x) =
  isleaf(x) ? anymutable(x) : ismutable(x)
usecache(::Nothing, x) = false

struct WalkCache{K, V, W, C <: AbstractDict{K, V}} <: AbstractDict{K, V}
  walk::W
  cache::C
  WalkCache(walk, cache::AbstractDict{K, V} = IdDict()) where {K, V} = new{K, V, typeof(walk), typeof(cache)}(walk, cache)
end
Base.length(cache::WalkCache) = length(cache.cache)
Base.empty!(cache::WalkCache) = empty!(cache.cache)
Base.haskey(cache::WalkCache, x) = haskey(cache.cache, x)
Base.get(cache::WalkCache, x, default) = haskey(cache.cache, x) ? cache[x] : default
Base.iterate(cache::WalkCache, state...) = iterate(cache.cache, state...)
Base.setindex!(cache::WalkCache, value, key) = setindex!(cache.cache, value, key)
Base.getindex(cache::WalkCache, x) = cache.cache[x]

function functor(T, x)
  names = fieldnames(T)
  if isempty(names)
    return NoChildren(), _ -> x
  end
  S = constructorof(T) # remove parameters from parametric types and support anonymous functions
  vals = ntuple(i -> getfield(x, names[i]), length(names))
  return NamedTuple{names}(vals), y -> S(y...)
end
functor(::Type{<:Tuple}, x) = x, identity
functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity
functor(::Type{<:Dict}, x) = Dict(k => x[k] for k in keys(x)), identity
functor(::Type{<:AbstractArray}, x) = x, identity

macro leaf(T)
  :(functor(::Type{<:$(esc(T))}, x) = (NoChildren(), _ -> x))
end

@leaf Type
@leaf Number
@leaf AbstractArray{<:Number}

function execute(walk, x, ys...)
  recurse(xs...) = walk(var"#self#", xs...)
  walk(recurse, x, ys...)
end

function fmap(f, x, ys...; exclude = isleaf,
                           walk = DefaultWalk(),
                           cache = IdDict(),
                           prune = NoKeyword())
  _walk = ExcludeWalk(walk, f, exclude)
  if !isnothing(cache)
    _walk = CachedWalk(_walk, prune, WalkCache(_walk, cache))
  end
  execute(_walk, x, ys...)
end

isnumeric(x::AbstractArray{<:Number}) = isleaf(x)
isnumeric(x::AbstractArray{<:Integer}) = false
isnumeric(x) = false
children(x) = functor(x)[1]
isleaf(@nospecialize(x)) = children(x) === NoChildren()

struct TrainableStructWalk end

mapvalue(f, x...) = map(f, x...)
mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)

trainable(x) = functor(x)[1]
_trainable(x) = _trainable(functor(x)[1], trainable(x))
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
_trainable(ch::AbstractArray, tr::AbstractArray) = tr
_trainable(ch::Dict, tr::Dict) = merge(mapvalue(_ -> nothing, ch), tr)
(::TrainableStructWalk)(recurse, x) = mapvalue(recurse, _trainable(x))

function _flatten(x)
  isnumeric(x) && return vcat(_vec(x)), 0, length(x)  # trivial case
  arrays = AbstractVector[]
  len = Ref(0)
  off = fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
    push!(arrays, _vec(y))
    o = len[]
    len[] = o + length(y)
    o
  end
  isempty(arrays) && return Bool[], off, 0
  return reduce(vcat, arrays), off, len[]
end

function destructure(x)
  flat, off, len = _flatten(x)
  flat, Restructure(x, off, len)
end

struct Restructure{T,S}
  model::T
  offsets::S
  length::Int
end

struct _Trainable_biwalk end


_getat(y::Number, o::Int, flat::AbstractVector) = flat[o + 1]
_getat(y::AbstractArray, o::Int, flat::AbstractVector) = reshape(flat[o .+ (1:length(y))], axes(y))

function _trainmap(f, ch, tr, aux)
  map(ch, tr, aux) do c, t, a  # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
    isnothing(t) ? c : f(t, a)
  end
end

function (::_Trainable_biwalk)(f, x, aux)
  ch, re = functor(typeof(x), x)
  au, _ = functor(typeof(x), aux)
  _trainmap(f, ch, _trainable(x), au) |> re
end

function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...)
  fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
    _getat(y, o, flat)
  end
end

(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length)

struct TransformedDistribution{D,B}
    dist::D
    transform::B
end

function makefunctor(T, fs = fieldnames(T))
  fidx = Ref(0)
  escargs = map(fieldnames(T)) do f
    f in fs ? :(y[$(fidx[] += 1)]) : :(x.$f)
  end
  escargs_nt = map(fieldnames(T)) do f
    f in fs ? :(y[$(Meta.quot(f))]) : :(x.$f)
  end
  escfs = [:($f=x.$f) for f in fs]
  
  @eval begin
    function functor(::Type{<:$T}, x)
      reconstruct(y) = $T($(escargs...))
      reconstruct(y::NamedTuple) = $T($(escargs_nt...))
      return (;$(escfs...)), reconstruct
    end
  end
end

functor(x) = functor(typeof(x), x)

makefunctor(TransformedDistribution)

function rand(rng::AbstractRNG, td::TransformedDistribution, num_samples::Int)
    samples = rand(rng, td.dist, num_samples)
    res = reduce(
        hcat,
        map(axes(samples, 2)) do i
            return td.transform(view(samples, :, i))
        end,
    )
    return res
end

struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}}
    bs::Bs
    ranges_in::Rs
    ranges_out::Rs
    length_in::Int
    length_out::Int
end

makefunctor(Stacked, (:bs,))

function mapvcat(f, args...)
    out = map(f, args...)
    init = vcat(out[1])
    return reduce(vcat, drop(out, 1); init=init)
end

@generated function _transform_stacked_recursive(
    x, rs::NTuple{N,UnitRange{Int}}, bs...
) where {N}
    exprs = []
    for i in 1:N
        push!(exprs, :(bs[$i](x[rs[$i]])))
    end
    return :(vcat($(exprs...)))
end

function _transform_stacked_recursive(x, rs::NTuple{1,UnitRange{Int}}, b)
    rs[1] == 1:length(x) || error("range must be 1:length(x)")
    return b(x)
end

function _transform_stacked(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real})
    y = _transform_stacked_recursive(x, sb.ranges_in, sb.bs...)
    return y
end

function _transform_stacked(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real})
    N = length(sb.bs)
    N == 1 && return sb.bs[1](x[sb.ranges_in[1]])

    y = mapvcat(1:N) do i
        sb.bs[i](x[sb.ranges_in[i]])
    end
    return y
end

function (sb::Stacked)(x::AbstractVector{<:Real})
    y = _transform_stacked(sb, x)
    return y
end

struct TestProb1 end

logdensity(::TestProb1, θ) = sum(θ)

struct TestProb2 end

logdensity(::TestProb2, θ) = sum(θ)

struct MvLocationScale{L}
    location::L
end

Base.length(q::MvLocationScale) = length(q.location)

makefunctor(MvLocationScale)

# This specialization improves AD performance of the sampling path
function rand(
    rng::AbstractRNG, q::MvLocationScale, num_samples::Int
)
    (; location,) = q
    n_dims = length(location)
    return randn(rng, n_dims, num_samples) 
end

function estimate_repgradelbo_ad_forward(params, aux)
    (; rng, problem, restructure) = aux
    q = restructure(params)
    zs = rand(rng, q, 10)
    return mean(Base.Fix1(logdensity, problem), eachcol(zs))
end

function main()
    d = 5
    rng = Random.default_rng()

    for prob in [TestProb1(), TestProb2()]
        q = if prob isa TestProb1
            MvLocationScale(zeros(d))
        else
            TransformedDistribution(
                MvLocationScale(zeros(d)),
                Stacked(
                    [Base.Fix1(broadcast, exp), identity],
                    [1:1, 2:d],
                    [1:1, 2:d],
                    d, d,
                )
            )
        end

        params, re = destructure(q)
        buf = zero(params)
        aux = (rng=rng, problem=prob, restructure=re)
        Enzyme.autodiff(
            set_runtime_activity(Enzyme.ReverseWithPrimal, true),
            estimate_repgradelbo_ad_forward,
            Enzyme.Active,
            Enzyme.Duplicated(params, buf),
            Enzyme.Const(aux),
        )
    end
end

main()

@wsmoses
Copy link
Member

wsmoses commented Nov 18, 2024

darn, sadly this no longer segfaults

@Red-Portal
Copy link
Author

Red-Portal commented Nov 18, 2024

@wsmoses Even with q = restructure(params)::typeof(restructure.model)? This is necessary on my end to segfault (but has been weirdly the opposite of your end).

@Red-Portal
Copy link
Author

@wsmoses Can you confirm whether you still can't reproduce?

@wsmoses
Copy link
Member

wsmoses commented Nov 23, 2024

got it and significantly reduced to this:

using Statistics
using Base.Iterators
using LinearAlgebra
using Enzyme
using Random

Enzyme.Compiler.DumpPostOpt[] = true
Enzyme.API.printall!(true)

struct Stacked
end

@inline function myrand(rng::AbstractRNG, td::Stacked, num_samples::Int)
    return Base.inferencebarrier(ones(1))
end

struct TestProb1 end

logdensity(::TestProb1, θ) = sum(θ)

struct TestProb2 end

logdensity(::TestProb2, θ) = sum(θ)

struct MvLocationScale
end

# This specialization improves AD performance of the sampling path
@inline function myrand(
    rng::AbstractRNG, q::MvLocationScale, num_samples::Int
)
    return ones(5, num_samples)
end

function mymean(problem, A::AbstractArray)
    isempty(A) && return sum(Base.Fix1(logdensity, problem), A)
    x1 = sum(@inbounds first(A))
    return 1.0
end

function estimate_repgradelbo_ad_forward(rng, problem, model)
    zs = myrand(rng, model, 10)
    return mymean(problem, eachcol(zs))
end

function main()
    d = 5
    rng = Random.default_rng()

    for prob in [TestProb1(), TestProb2()]
        q = if prob isa TestProb1
            MvLocationScale()
        else
            Stacked()
        end

        Enzyme.autodiff(
            set_runtime_activity(Enzyme.ReverseWithPrimal, true),
            estimate_repgradelbo_ad_forward,
            Enzyme.Active,
            Enzyme.Const(rng),
            Enzyme.Const(prob),
            Enzyme.Const(q),
        )
    end
end

main()
# main0()
# main1()

@Red-Portal
Copy link
Author

No way... that's magical

@Red-Portal
Copy link
Author

Wow it also seems like it was quite a fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants