Skip to content

Commit

Permalink
tidy up, add NEWS
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 25, 2022
1 parent 8981283 commit 2827cc1
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 23 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Flux Release Notes

## v0.13.5

* New macro `Flux.@layer` which should be used in place of `@functor`.
This also adds `show` methods for pretty printing.

## v0.13.4
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)

Expand Down
32 changes: 17 additions & 15 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,43 @@
@layer :expand Chain
@layer BatchNorm trainable=(β,γ)
@layer Struct children=(α,β) trainable=(β,)
This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same:
This macro replaces most uses of `@functor`. Its basic purpose is the same:
When you define a new layer, this tells Flux to explore inside it
to see the parameters it trains, and also to move them to the GPU, change precision, etc.
Like `@functor`, this assumes your struct has the default constructor, to enable re-building.
Some "keywords" allow control of the recursion:
Some "keywords" allow control of the recursion.
* If some fields look like parameters but should not be trained,
then `trainable` lets you specify fields to include, and ignore the rest.
* You can likewise add restructions to Functors's `children` (although this is seldom a good idea).
The defaults are `fieldnames(T)` for both. They must be subsets of this, and `trainable` must be a subset of `children`.
It also handles overloads of `show` for pretty printing.
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
* To disable all `show` overloads, there is an `:ignore` option too.
Note that re-running the macro with different options does not overwrite all methods, you will need to restart.
(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
Note that re-running the macro with different options does not overwrite all methods, you will need to restart.
# Example
```jldoctest
julia> struct Trio; a; b; c end
julia> tri = Trio(Dense([1.1 2.2],), Dense([3.3;;], false), Dropout(0.4))
Trio(Dense(1 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense([3.3;;], false), Dropout(0.4))
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
julia> Flux.destructure(tri) # parameters not visible to Flux
julia> Flux.destructure(tri) # parameters are not yet visible to Flux
(Bool[], Restructure(Trio, ..., 0))
julia> Flux.@layer :expand Trio
julia> Flux.destructure(tri) # now gpu, train!, etc will see inside too
([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4))
julia> tri
julia> tri # and layer is printed like Chain
Trio(
Dense(2 => 1), # 3 parameters
Dense(1 => 1; bias=false), # 1 parameters
Expand All @@ -58,7 +59,7 @@ macro layer(exs...)
elseif exs[1] == QuoteNode(:ignore)
exs[2:end]
elseif exs[1] isa QuoteNode
error("before the type, only accepted options are `:expand` and `:ignore`")
error("`@layer` accepts only two options before the layer type, `:expand` and `:ignore` (to control `show`)")
else
push!(out.args, _macro_layer_show(esc(exs[1])))
exs
Expand All @@ -76,12 +77,14 @@ macro layer(exs...)
for j in 1:length(rest)
j == i && continue
ex = rest[j]
Meta.isexpr(ex, :(=)) || error("expected keyword = fields")
Meta.isexpr(ex, :(=)) || error("The macro `@layer` expects here `keyword = (fields...,)`, got $ex")

name = if ex.args[1] == :trainable
:(Optimisers.trainable)
elseif ex.args[1] == :functor
error("Can't use `functor=(...)` as a keyword to `@layer`. Use `childen=(...)` to define a method for `functor`.")
else
@warn "trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1
@warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1
esc(ex.args[1])
end
push!(out.args, _macro_trainable(esc(type), name, ex.args[2]))
Expand All @@ -94,7 +97,7 @@ end

function _check_new_macro(x::T) where T
Functors.isleaf(x) && return
@warn "This type should now use Flux.@layer instead of @functor" T maxlog=1 _id=hash(T)
Base.depwarn("This type should probably now use `Flux.@layer` instead of `@functor`: $T", Symbol("@functor"))
end
_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users
_check_new_macro(::NamedTuple) = nothing
Expand Down Expand Up @@ -159,11 +162,10 @@ function _macro_trainable(type, fun, fields)
gets = [:(getfield(x, $f)) for f in quoted]
quote
$fun(x::$type) = NamedTuple{$symbols}(($(gets...),))
# Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird
end
end
_macro_trainable(type, fun, field::Union{Symbol,QuoteNode}) = _macro_trainable(type, fun, :(($field,))) # lets you forget a comma

_noquotenode(s::Symbol) = s
_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y)
_noquotenode(ex) = error("expected a symbol, got $ex")
_noquotenode(ex) = error("expected a symbol here, as a field name, but got $ex")
5 changes: 0 additions & 5 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ function Dropout(p; dims=:, rng = default_rng_value())
end

@layer Dropout trainable=()
# trainable(a::Dropout) = (;)

function (a::Dropout)(x)
_isactive(a) || return x
Expand Down Expand Up @@ -159,7 +158,6 @@ AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)

@layer AlphaDropout trainable=()
# trainable(a::AlphaDropout) = (;)

function (a::AlphaDropout)(x::AbstractArray{T}) where T
_isactive(a) || return x
Expand Down Expand Up @@ -355,7 +353,6 @@ function BatchNorm(chs::Int, λ=identity;
end

@layer BatchNorm trainable=(β,γ)
# trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)

function (BN::BatchNorm)(x)
@assert size(x, ndims(x)-1) == BN.chs
Expand Down Expand Up @@ -445,7 +442,6 @@ function InstanceNorm(chs::Int, λ=identity;
end

@layer InstanceNorm trainable=(β,γ)
# trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)

function (l::InstanceNorm)(x)
@assert ndims(x) > 2
Expand Down Expand Up @@ -524,7 +520,6 @@ mutable struct GroupNorm{F,V,N,W}
end

@layer GroupNorm trainable=(β,γ)
# trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
Expand Down
1 change: 0 additions & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ end
_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:
_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell
# _show_leaflike(::Scale) = true # appears inside LayerNorm
_show_leaflike(::AbstractArray{<:Number}) = true # e.g. transposed arrays

_show_children(x) = trainable(x)
Expand Down
5 changes: 3 additions & 2 deletions test/layers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ module MacroTest
@layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget

struct TwoThirds; a; b; c; end
@layer :expand TwoThirds children=(a,c) trainable=(a)

end

@testset "@layer macro" begin
Expand All @@ -33,6 +31,9 @@ end
@test MacroTest.test(m3) == (c = [3.0],)

m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6])
# Check that we can use the macro with a qualified type name, outside the defining module:
Flux.@layer :expand MacroTest.TwoThirds children=(:a,:c) trainable=(:a) # documented as (a,c) but allow quotes

@test Functors.children(m23) == (a = [1 2], c = [5 6])
m23re = Functors.functor(m23)[2]((a = [10 20], c = [50 60]))
@test m23re isa MacroTest.TwoThirds
Expand Down

0 comments on commit 2827cc1

Please sign in to comment.