Skip to content

Commit

Permalink
upgrade to at-layer macro, replaces at-functor
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 23, 2022
1 parent 9f9051f commit 6370374
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 88 deletions.
4 changes: 3 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@ include("functor.jl")
# Pirate error to catch a common mistake.
Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.")

include("layers/show.jl")
include("layers/macro.jl")

include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")
include("layers/upsample.jl")
include("layers/show.jl")

include("loading.jl")

Expand Down
1 change: 1 addition & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ function params!(p::Params, x, seen = IdSet())
elseif x in seen
nothing
else
_check_new_macro(x) # complains if you used @functor not @layer
push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
Expand Down
18 changes: 9 additions & 9 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys, Base.firstindex

@functor Chain
@layer :expand Chain # the + opts-in to container-style pretty-printing

(c::Chain)(x) = _applychain(c.layers, x)

Expand Down Expand Up @@ -165,7 +165,7 @@ function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity;
Dense(init(out, in), bias, σ)
end

@functor Dense
@layer Dense

function (a::Dense)(x::AbstractVecOrMat)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
Expand Down Expand Up @@ -236,7 +236,7 @@ end
Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act)
Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end])

@functor Scale
@layer Scale

function (a::Scale)(x::AbstractArray)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
Expand Down Expand Up @@ -291,7 +291,7 @@ end
Maxout(layers...) = Maxout(layers)
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)

@functor Maxout
@layer :expand Maxout

function (mo::Maxout)(input::AbstractArray)
# Perhaps surprisingly, pairwise max broadcast is often faster,
Expand Down Expand Up @@ -338,7 +338,7 @@ struct SkipConnection{T,F}
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
end

@functor SkipConnection
@layer SkipConnection # should this be expand?

function (skip::SkipConnection)(input)
skip.connection(skip.layers(input), input)
Expand Down Expand Up @@ -408,7 +408,7 @@ struct Bilinear{F,A,B}
end
end

@functor Bilinear
@layer Bilinear

function Bilinear(((in1, in2), out)::Pair{<:Tuple, <:Integer}, σ = identity;
bias = true, init = glorot_uniform)
Expand Down Expand Up @@ -507,7 +507,7 @@ function Parallel(connection; kw...)
Parallel(connection, layers)
end

@functor Parallel
@layer :expand Parallel

(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
(m::Parallel)(xs::Tuple) = m(xs...)
Expand Down Expand Up @@ -628,7 +628,7 @@ end
end
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)

@functor PairwiseFusion
@layer :expand PairwiseFusion

Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
Expand Down Expand Up @@ -676,7 +676,7 @@ struct Embedding{W}
weight::W
end

@functor Embedding
@layer Embedding

Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))

Expand Down
6 changes: 3 additions & 3 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init(filter..., cin÷groups, cout)
end

@functor Conv
@layer Conv

conv_dims(c::Conv, x::AbstractArray) =
DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
Expand Down Expand Up @@ -307,7 +307,7 @@ function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
ConvTranspose(weight, bias, σ; stride, pad, dilation, groups)
end

@functor ConvTranspose
@layer ConvTranspose

function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
Expand Down Expand Up @@ -453,7 +453,7 @@ function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = iden
return CrossCor(weight, bias, σ; stride, pad, dilation)
end

@functor CrossCor
@layer CrossCor

function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)
Expand Down
177 changes: 177 additions & 0 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@

"""
@layer Dense
@layer :expand Chain
@layer BatchNorm trainable=(β,γ)
@layer Struct functor=(α,β) trainable=(β,)
This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same:
When you define a new layer, this tells Flux to explore inside it
to see the parameters it trains, and also to move them to the GPU, change precision, etc.
Some "keywords" allow control of the recursion:
* If some fields look like parameters but should not be trained,
then `Optimisers.trainable` lets you specify fields to include, and ignore the rest.
* We can likewise add restructions to `Functors.functor`, but not yet written.
* In fact you can provide an arbitrary keyword with this syntax, and it will
overload this function alla `trainable`... that might be a terrible idea.
It also handles overloads of `show` for pretty printing.
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
* To disable all `show` overloads, maybe we want a `:ignore` option too.
(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
"""
macro layer(exs...)
out = quote end

# These functions are defined in show.jl, and each return an expression overloading Base.show
type, rest... = if exs[1] == QuoteNode(:expand)
push!(out.args, _macro_big_show(esc(exs[2])))
exs[2:end]
elseif exs[1] == QuoteNode(:ignore)
exs[2:end]
elseif exs[1] isa QuoteNode
error("before the type, only accepted options are `:expand` and `:ignore`")
else
push!(out.args, _macro_layer_show(esc(exs[1])))
exs
end

# This function exists only for depwarns when you use @functor directly
push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) # scope is weird ?? can't use $ on func name?

i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :functor, rest)
if isnothing(i)
push!(out.args, _macro_functor(esc(type)))
else
push!(out.args, _macro_functor(esc(type), rest[i].args[2]))
end
for j in 1:length(rest)
j == i && continue
ex = rest[j]
Meta.isexpr(ex, :(=)) || error("expected keyword = fields")
if ex.args[1] == :trainable
push!(out.args, _macro_trainable(type, trainable, ex.args[2])) # pass the function "trainable" not the symbol
else
error()
# @warn "defining a method for $(ex.args[1]) in your scope" # ??
# push!(out.args, _macro_trainable(type, esc(ex.args[1]), ex.args[2]))
end
end

out
end

# Temporary depwarn function:

function _check_new_macro(x::T) where T
Functors.isleaf(x) && return
@warn "you used @functor for this type, but should now use @layer" T maxlog=1 _id=hash(T)
end
_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users
_check_new_macro(::NamedTuple) = nothing
_check_new_macro(::Transpose) = nothing
_check_new_macro(::Adjoint) = nothing
_check_new_macro(::Ref) = nothing

# @layer's code for Functors & Adapt
# Unlike @functor, _default_functor doesn't need to eval anything

function _macro_functor(type)
quote
Functors.functor(::Type{T}, x) where {T<:$type} = _default_functor(T, x)
Adapt.adapt_structure(to, layer::$type) = fmap(adapt(to), layer)
end
end

function _macro_functor(type, fields)
error("the equivalent of @functor Layer (:x,) isn't written yet, sorry")
end

function _default_functor(::Type{T}, x) where {T}
if @generated
F = fieldnames(T)
args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F)
C = Base.typename(T).name # constructor
recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C))
:((NamedTuple{$F}(($(args...),)), $recon))
else
# Getting this parameterless type takes about 2μs, every time:
namedtuple(x), Base.splat(Base.typename(T).wrapper)
end
end

function namedtuple(x::T) where T
F = fieldnames(T)
NamedTuple{F}(map(sy -> getfield(x, sy), F))
end

# @layer's code for Optimisers.trainable, and perhaps anything else,
# with the pattern that keywords mean function names & what fields they pick.

function _macro_trainable(type, fun, fields)
Meta.isexpr(fields, :tuple) || error("expected a tuple of field names")
symbols = Tuple(map(_noquotenode, fields.args))
quoted = map(QuoteNode, symbols)
gets = [:(getfield(x, $f)) for f in quoted]
quote
# $fun(x::$type) = NamedTuple{$names}(($(gets...),))
Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird
end
end
_macro_trainable(type, fun, field::Union{Symbol,QuoteNode}) = _macro_trainable(type, fun, :(($field,))) # lets you forget a comma

_noquotenode(s::Symbol) = s
_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y)
_noquotenode(ex) = error("expected a symbol, got $ex")






# @big_show Chain
# @big_show Parallel
# @big_show SkipConnection
# @big_show Recur
# @big_show Maxout




"""
@big_show MyContainer
This macro lets you opt-in to Flux's fancy printing.
When `model::MyContainer` is returned at the REPL it will be treated like `Chain`,
and the printing routine will recursively unfold its children.
This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`.
Custom layers which do not contain other layers (more like `Dense` than like `Chain`)
need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`.
# Example
```jldoctest
julia> struct Trio{A,B,C}; a::A; b::B; c::C end
julia> Flux.@functor Trio
julia> Flux.@big_show Trio
julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax)
Trio(
Dense(10 => 5, tanh), # 55 parameters
Dense(5 => 2), # 12 parameters
NNlib.softmax,
) # Total: 4 arrays, 67 parameters, 492 bytes.
```
Note that there is no automatic method for 2-arg `show`, and thus
something like `(tri, tri)` will print all the type parameters.
However, `Chain(tri, tri)` will always use Flux's recursive printing,
even without using this macro: `Chain` is the entry point.
"""
22 changes: 11 additions & 11 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ function Dropout(p; dims=:, rng = default_rng_value())
Dropout(p, dims, nothing, rng)
end

@functor Dropout
trainable(a::Dropout) = (;)
@layer Dropout trainable=()
# trainable(a::Dropout) = (;)

function (a::Dropout)(x)
_isactive(a) || return x
Expand Down Expand Up @@ -158,8 +158,8 @@ end
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)

@functor AlphaDropout
trainable(a::AlphaDropout) = (;)
@layer AlphaDropout trainable=()
# trainable(a::AlphaDropout) = (;)

function (a::AlphaDropout)(x::AbstractArray{T}) where T
_isactive(a) || return x
Expand Down Expand Up @@ -224,7 +224,7 @@ end
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)

@functor LayerNorm
@layer LayerNorm

(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))

Expand Down Expand Up @@ -352,8 +352,8 @@ function BatchNorm(chs::Int, λ=identity;
nothing, chs)
end

@functor BatchNorm
trainable(bn::BatchNorm) = hasaffine(bn) ?= bn.β, γ = bn.γ) : (;)
@layer BatchNorm trainable=(β,γ)
# trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)

function (BN::BatchNorm)(x)
@assert size(x, ndims(x)-1) == BN.chs
Expand Down Expand Up @@ -442,8 +442,8 @@ function InstanceNorm(chs::Int, λ=identity;
nothing, chs)
end

@functor InstanceNorm
trainable(in::InstanceNorm) = hasaffine(in) ?= in.β, γ = in.γ) : (;)
@layer InstanceNorm trainable=(β,γ)
# trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)

function (l::InstanceNorm)(x)
@assert ndims(x) > 2
Expand Down Expand Up @@ -521,8 +521,8 @@ mutable struct GroupNorm{F,V,N,W}
chs::Int # number of channels
end

@functor GroupNorm
trainable(gn::GroupNorm) = hasaffine(gn) ?= gn.β, γ = gn.γ) : (;)
@layer GroupNorm trainable=(β,γ)
# trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
Expand Down
Loading

0 comments on commit 6370374

Please sign in to comment.