From 6370374714605bd69558d484bc120cd6fa53d129 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 23 Aug 2022 16:49:06 -0400 Subject: [PATCH] upgrade to at-layer macro, replaces at-functor --- src/Flux.jl | 4 +- src/functor.jl | 1 + src/layers/basic.jl | 18 ++-- src/layers/conv.jl | 6 +- src/layers/macro.jl | 177 ++++++++++++++++++++++++++++++++++++++++ src/layers/normalise.jl | 22 ++--- src/layers/recurrent.jl | 12 +-- src/layers/show.jl | 107 +++++++++++------------- 8 files changed, 259 insertions(+), 88 deletions(-) create mode 100644 src/layers/macro.jl diff --git a/src/Flux.jl b/src/Flux.jl index b7d27406b0..9629a6e66d 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -44,13 +44,15 @@ include("functor.jl") # Pirate error to catch a common mistake. Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.") +include("layers/show.jl") +include("layers/macro.jl") + include("layers/stateless.jl") include("layers/basic.jl") include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") include("layers/upsample.jl") -include("layers/show.jl") include("loading.jl") diff --git a/src/functor.jl b/src/functor.jl index d05489104f..e3b61b8ef3 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -42,6 +42,7 @@ function params!(p::Params, x, seen = IdSet()) elseif x in seen nothing else + _check_new_macro(x) # complains if you used @functor not @layer push!(seen, x) for child in trainable(x) params!(p, child, seen) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 51c5fda9b1..9b9d3077d9 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -46,7 +46,7 @@ end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys, Base.firstindex -@functor Chain +@layer :expand Chain # the + opts-in to container-style pretty-printing (c::Chain)(x) = _applychain(c.layers, x) @@ -165,7 +165,7 @@ function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity; Dense(init(out, in), bias, σ) end -@functor Dense +@layer Dense function (a::Dense)(x::AbstractVecOrMat) σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc @@ -236,7 +236,7 @@ end Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act) Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end]) -@functor Scale +@layer Scale function (a::Scale)(x::AbstractArray) σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc @@ -291,7 +291,7 @@ end Maxout(layers...) = Maxout(layers) Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...) -@functor Maxout +@layer :expand Maxout function (mo::Maxout)(input::AbstractArray) # Perhaps surprisingly, pairwise max broadcast is often faster, @@ -338,7 +338,7 @@ struct SkipConnection{T,F} connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b end -@functor SkipConnection +@layer SkipConnection # should this be expand? function (skip::SkipConnection)(input) skip.connection(skip.layers(input), input) @@ -408,7 +408,7 @@ struct Bilinear{F,A,B} end end -@functor Bilinear +@layer Bilinear function Bilinear(((in1, in2), out)::Pair{<:Tuple, <:Integer}, σ = identity; bias = true, init = glorot_uniform) @@ -507,7 +507,7 @@ function Parallel(connection; kw...) Parallel(connection, layers) end -@functor Parallel +@layer :expand Parallel (m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) (m::Parallel)(xs::Tuple) = m(xs...) @@ -628,7 +628,7 @@ end end applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x) -@functor PairwiseFusion +@layer :expand PairwiseFusion Base.getindex(m::PairwiseFusion, i) = m.layers[i] Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i]) @@ -676,7 +676,7 @@ struct Embedding{W} weight::W end -@functor Embedding +@layer Embedding Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in)) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 36aa5c8430..60f09cd0c3 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -187,7 +187,7 @@ function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; init(filter..., cin÷groups, cout) end -@functor Conv +@layer Conv conv_dims(c::Conv, x::AbstractArray) = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups) @@ -307,7 +307,7 @@ function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = ConvTranspose(weight, bias, σ; stride, pad, dilation, groups) end -@functor ConvTranspose +@layer ConvTranspose function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) # Calculate size of "input", from ∇conv_data()'s perspective... @@ -453,7 +453,7 @@ function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = iden return CrossCor(weight, bias, σ; stride, pad, dilation) end -@functor CrossCor +@layer CrossCor function crosscor(x, w, ddims::DenseConvDims) ddims = DenseConvDims(ddims, F=true) diff --git a/src/layers/macro.jl b/src/layers/macro.jl new file mode 100644 index 0000000000..5ec4fe4487 --- /dev/null +++ b/src/layers/macro.jl @@ -0,0 +1,177 @@ + +""" + @layer Dense + @layer :expand Chain + @layer BatchNorm trainable=(β,γ) + @layer Struct functor=(α,β) trainable=(β,) + +This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same: +When you define a new layer, this tells Flux to explore inside it +to see the parameters it trains, and also to move them to the GPU, change precision, etc. + +Some "keywords" allow control of the recursion: +* If some fields look like parameters but should not be trained, + then `Optimisers.trainable` lets you specify fields to include, and ignore the rest. +* We can likewise add restructions to `Functors.functor`, but not yet written. +* In fact you can provide an arbitrary keyword with this syntax, and it will + overload this function alla `trainable`... that might be a terrible idea. + +It also handles overloads of `show` for pretty printing. +* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`. +* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents. +* To disable all `show` overloads, maybe we want a `:ignore` option too. + +(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.) +""" +macro layer(exs...) + out = quote end + + # These functions are defined in show.jl, and each return an expression overloading Base.show + type, rest... = if exs[1] == QuoteNode(:expand) + push!(out.args, _macro_big_show(esc(exs[2]))) + exs[2:end] + elseif exs[1] == QuoteNode(:ignore) + exs[2:end] + elseif exs[1] isa QuoteNode + error("before the type, only accepted options are `:expand` and `:ignore`") + else + push!(out.args, _macro_layer_show(esc(exs[1]))) + exs + end + + # This function exists only for depwarns when you use @functor directly + push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) # scope is weird ?? can't use $ on func name? + + i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :functor, rest) + if isnothing(i) + push!(out.args, _macro_functor(esc(type))) + else + push!(out.args, _macro_functor(esc(type), rest[i].args[2])) + end + for j in 1:length(rest) + j == i && continue + ex = rest[j] + Meta.isexpr(ex, :(=)) || error("expected keyword = fields") + if ex.args[1] == :trainable + push!(out.args, _macro_trainable(type, trainable, ex.args[2])) # pass the function "trainable" not the symbol + else + error() + # @warn "defining a method for $(ex.args[1]) in your scope" # ?? + # push!(out.args, _macro_trainable(type, esc(ex.args[1]), ex.args[2])) + end + end + + out +end + +# Temporary depwarn function: + +function _check_new_macro(x::T) where T + Functors.isleaf(x) && return + @warn "you used @functor for this type, but should now use @layer" T maxlog=1 _id=hash(T) +end +_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users +_check_new_macro(::NamedTuple) = nothing +_check_new_macro(::Transpose) = nothing +_check_new_macro(::Adjoint) = nothing +_check_new_macro(::Ref) = nothing + +# @layer's code for Functors & Adapt +# Unlike @functor, _default_functor doesn't need to eval anything + +function _macro_functor(type) + quote + Functors.functor(::Type{T}, x) where {T<:$type} = _default_functor(T, x) + Adapt.adapt_structure(to, layer::$type) = fmap(adapt(to), layer) + end +end + +function _macro_functor(type, fields) + error("the equivalent of @functor Layer (:x,) isn't written yet, sorry") +end + +function _default_functor(::Type{T}, x) where {T} + if @generated + F = fieldnames(T) + args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F) + C = Base.typename(T).name # constructor + recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C)) + :((NamedTuple{$F}(($(args...),)), $recon)) + else + # Getting this parameterless type takes about 2μs, every time: + namedtuple(x), Base.splat(Base.typename(T).wrapper) + end +end + +function namedtuple(x::T) where T + F = fieldnames(T) + NamedTuple{F}(map(sy -> getfield(x, sy), F)) +end + +# @layer's code for Optimisers.trainable, and perhaps anything else, +# with the pattern that keywords mean function names & what fields they pick. + +function _macro_trainable(type, fun, fields) + Meta.isexpr(fields, :tuple) || error("expected a tuple of field names") + symbols = Tuple(map(_noquotenode, fields.args)) + quoted = map(QuoteNode, symbols) + gets = [:(getfield(x, $f)) for f in quoted] + quote + # $fun(x::$type) = NamedTuple{$names}(($(gets...),)) + Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird + end +end +_macro_trainable(type, fun, field::Union{Symbol,QuoteNode}) = _macro_trainable(type, fun, :(($field,))) # lets you forget a comma + +_noquotenode(s::Symbol) = s +_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y) +_noquotenode(ex) = error("expected a symbol, got $ex") + + + + + + +# @big_show Chain +# @big_show Parallel +# @big_show SkipConnection +# @big_show Recur +# @big_show Maxout + + + + +""" + @big_show MyContainer + +This macro lets you opt-in to Flux's fancy printing. + +When `model::MyContainer` is returned at the REPL it will be treated like `Chain`, +and the printing routine will recursively unfold its children. +This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`. + +Custom layers which do not contain other layers (more like `Dense` than like `Chain`) +need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`. + +# Example +```jldoctest +julia> struct Trio{A,B,C}; a::A; b::B; c::C end + +julia> Flux.@functor Trio + +julia> Flux.@big_show Trio + +julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax) +Trio( + Dense(10 => 5, tanh), # 55 parameters + Dense(5 => 2), # 12 parameters + NNlib.softmax, +) # Total: 4 arrays, 67 parameters, 492 bytes. +``` + +Note that there is no automatic method for 2-arg `show`, and thus +something like `(tri, tri)` will print all the type parameters. + +However, `Chain(tri, tri)` will always use Flux's recursive printing, +even without using this macro: `Chain` is the entry point. +""" \ No newline at end of file diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 446575f355..2dd116731b 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -103,8 +103,8 @@ function Dropout(p; dims=:, rng = default_rng_value()) Dropout(p, dims, nothing, rng) end -@functor Dropout -trainable(a::Dropout) = (;) +@layer Dropout trainable=() +# trainable(a::Dropout) = (;) function (a::Dropout)(x) _isactive(a) || return x @@ -158,8 +158,8 @@ end AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value()) AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng) -@functor AlphaDropout -trainable(a::AlphaDropout) = (;) +@layer AlphaDropout trainable=() +# trainable(a::AlphaDropout) = (;) function (a::AlphaDropout)(x::AbstractArray{T}) where T _isactive(a) || return x @@ -224,7 +224,7 @@ end LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...) LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...) -@functor LayerNorm +@layer LayerNorm (a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ)) @@ -352,8 +352,8 @@ function BatchNorm(chs::Int, λ=identity; nothing, chs) end -@functor BatchNorm -trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;) +@layer BatchNorm trainable=(β,γ) +# trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;) function (BN::BatchNorm)(x) @assert size(x, ndims(x)-1) == BN.chs @@ -442,8 +442,8 @@ function InstanceNorm(chs::Int, λ=identity; nothing, chs) end -@functor InstanceNorm -trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;) +@layer InstanceNorm trainable=(β,γ) +# trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;) function (l::InstanceNorm)(x) @assert ndims(x) > 2 @@ -521,8 +521,8 @@ mutable struct GroupNorm{F,V,N,W} chs::Int # number of channels end -@functor GroupNorm -trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;) +@layer GroupNorm trainable=(β,γ) +# trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;) function GroupNorm(chs::Int, G::Int, λ=identity; initβ=zeros32, initγ=ones32, diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 760933bb96..419ef465e0 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -135,8 +135,8 @@ function (m::Recur)(x) return y end -@functor Recur -trainable(a::Recur) = (; cell = a.cell) +@layer :expand Recur trainable=(cell,) +# trainable(a::Recur) = (; cell = a.cell) Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") @@ -207,7 +207,7 @@ function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T} return h, reshape_cell_output(h, x) end -@functor RNNCell +@layer RNNCell # trainable=(Wi, Wh, b) function Base.show(io::IO, l::RNNCell) print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)) @@ -302,7 +302,7 @@ function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{Abstr return (h′, c′), reshape_cell_output(h′, x) end -@functor LSTMCell +@layer LSTMCell Base.show(io::IO, l::LSTMCell) = print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")") @@ -370,7 +370,7 @@ function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},O return h′, reshape_cell_output(h′, x) end -@functor GRUCell +@layer GRUCell Base.show(io::IO, l::GRUCell) = print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") @@ -435,7 +435,7 @@ function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T} return h′, reshape_cell_output(h′, x) end -@functor GRUv3Cell +@layer GRUv3Cell Base.show(io::IO, l::GRUv3Cell) = print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") diff --git a/src/layers/show.jl b/src/layers/show.jl index db1ba21b85..b425e4e6e7 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -1,42 +1,10 @@ -""" - @big_show MyContainer +@nospecialize # just for this file, for startup time -This macro lets you opt-in to Flux's fancy printing. - -When `model::MyContainer` is returned at the REPL it will be treated like `Chain`, -and the printing routine will recursively unfold its children. -This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`. - -Custom layers which do not contain other layers (more like `Dense` than like `Chain`) -need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`. - -# Example -```jldoctest -julia> struct Trio{A,B,C}; a::A; b::B; c::C end - -julia> Flux.@functor Trio - -julia> Flux.@big_show Trio - -julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax) -Trio( - Dense(10 => 5, tanh), # 55 parameters - Dense(5 => 2), # 12 parameters - NNlib.softmax, -) # Total: 4 arrays, 67 parameters, 492 bytes. -``` - -Note that there is no automatic method for 2-arg `show`, and thus -something like `(tri, tri)` will print all the type parameters. - -However, `Chain(tri, tri)` will always use Flux's recursive printing, -even without using this macro: `Chain` is the entry point. -""" -macro big_show(ex) - ex isa Symbol || error("usage is `Flux.@big_show Chain`") - eex = esc(ex) +# This is called by @layer, on layers which should be treated like Chain +function _macro_big_show(ex) quote - function Base.show(io::IO, m::MIME"text/plain", x::$eex) + # Entry point: + function Base.show(io::IO, m::MIME"text/plain", x::$ex) if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL _big_show(io, x) elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix @@ -45,15 +13,12 @@ macro big_show(ex) show(io, x) end end + + # Don't show Chain(Tuple(...)), always splat that: + _show_children(x::$ex) = _flat_children(x) end end -@big_show Chain -@big_show Parallel -@big_show SkipConnection -@big_show Recur -@big_show Maxout - function _big_show(io::IO, obj, indent::Int=0, name=nothing) pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")") children = _show_children(obj) @@ -88,27 +53,51 @@ end _show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: _show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv _show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell -_show_leaflike(::Scale) = true # appears inside LayerNorm +# _show_leaflike(::Scale) = true # appears inside LayerNorm _show_leaflike(::AbstractArray{<:Number}) = true # e.g. transposed arrays _show_children(x) = trainable(x) # except for layers which hide their Tuple: -_show_children(c::Chain) = c.layers -_show_children(m::Maxout) = m.layers -_show_children(p::Parallel) = (p.connection, p.layers...) -_show_children(f::PairwiseFusion) = (f.connection, f.layers...) +# _show_children(c::Chain) = c.layers +# _show_children(m::Maxout) = m.layers +# _show_children(p::Parallel) = (p.connection, p.layers...) +# _show_children(f::PairwiseFusion) = (f.connection, f.layers...) + +function _flat_children(x) + alpha = map(f -> getfield(x, f), fieldnames(typeof(x))) + beta = map(y -> y isa Union{Tuple, NamedTuple} ? y : (y,), alpha) + gamma = ((beta...)...,) +end -for T in [ - :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, - :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, - ] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if !get(io, :compact, false) - _layer_show(io, x) - else - show(io, x) +# This is called by @layer, on layers which should be treated like Dense +function _macro_layer_show(ex) + quote + # Entry point: + function Base.show(io::IO, m::MIME"text/plain", x::$ex) + if !get(io, :compact, false) + _layer_show(io, x) + else + show(io, x) + end + end + + # Exit from _big_show recursion, do we need this and _show_leaflike? + _big_show(io::IO, obj::$ex, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name) + # Since this isn't a container, do not recurse into its children, if any: + _show_leaflike(::$ex) = true end - end end +# for T in [ +# :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, +# :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, +# ] +# @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) +# if !get(io, :compact, false) +# _layer_show(io, x) +# else +# show(io, x) +# end +# end +# end function _layer_show(io::IO, layer, indent::Int=0, name=nothing) _str = isnothing(name) ? "" : "$name = " @@ -163,6 +152,8 @@ function _nan_show(io::IO, x) end end +@specialize # un-does @nospecialze at the top of this file + _any(f, xs::AbstractArray{<:Number}) = any(f, xs) # _any(f, xs::Union{Tuple,NamedTuple,Zygote.Params}) = any(x -> _any(f, x), xs) _any(f, xs) = any(x -> _any(f, x), xs)