Skip to content

Commit

Permalink
review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 20, 2022
1 parent bd91d81 commit 385adc3
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 16 deletions.
29 changes: 18 additions & 11 deletions docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ function (m::CustomModel)(x)
return m.chain(x) + x
end

# Call @functor to allow for training. Described below in more detail.
Flux.@functor CustomModel
# Call @layer to allow for training. Described below in more detail.
Flux.@layer CustomModel
```

You can then use the model like:
Expand All @@ -41,29 +41,36 @@ By default all the fields in the `Affine` type are collected as its parameters,
The first way of achieving this is through overloading the `trainable` function.

```julia-repl
julia> @functor Affine
julia> @layer Affine
julia> a = Affine(rand(3,3), rand(3))
Affine{Array{Float64,2},Array{Float64,1}}([0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955])
julia> Flux.params(a) # default behavior
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955]])
julia> Flux.trainable(a::Affine) = (a.W,)
julia> Flux.trainable(a::Affine) = (W = a.W,) # must return a NamedTuple
julia> Flux.params(a)
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297]])
```

Only the fields returned by `trainable` will be collected as trainable parameters of the layer when calling `Flux.params`.

Another way of achieving this is through the `@functor` macro directly. Here, we can mark the fields we are interested in by grouping them in the second argument:
The exact same method of `trainable` can also be defined using the macro, for convenience:

```julia
Flux.@functor Affine (W,)
Flux.@layer Affine trainable=(W,)
```

However, doing this requires the `struct` to have a corresponding constructor that accepts those parameters.
There is a second, more severe, kind of restriction possible:

```
Flux.@layer Affine children=(W,)
```

This is equivalent to `Functors.@functor Affine (W,)`. It means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This is not usually recommended.


## Freezing Layer Parameters

Expand Down Expand Up @@ -127,9 +134,9 @@ Join(combine, paths...) = Join(combine, paths)
```
Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field.

The next step is to use [`Functors.@functor`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path.
The next step is to use [`Functors.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path.
```julia
Flux.@functor Join
Flux.@layer Join
```

Finally, we define the forward pass. For `Join`, this means applying each `path` in `paths` to each input array, then using `combine` to merge the results.
Expand Down Expand Up @@ -182,7 +189,7 @@ model(xs)

Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs.

We start by following the same steps as the `Join` layer: define a struct, use [`Functors.@functor`](@ref), and define the forward pass.
We start by following the same steps as the `Join` layer: define a struct, use [`@layer`](@ref), and define the forward pass.
```julia
using Flux
using CUDA
Expand All @@ -194,7 +201,7 @@ end

Split(paths...) = Split(paths)

Flux.@functor Split
Flux.@layer Split

(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)
```
Expand Down
5 changes: 3 additions & 2 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ macro layer(exs...)
elseif ex.args[1] == :functor
error("Can't use `functor=(...)` as a keyword to `@layer`. Use `childen=(...)` to define a method for `functor`.")
else
@warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1
esc(ex.args[1])
error("`@layer` cannot define a method for `$(ex.args[1])` at the moment, sorry.")
# @warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1
# esc(ex.args[1])
end
push!(out.args, _macro_trainable(esc(type), name, ex.args[2]))
end
Expand Down
2 changes: 1 addition & 1 deletion src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{
return h, reshape_cell_output(h, x)
end

@layer RNNCell # trainable=(Wi, Wh, b)
@layer RNNCell # state0 is trainable, see issue 807 about this.

function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1))
Expand Down
5 changes: 3 additions & 2 deletions test/layers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ module MacroTest
@layer :expand Duo

struct Trio; a; b; c end
@layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget
# @layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget
@layer Trio trainable=(a,b) # defining a method for test is made an error, for now

struct TwoThirds; a; b; c; end
end
Expand All @@ -28,7 +29,7 @@ end
@test Optimisers.trainable(m3) isa NamedTuple{(:a, :b)}
@test Optimisers.destructure(m3)[1] == [1, 2]

@test MacroTest.test(m3) == (c = [3.0],)
# @test MacroTest.test(m3) == (c = [3.0],) # removed, for now

m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6])
# Check that we can use the macro with a qualified type name, outside the defining module:
Expand Down

0 comments on commit 385adc3

Please sign in to comment.