-
-
Notifications
You must be signed in to change notification settings - Fork 611
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
upgrade to at-layer macro, replaces at-functor
- Loading branch information
Showing
8 changed files
with
259 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
|
||
""" | ||
@layer Dense | ||
@layer :expand Chain | ||
@layer BatchNorm trainable=(β,γ) | ||
@layer Struct functor=(α,β) trainable=(β,) | ||
This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same: | ||
When you define a new layer, this tells Flux to explore inside it | ||
to see the parameters it trains, and also to move them to the GPU, change precision, etc. | ||
Some "keywords" allow control of the recursion: | ||
* If some fields look like parameters but should not be trained, | ||
then `Optimisers.trainable` lets you specify fields to include, and ignore the rest. | ||
* We can likewise add restructions to `Functors.functor`, but not yet written. | ||
* In fact you can provide an arbitrary keyword with this syntax, and it will | ||
overload this function alla `trainable`... that might be a terrible idea. | ||
It also handles overloads of `show` for pretty printing. | ||
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`. | ||
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents. | ||
* To disable all `show` overloads, maybe we want a `:ignore` option too. | ||
(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.) | ||
""" | ||
macro layer(exs...) | ||
out = quote end | ||
|
||
# These functions are defined in show.jl, and each return an expression overloading Base.show | ||
type, rest... = if exs[1] == QuoteNode(:expand) | ||
push!(out.args, _macro_big_show(esc(exs[2]))) | ||
exs[2:end] | ||
elseif exs[1] == QuoteNode(:ignore) | ||
exs[2:end] | ||
elseif exs[1] isa QuoteNode | ||
error("before the type, only accepted options are `:expand` and `:ignore`") | ||
else | ||
push!(out.args, _macro_layer_show(esc(exs[1]))) | ||
exs | ||
end | ||
|
||
# This function exists only for depwarns when you use @functor directly | ||
push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) # scope is weird ?? can't use $ on func name? | ||
|
||
i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :functor, rest) | ||
if isnothing(i) | ||
push!(out.args, _macro_functor(esc(type))) | ||
else | ||
push!(out.args, _macro_functor(esc(type), rest[i].args[2])) | ||
end | ||
for j in 1:length(rest) | ||
j == i && continue | ||
ex = rest[j] | ||
Meta.isexpr(ex, :(=)) || error("expected keyword = fields") | ||
if ex.args[1] == :trainable | ||
push!(out.args, _macro_trainable(type, trainable, ex.args[2])) # pass the function "trainable" not the symbol | ||
else | ||
error() | ||
# @warn "defining a method for $(ex.args[1]) in your scope" # ?? | ||
# push!(out.args, _macro_trainable(type, esc(ex.args[1]), ex.args[2])) | ||
end | ||
end | ||
|
||
out | ||
end | ||
|
||
# Temporary depwarn function: | ||
|
||
function _check_new_macro(x::T) where T | ||
Functors.isleaf(x) && return | ||
@warn "you used @functor for this type, but should now use @layer" T maxlog=1 _id=hash(T) | ||
end | ||
_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users | ||
_check_new_macro(::NamedTuple) = nothing | ||
_check_new_macro(::Transpose) = nothing | ||
_check_new_macro(::Adjoint) = nothing | ||
_check_new_macro(::Ref) = nothing | ||
|
||
# @layer's code for Functors & Adapt | ||
# Unlike @functor, _default_functor doesn't need to eval anything | ||
|
||
function _macro_functor(type) | ||
quote | ||
Functors.functor(::Type{T}, x) where {T<:$type} = _default_functor(T, x) | ||
Adapt.adapt_structure(to, layer::$type) = fmap(adapt(to), layer) | ||
end | ||
end | ||
|
||
function _macro_functor(type, fields) | ||
error("the equivalent of @functor Layer (:x,) isn't written yet, sorry") | ||
end | ||
|
||
function _default_functor(::Type{T}, x) where {T} | ||
if @generated | ||
F = fieldnames(T) | ||
args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F) | ||
C = Base.typename(T).name # constructor | ||
recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C)) | ||
:((NamedTuple{$F}(($(args...),)), $recon)) | ||
else | ||
# Getting this parameterless type takes about 2μs, every time: | ||
namedtuple(x), Base.splat(Base.typename(T).wrapper) | ||
end | ||
end | ||
|
||
function namedtuple(x::T) where T | ||
F = fieldnames(T) | ||
NamedTuple{F}(map(sy -> getfield(x, sy), F)) | ||
end | ||
|
||
# @layer's code for Optimisers.trainable, and perhaps anything else, | ||
# with the pattern that keywords mean function names & what fields they pick. | ||
|
||
function _macro_trainable(type, fun, fields) | ||
Meta.isexpr(fields, :tuple) || error("expected a tuple of field names") | ||
symbols = Tuple(map(_noquotenode, fields.args)) | ||
quoted = map(QuoteNode, symbols) | ||
gets = [:(getfield(x, $f)) for f in quoted] | ||
quote | ||
# $fun(x::$type) = NamedTuple{$names}(($(gets...),)) | ||
Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird | ||
end | ||
end | ||
_macro_trainable(type, fun, field::Union{Symbol,QuoteNode}) = _macro_trainable(type, fun, :(($field,))) # lets you forget a comma | ||
|
||
_noquotenode(s::Symbol) = s | ||
_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y) | ||
_noquotenode(ex) = error("expected a symbol, got $ex") | ||
|
||
|
||
|
||
|
||
|
||
|
||
# @big_show Chain | ||
# @big_show Parallel | ||
# @big_show SkipConnection | ||
# @big_show Recur | ||
# @big_show Maxout | ||
|
||
|
||
|
||
|
||
""" | ||
@big_show MyContainer | ||
This macro lets you opt-in to Flux's fancy printing. | ||
When `model::MyContainer` is returned at the REPL it will be treated like `Chain`, | ||
and the printing routine will recursively unfold its children. | ||
This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`. | ||
Custom layers which do not contain other layers (more like `Dense` than like `Chain`) | ||
need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`. | ||
# Example | ||
```jldoctest | ||
julia> struct Trio{A,B,C}; a::A; b::B; c::C end | ||
julia> Flux.@functor Trio | ||
julia> Flux.@big_show Trio | ||
julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax) | ||
Trio( | ||
Dense(10 => 5, tanh), # 55 parameters | ||
Dense(5 => 2), # 12 parameters | ||
NNlib.softmax, | ||
) # Total: 4 arrays, 67 parameters, 492 bytes. | ||
``` | ||
Note that there is no automatic method for 2-arg `show`, and thus | ||
something like `(tri, tri)` will print all the type parameters. | ||
However, `Chain(tri, tri)` will always use Flux's recursive printing, | ||
even without using this macro: `Chain` is the entry point. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.