From 73dae527940ac71c405383f6eb7171566c30f761 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 14 Oct 2024 19:25:46 +0200 Subject: [PATCH] add tests finish RNNCell RNN rework LSTMCell LSTM more work gru extended testing runtests add tests finish RNNCell RNN rework LSTMCell LSTM more work gru extended testing reset! deprecation fix test unbreak l2 test fix tests fixes --- perf/recurrent.jl | 2 - src/Flux.jl | 1 + src/deprecations.jl | 5 + src/layers/recurrent.jl | 837 ++++++++++++++++------------ test/ext_amdgpu/runtests.jl | 5 + test/ext_common/recurrent_gpu_ad.jl | 163 ++++++ test/ext_cuda/curnn.jl | 67 --- test/ext_cuda/layers.jl | 6 +- test/ext_cuda/runtests.jl | 5 +- test/ext_enzyme/enzyme.jl | 2 - test/ext_metal/runtests.jl | 5 + test/layers/attention.jl | 6 +- test/layers/recurrent.jl | 472 ++++++++-------- test/test_utils.jl | 39 +- test/utils.jl | 18 +- 15 files changed, 957 insertions(+), 676 deletions(-) create mode 100644 test/ext_common/recurrent_gpu_ad.jl delete mode 100644 test/ext_cuda/curnn.jl diff --git a/perf/recurrent.jl b/perf/recurrent.jl index ef00a8d9a5..1550009bd3 100644 --- a/perf/recurrent.jl +++ b/perf/recurrent.jl @@ -7,12 +7,10 @@ Flux.@functor RNNWrapper # Need to specialize for RNNWrapper. fw(r::RNNWrapper, X::Vector{<:AbstractArray}) = begin - Flux.reset!(r.rnn) [r.rnn(x) for x in X] end fw(r::RNNWrapper, X) = begin - Flux.reset!(r.rnn) r.rnn(X) end diff --git a/src/Flux.jl b/src/Flux.jl index 47142dbdda..e0d01639ca 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -35,6 +35,7 @@ Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zyg export Chain, Dense, Embedding, EmbeddingBag, Maxout, SkipConnection, Parallel, PairwiseFusion, + RNNCell, LSTMCell, GRUCell, GRUv3Cell, RNN, LSTM, GRU, GRUv3, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, diff --git a/src/deprecations.jl b/src/deprecations.jl index 6148894dbe..8a9b67501d 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -167,3 +167,8 @@ end # where `loss_mxy` accepts the model as its first argument. # """ # )) + +function reset!(x) + Base.depwarn("reset!(m) is deprecated. You can remove this call as it is no more needed.", :reset!) + return x +end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 931eed65ca..b93e3a80c3 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,506 +1,647 @@ -gate(h, n) = (1:h) .+ h*(n-1) -gate(x::AbstractVector, h, n) = @view x[gate(h,n)] -gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :) - -# AD-friendly helper for dividing monolithic RNN params into equally sized gates -multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N) - -function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, h, c) - function multigate_pullback(dy) - dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x) - foreach(multigate(dx, h, c), unthunk(dy)) do dxᵢ, dyᵢ - dyᵢ isa AbstractZero && return - @. dxᵢ += dyᵢ - end - return (NoTangent(), dx, NoTangent(), NoTangent()) - end - return multigate(x, h, c), multigate_pullback -end +# Vanilla RNN -# Type stable and AD-friendly helper for iterating over the last dimension of an array -function eachlastdim(A::AbstractArray{T,N}) where {T,N} - inds_before = ntuple(_ -> :, N-1) - return (view(A, inds_before..., i) for i in axes(A, N)) -end +@doc raw""" + RNNCell(in => out, σ = tanh; init = glorot_uniform, bias = true) -# adapted from https://github.com/JuliaDiff/ChainRules.jl/blob/f13e0a45d10bb13f48d6208e9c9d5b4a52b96732/src/rulesets/Base/indexing.jl#L77 -function ∇eachlastdim(dys_raw, x::AbstractArray{T, N}) where {T, N} - dys = unthunk(dys_raw) - i1 = findfirst(dy -> dy isa AbstractArray, dys) - if isnothing(i1) # all slices are Zero! - return fill!(similar(x, T, axes(x)), zero(T)) - end - # The whole point of this gradient is that we can allocate one `dx` array: - dx = similar(x, T, axes(x))::AbstractArray - for i in axes(x, N) - slice = selectdim(dx, N, i) - if dys[i] isa AbstractZero - fill!(slice, zero(eltype(slice))) - else - copyto!(slice, dys[i]) - end - end - return ProjectTo(x)(dx) -end +The most basic recurrent layer. Essentially acts as a `Dense` layer, but with the +output fed back into the input each time step. -function ChainRulesCore.rrule(::typeof(eachlastdim), x::AbstractArray{T,N}) where {T,N} - lastdims(dy) = (NoTangent(), ∇eachlastdim(unthunk(dy), x)) - collect(eachlastdim(x)), lastdims -end +In the forward pass, implements the function +```math +h^\prime = \sigma(W_i x + W_h h + b) +``` +and returns `h'`. -reshape_cell_output(h, x) = reshape(h, :, size(x)[2:end]...) +See [`RNN`](@ref) for a layer that processes entire sequences. -# Stateful recurrence +# Arguments -""" - Recur(cell) +- `in => out`: The input and output dimensions of the layer. +- `σ`: The non-linearity to apply to the output. Default is `tanh`. +- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `bias`: Whether to include a bias term initialized to zero. Default is `true`. -`Recur` takes a recurrent cell and makes it stateful, managing the hidden state -in the background. `cell` should be a model of the form: +# Forward - h, y = cell(h, x...) + rnncell(x, [h]) -For example, here's a recurrent network that keeps a running total of its inputs: +The arguments of the forward pass are: + +- `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`. +- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`. + If not provided, it is assumed to be a vector of zeros. # Examples -```jldoctest -julia> accum(h, x) = (h + x, x) -accum (generic function with 1 method) -julia> rnn = Flux.Recur(accum, 0) -Recur(accum) +```jldoctest +r = RNNCell(3 => 5) -julia> rnn(2) -2 +# A sequence of length 10 and batch size 4 +x = [rand(Float32, 3, 4) for _ in 1:10] -julia> rnn(3) -3 +# Initialize the hidden state +h = zeros(Float32, 5) -julia> rnn.state -5 -``` +# We collect the hidden states in an array `history` +# in case the loss depends on the entire sequence. +ŷ = [] -Folding over a 3d Array of dimensions `(features, batch, time)` is also supported: +for x_t in x + h = r(x_t, h) + ŷ = [ŷ..., h] # Cannot use `push!(ŷ, h)` here since mutation + # is not automatic differentiation friendly yet. + # Can use `y = vcat(y, [h])` as an alternative. +end -```jldoctest -julia> accum(h, x) = (h .+ x, x) -accum (generic function with 1 method) - -julia> rnn = Flux.Recur(accum, zeros(Int, 1, 1)) -Recur(accum) - -julia> rnn([2]) -1-element Vector{Int64}: - 2 - -julia> rnn([3]) -1-element Vector{Int64}: - 3 - -julia> rnn.state -1×1 Matrix{Int64}: - 5 - -julia> out = rnn(reshape(1:10, 1, 1, :)); # apply to a sequence of (features, batch, time) - -julia> out |> size -(1, 1, 10) - -julia> vec(out) -10-element Vector{Int64}: - 1 - 2 - 3 - 4 - 5 - 6 - 7 - 8 - 9 - 10 - -julia> rnn.state -1×1 Matrix{Int64}: - 60 +h # The final hidden state +ŷ # The hidden states at each time step ``` """ -mutable struct Recur{T,S} - cell::T - state::S +struct RNNCell{F,I,H,V} + σ::F + Wi::I + Wh::H + bias::V end -function (m::Recur)(x) - m.state, y = m.cell(m.state, x) - return y +@layer RNNCell + +function RNNCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true) + Wi = init(out, in) + Wh = init(out, out) + b = create_bias(Wi, bias, size(Wi, 1)) + return RNNCell(σ, Wi, Wh, b) end -@layer :expand Recur trainable=(cell,) +(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1))) -Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") +function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat) + _size_check(m, x, 1 => size(m.Wi,2)) + σ = NNlib.fast_act(m.σ, x) + h = σ.(m.Wi*x .+ m.Wh*h .+ m.bias) + return h +end -""" - reset!(rnn) +function Base.show(io::IO, m::RNNCell) + print(io, "RNNCell(", size(m.Wi, 2), " => ", size(m.Wi, 1)) + print(io, ", ", m.σ) + print(io, ")") +end -Reset the hidden state of a recurrent layer back to its original value. +@doc raw""" + RNN(in => out, σ = tanh; bias = true, init = glorot_uniform) -Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to: +The most basic recurrent layer. Essentially acts as a `Dense` layer, but with the +output fed back into the input each time step. - rnn.state = hidden(rnn.cell) +In the forward pass computes -# Examples -```jldoctest -julia> r = Flux.RNNCell(relu, ones(1,1), zeros(1,1), ones(1,1), zeros(1,1)); # users should use the RNN wrapper struct instead +```math +h_t = \sigma(W_i x_t + W_h h_{t-1} + b) +``` +for all `len` steps `t` in the in input sequence. -julia> y = Flux.Recur(r, ones(1,1)); +See [`RNNCell`](@ref) for a layer that processes a single time step. -julia> y.state -1×1 Matrix{Float64}: - 1.0 +# Arguments -julia> y(ones(1,1)) # relu(1*1 + 1) -1×1 Matrix{Float64}: - 2.0 +- `in => out`: The input and output dimensions of the layer. +- `σ`: The non-linearity to apply to the output. Default is `tanh`. +- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `bias`: Whether to include a bias term initialized to zero. Default is `true`. -julia> y.state -1×1 Matrix{Float64}: - 2.0 +# Forward -julia> Flux.reset!(y) -1×1 Matrix{Float64}: - 0.0 + rnn(x, h) -julia> y.state -1×1 Matrix{Float64}: - 0.0 -``` -""" -reset!(m::Recur) = (m.state = m.cell.state0) -reset!(m) = foreach(reset!, functor(m)[1]) +The arguments of the forward pass are: -flip(f, xs) = reverse([f(x) for x in reverse(xs)]) +- `x`: The input to the RNN. It should be a matrix size `in x len` or an array of size `in x len x batch_size`. +- `h`: The initial hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`. -function (m::Recur)(x::AbstractArray{T, 3}) where T - h = [m(x_t) for x_t in eachlastdim(x)] - sze = size(h[1]) - reshape(reduce(hcat, h), sze[1], sze[2], length(h)) -end +Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. -# Vanilla RNN +# Examples -struct RNNCell{F,I,H,V,S} - σ::F - Wi::I - Wh::H - b::V - state0::S -end +```jldoctest +julia> d_in, d_out, len, batch_size = 4, 6, 3, 5; -RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) = - RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1)) +julia> x = rand(Float32, (d_in, len, batch_size)); -function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where {F,I,H,V,T} - Wi, Wh, b = m.Wi, m.Wh, m.b - _size_check(m, x, 1 => size(Wi,2)) - σ = NNlib.fast_act(m.σ, x) - xT = _match_eltype(m, T, x) - h = σ.(Wi*xT .+ Wh*h .+ b) - return h, reshape_cell_output(h, x) -end +julia> h = zeros(Float32, (d_out, batch_size)); -@layer RNNCell # state0 is trainable, see issue 807 about this. +julia> rnn = RNN(d_in => d_out) +RNN( + RNNCell(4 => 6, tanh), # 66 parameters +) # Total: 3 arrays, 66 parameters, 424 bytes. -function Base.show(io::IO, l::RNNCell) - print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)) - l.σ == identity || print(io, ", ", l.σ) - print(io, ")") +julia> y = rnn(x, h); # [y] = [d_out, len, batch_size] +``` + +Sometimes, the initial hidden state is a learnable parameter. +In this case, the `RNN` should be wrapped in a custom struct. + +```jldoctest +struct Model + rnn::RNN + h0::AbstractVector end +Flux.@layer :expand Model + +(m::Model)(x) = m.rnn(x, m.h0) + +model = Model(RNN(32 => 64), zeros(Float32, 64)) +``` """ - RNN(in => out, σ = tanh) +struct RNN{M} + cell::M +end -The most basic recurrent layer; essentially acts as a `Dense` layer, but with the -output fed back into the input each time step. +@layer :expand RNN + +function RNN((in, out)::Pair, σ = tanh; bias = true, init = glorot_uniform) + cell = RNNCell(in => out, σ; bias, init) + return RNN(cell) +end -The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. +(m::RNN)(x) = m(x, zeros_like(x, size(m.cell.Wh, 1))) + +function (m::RNN)(x, h) + @assert ndims(x) == 2 || ndims(x) == 3 + # [x] = [in, L] or [in, L, B] + # [h] = [out] or [out, B] + y = [] + for x_t in eachslice(x, dims=2) + h = m.cell(x_t, h) + # y = [y..., h] + y = vcat(y, [h]) + end + return stack(y, dims=2) +end -This constructor is syntactic sugar for `Recur(RNNCell(a...))`, and so RNNs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. -# Examples -```jldoctest -julia> r = RNN(3 => 5) -Recur( - RNNCell(3 => 5, tanh), # 50 parameters -) # Total: 4 trainable arrays, 50 parameters, - # plus 1 non-trainable, 5 parameters, summarysize 424 bytes. +# LSTM +@doc raw""" + LSTMCell(in => out; init = glorot_uniform, bias = true) -julia> r(rand(Float32, 3)) |> size -(5,) +The [Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory) cell. +Behaves like an RNN but generally exhibits a longer memory span over sequences. -julia> Flux.reset!(r); +In the forward pass, computes -julia> r(rand(Float32, 3, 10)) |> size # batch size of 10 -(5, 10) +```math +i_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i) +f_t = \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f) +c_t = f_t \odot c_{t-1} + i_t \odot \tanh(W_{xc} x_t + W_{hc} h_{t-1} + b_c) +o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o) +h_t = o_t \odot \tanh(c_t) ``` -!!! warning "Batch size changes" - - Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the following example: +The `LSTMCell` returns the new hidden state `h_t` and cell state `c_t` for a single time step. +See also [`LSTM`](@ref) for a layer that processes entire sequences. - ```julia - julia> r = RNN(3 => 5) - Recur( - RNNCell(3 => 5, tanh), # 50 parameters - ) # Total: 4 trainable arrays, 50 parameters, - # plus 1 non-trainable, 5 parameters, summarysize 432 bytes. +# Arguments - julia> r.state |> size - (5, 1) +- `in => out`: The input and output dimensions of the layer. +- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `bias`: Whether to include a bias term initialized to zero. Default is `true`. - julia> r(rand(Float32, 3)) |> size - (5,) +# Forward - julia> r.state |> size - (5, 1) + lstmcell(x, (h, c)) + lstmcell(x) - julia> r(rand(Float32, 3, 10)) |> size # batch size of 10 - (5, 10) +The arguments of the forward pass are: +- `x`: The input to the LSTM. It should be a matrix of size `in` or an array of size `in x batch_size`. +- `(h, c)`: A tuple containing the hidden and cell states of the LSTM. + They should be vectors of size `out` or matrices of size `out x batch_size`. + If not provided, they are assumed to be vectors of zeros. - julia> r.state |> size # state shape has changed - (5, 10) +Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`. - julia> r(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector. - (50,) - ``` +# Examples -# Note: -`RNNCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi` and `Wh` matrices do not need to be the same type, but if `Wh` is `dxd`, then `Wi` should be of shape `dxN`. +```jldoctest +julia> l = LSTMCell(3 => 5) +LSTMCell(3 => 5) # 180 parameters -```julia -julia> using LinearAlgebra +julia> h = zeros(Float32, 5); # hidden state -julia> r = Flux.Recur(Flux.RNNCell(tanh, rand(5, 4), Tridiagonal(rand(5, 5)), rand(5), rand(5, 1))) +julia> c = zeros(Float32, 5); # cell state -julia> r(rand(4, 10)) |> size # batch size of 10 -(5, 10) -``` -""" -RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) -Recur(m::RNNCell) = Recur(m, m.state0) +julia> x = rand(Float32, 3, 4); # in x batch_size -# LSTM +julia> h′, c′ = l(x, (h, c)); -struct LSTMCell{I,H,V,S} +julia> size(h′) # out x batch_size +(5, 4) +""" +struct LSTMCell{I,H,V} Wi::I Wh::H - b::V - state0::S + bias::V end -function LSTMCell((in, out)::Pair; - init = glorot_uniform, - initb = zeros32, - init_state = zeros32) - cell = LSTMCell(init(out * 4, in), init(out * 4, out), initb(out * 4), (init_state(out,1), init_state(out,1))) - cell.b[gate(out, 2)] .= 1 +@layer LSTMCell + +function LSTMCell((in, out)::Pair; init = glorot_uniform, bias = true) + Wi = init(out * 4, in) + Wh = init(out * 4, out) + b = create_bias(Wi, bias, out * 4) + cell = LSTMCell(Wi, Wh, b) return cell end -function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::AbstractVecOrMat) where {I,H,V,T} - _size_check(m, x, 1 => size(m.Wi,2)) - b, o = m.b, size(h, 1) - xT = _match_eltype(m, T, x) - g = muladd(m.Wi, xT, muladd(m.Wh, h, b)) - input, forget, cell, output = multigate(g, o, Val(4)) +function (m::LSTMCell)(x::AbstractVecOrMat) + h = zeros_like(x, size(m.Wh, 2)) + c = zeros_like(h) + return m(x, (h, c)) +end + +function (m::LSTMCell)(x::AbstractVecOrMat, (h, c)) + _size_check(m, x, 1 => size(m.Wi, 2)) + b = m.bias + g = m.Wi * x .+ m.Wh * h .+ b + input, forget, cell, output = chunk(g, 4; dims=1) c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell) h′ = @. sigmoid_fast(output) * tanh_fast(c′) - return (h′, c′), reshape_cell_output(h′, x) + return h′, c′ end -@layer LSTMCell +Base.show(io::IO, m::LSTMCell) = + print(io, "LSTMCell(", size(m.Wi, 2), " => ", size(m.Wi, 1)÷4, ")") -Base.show(io::IO, l::LSTMCell) = - print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")") -""" - LSTM(in => out) +@doc raw"""" + LSTM(in => out; init = glorot_uniform, bias = true) [Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory) recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. -The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. - -This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. - See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. +In the forward pass, computes + +```math +i_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i) +f_t = \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f) +c_t = f_t \odot c_{t-1} + i_t \odot \tanh(W_{xc} x_t + W_{hc} h_{t-1} + b_c) +o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o) +h_t = o_t \odot \tanh(c_t) +``` +for all `len` steps `t` in the input sequence. +See [`LSTMCell`](@ref) for a layer that processes a single time step. + +# Arguments + +- `in => out`: The input and output dimensions of the layer. +- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `bias`: Whether to include a bias term initialized to zero. Default is `true`. + +# Forward + + lstm(x, (h, c)) + lstm(x) + +The arguments of the forward pass are: +- `x`: The input to the LSTM. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`. +- `(h, c)`: A tuple containing the hidden and cell states of the LSTM. + They should be vectors of size `out` or matrices of size `out x batch_size`. + If not provided, they are assumed to be vectors of zeros. + +Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t` +in tensors of size `out x len` or `out x len x batch_size`. + # Examples + ```jldoctest -julia> l = LSTM(3 => 5) -Recur( - LSTMCell(3 => 5), # 190 parameters -) # Total: 5 trainable arrays, 190 parameters, - # plus 2 non-trainable, 10 parameters, summarysize 1.023 KiB. +struct Model + lstm::LSTM + h0::AbstractVector + c0::AbstractVector +end -julia> l(rand(Float32, 3)) |> size -(5,) +Flux.@layer :expand Model -julia> Flux.reset!(l); +(m::Model)(x) = m.lstm(x, (m.h0, m.c0)) -julia> l(rand(Float32, 3, 10)) |> size # batch size of 10 -(5, 10) +d_in, d_out, len, batch_size = 2, 3, 4, 5 +x = rand(Float32, (d_in, len, batch_size)) +model = Model(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out)) +h, c = model(x) +size(h) # out x len x batch_size ``` +""" +struct LSTM{M} + cell::M +end -!!! warning "Batch size changes" - Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref). +@layer :expand LSTM -# Note: - `LSTMCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi` and `Wh` matrices do not need to be the same type. See the example in [`RNN`](@ref). -""" -LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) -Recur(m::LSTMCell) = Recur(m, m.state0) +function LSTM((in, out)::Pair; init = glorot_uniform, bias = true) + cell = LSTMCell(in => out; init, bias) + return LSTM(cell) +end -# GRU +function (m::LSTM)(x) + h = zeros_like(x, size(m.cell.Wh, 1)) + c = zeros_like(h) + return m(x, (h, c)) +end -function _gru_output(gxs, ghs, bs) - r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1]) - z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2]) - return r, z +function (m::LSTM)(x, (h, c)) + @assert ndims(x) == 2 || ndims(x) == 3 + h′ = [] + c′ = [] + for x_t in eachslice(x, dims=2) + h, c = m.cell(x_t, (h, c)) + h′ = vcat(h′, [h]) + c′ = vcat(c′, [c]) + end + return stack(h′, dims=2), stack(c′, dims=2) end -struct GRUCell{I,H,V,S} +# GRU + +@doc raw""" + GRUCell(in => out; init = glorot_uniform, bias = true) + +[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. +Behaves like an RNN but generally exhibits a longer memory span over sequences. +This implements the variant proposed in v1 of the referenced paper. + +In the forward pass, computes + +```math +r = \sigma(W_{xi} x + W_{hi} h + b_i) +z = \sigma(W_{xz} x + W_{hz} h + b_z) +h̃ = \tanh(W_{xh} x + r \odot W_{hh} h + b_h) +h' = (1 - z) \odot h̃ + z \odot h +``` + +See also [`GRU`](@ref) for a layer that processes entire sequences. + +# Arguments + +- `in => out`: The input and output dimensions of the layer. +- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `bias`: Whether to include a bias term initialized to zero. Default is `true`. + +# Forward + + grucell(x, h) + grucell(x) + +The arguments of the forward pass are: +- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`. +- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. + +# Examples + +TODO add loop +```jldoctest +julia> g = GRUCell(3 => 5) +GRUCell(3 => 5) # 140 parameters + +julia> h = zeros(Float32, 5); # hidden state + +julia> x = rand(Float32, 3, 4); # in x batch_size + +julia> h′ = g(x, h); +``` +""" +struct GRUCell{I,H,V} Wi::I Wh::H b::V - state0::S end -GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) = - GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1)) +@layer GRUCell + +function GRUCell((in, out)::Pair; init = glorot_uniform, bias = true) + Wi = init(out * 3, in) + Wh = init(out * 3, out) + b = create_bias(Wi, bias, size(Wi, 1)) + return GRUCell(Wi, Wh, b) +end + +(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2))) -function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where {I,H,V,T} +function (m::GRUCell)(x::AbstractVecOrMat, h) _size_check(m, x, 1 => size(m.Wi,2)) - Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1) - xT = _match_eltype(m, T, x) - gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3)) - r, z = _gru_output(gxs, ghs, bs) + gxs = chunk(m.Wi * x, 3, dims=1) + ghs = chunk(m.Wh * h, 3, dims=1) + if m.b isa AbstractArray + bs = chunk(m.b, 3, dims=1) + else # b == false + bs = [false, false, false] + end + r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1]) + z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2]) h̃ = @. tanh_fast(gxs[3] + r * ghs[3] + bs[3]) h′ = @. (1 - z) * h̃ + z * h - return h′, reshape_cell_output(h′, x) + return h′ end -@layer GRUCell +Base.show(io::IO, m::GRUCell) = + print(io, "GRUCell(", size(m.Wi, 2), " => ", size(m.Wi, 1)÷3, ")") -Base.show(io::IO, l::GRUCell) = - print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") - -""" - GRU(in => out) +@doc raw""" + GRU(in => out; init = glorot_uniform, bias = true) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. This implements the variant proposed in v1 of the referenced paper. -The integer arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. +The forward pass computes -This constructor is syntactic sugar for `Recur(GRUCell(a...))`, and so GRUs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. +```math +r_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i) +z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z) +h̃_t = \tanh(W_{xh} x_t + r_t \odot W_{hh} h_{t-1} + b_h) +h_t = (1 - z_t) \odot h̃_t + z_t \odot h_{t-1} +``` +for all `len` steps `t` in the input sequence. +See [`GRUCell`](@ref) for a layer that processes a single time step. -See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) -for a good overview of the internals. +# Forward + + gru(x, h) + gru(x) + +The arguments of the forward pass are: + +- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`. +- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. # Examples + ```jldoctest -julia> g = GRU(3 => 5) -Recur( - GRUCell(3 => 5), # 140 parameters -) # Total: 4 trainable arrays, 140 parameters, - # plus 1 non-trainable, 5 parameters, summarysize 784 bytes. +d_in, d_out, len, batch_size = 2, 3, 4, 5 +gru = GRU(d_in => d_out) +x = rand(Float32, (d_in, len, batch_size)) +h0 = zeros(Float32, d_out) +h = gru(x, h0) # out x len x batch_size +``` +""" +struct GRU{M} + cell::M +end + +@layer :expand GRU + +function GRU((in, out)::Pair; init = glorot_uniform, bias = true) + cell = GRUCell(in => out; init, bias) + return GRU(cell) +end -julia> g(rand(Float32, 3)) |> size -(5,) +function (m::GRU)(x) + h = zeros_like(x, size(m.cell.Wh, 2)) + return m(x, h) +end -julia> Flux.reset!(g); +function (m::GRU)(x, h) + @assert ndims(x) == 2 || ndims(x) == 3 + h′ = [] + # [x] = [in, L] or [in, L, B] + for x_t in eachslice(x, dims=2) + h = m.cell(x_t, h) + h′ = vcat(h′, [h]) + end + return stack(h′, dims=2) +end -julia> g(rand(Float32, 3, 10)) |> size # batch size of 10 -(5, 10) +# GRU v3 +@doc raw""" + GRUv3Cell(in => out, init = glorot_uniform, bias = true) + +[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. +Behaves like an RNN but generally exhibits a longer memory span over sequences. +This implements the variant proposed in v3 of the referenced paper. + +The forward pass computes +```math +r = \sigma(W_{xi} x + W_{hi} h + b_i) +z = \sigma(W_{xz} x + W_{hz} h + b_z) +h̃ = \tanh(W_{xh} x + W_{hh̃} (r \odot W_{hh} h) + b_h) +h' = (1 - z) \odot h̃ + z \odot h ``` +and returns `h'`. This is a single time step of the GRU. -!!! warning "Batch size changes" - Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref). +See [`GRUv3`](@ref) for a layer that processes entire sequences. +See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer. -# Note: - `GRUCell`s can be constructed directly by specifying the non-linear function, the `Wi` and `Wh` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi` and `Wh` matrices do not need to be the same type. See the example in [`RNN`](@ref). -""" -GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) -Recur(m::GRUCell) = Recur(m, m.state0) +# Arguments -# GRU v3 +- `in => out`: The input and output dimensions of the layer. +- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `bias`: Whether to include a bias term initialized to zero. Default is `true`. -struct GRUv3Cell{I,H,V,HH,S} +# Forward + + gruv3cell(x, h) + gruv3cell(x) + +The arguments of the forward pass are: +- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`. +- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. +""" +struct GRUv3Cell{I,H,V,HH} Wi::I Wh::H b::V Wh_h̃::HH - state0::S end -GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) = - GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3), - init(out, out), init_state(out,1)) +@layer GRUv3Cell -function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where {I,H,V,HH,T} +function GRUv3Cell((in, out)::Pair; init = glorot_uniform, bias = true) + Wi = init(out * 3, in) + Wh = init(out * 3, out) + Wh_h̃ = init(out, out) + b = create_bias(Wi, bias, out * 3) + return GRUv3Cell(Wi, Wh, b, Wh_h̃) +end + +(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2))) + +function (m::GRUv3Cell)(x::AbstractVecOrMat, h) _size_check(m, x, 1 => size(m.Wi,2)) - Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1) - xT = _match_eltype(m, T, x) - gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3)) - r, z = _gru_output(gxs, ghs, bs) - h̃ = tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3]) + gxs = chunk(m.Wi * x, 3, dims=1) + ghs = chunk(m.Wh * h, 3, dims=1) + if m.b isa AbstractArray + bs = chunk(m.b, 3, dims=1) + else # m.b == false + bs = [false, false, false] + end + r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1]) + z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2]) + h̃ = tanh_fast.(gxs[3] .+ (m.Wh_h̃ * (r .* h)) .+ bs[3]) h′ = @. (1 - z) * h̃ + z * h - return h′, reshape_cell_output(h′, x) + return h′ end -@layer GRUv3Cell +Base.show(io::IO, m::GRUv3Cell) = + print(io, "GRUv3Cell(", size(m.Wi, 2), " => ", size(m.Wi, 1)÷3, ")") -Base.show(io::IO, l::GRUv3Cell) = - print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") -""" +@doc raw""" GRUv3(in => out) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. This implements the variant proposed in v3 of the referenced paper. -The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. - -This constructor is syntactic sugar for `Recur(GRUv3Cell(a...))`, and so GRUv3s are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. +The forward pass computes -See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) -for a good overview of the internals. +```math +r_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i) +z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z) +h̃_t = \tanh(W_{xh} x_t + W_{hh̃} (r_t \odot W_{hh} h_{t-1}) + b_h) +h_t = (1 - z_t) \odot h̃_t + z_t \odot h_{t-1} +``` +for all `len` steps `t` in the input sequence. +See [`GRUv3Cell`](@ref) for a layer that processes a single time step. +See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer. # Examples -```jldoctest -julia> g = GRUv3(3 => 5) -Recur( - GRUv3Cell(3 => 5), # 140 parameters -) # Total: 5 trainable arrays, 140 parameters, - # plus 1 non-trainable, 5 parameters, summarysize 840 bytes. +TODO +""" +struct GRUv3{M} + cell::M +end -julia> g(rand(Float32, 3)) |> size -(5,) +@layer :expand GRUv3 -julia> Flux.reset!(g); +function GRUv3((in, out)::Pair; init = glorot_uniform, bias = true) + cell = GRUv3Cell(in => out; init, bias) + return GRUv3(cell) +end -julia> g(rand(Float32, 3, 10)) |> size # batch size of 10 -(5, 10) -``` +function (m::GRUv3)(x) + h = zeros_like(x, size(m.cell.Wh, 2)) + return m(x, h) +end -!!! warning "Batch size changes" - Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the example in [`RNN`](@ref). +function (m::GRUv3)(x, h) + @assert ndims(x) == 2 || ndims(x) == 3 + h′ = [] + for x_t in eachslice(x, dims=2) + h = m.cell(x_t, h) + h′ = vcat(h′, [h]) + end + return stack(h′, dims=2) +end -# Note: - `GRUv3Cell`s can be constructed directly by specifying the non-linear function, the `Wi`, `Wh`, and `Wh_h` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `Wi`, `Wh`, and `Wh_h` matrices do not need to be the same type. See the example in [`RNN`](@ref). -""" -GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...)) -Recur(m::GRUv3Cell) = Recur(m, m.state0) diff --git a/test/ext_amdgpu/runtests.jl b/test/ext_amdgpu/runtests.jl index ec779dedea..42972f285e 100644 --- a/test/ext_amdgpu/runtests.jl +++ b/test/ext_amdgpu/runtests.jl @@ -9,3 +9,8 @@ end @testset "Basic" begin include("basic.jl") end + +@testset "Recurrent" begin + global BROKEN_TESTS = [] + include("../ext_common/recurrent_gpu_ad.jl") +end diff --git a/test/ext_common/recurrent_gpu_ad.jl b/test/ext_common/recurrent_gpu_ad.jl new file mode 100644 index 0000000000..d2ef3fe34b --- /dev/null +++ b/test/ext_common/recurrent_gpu_ad.jl @@ -0,0 +1,163 @@ + +@testset "RNNCell GPU AD" begin + function loss(r, x, h) + y = [] + for x_t in x + h = r(x_t, h) + y = vcat(y, [h]) + end + # return mean(h) + y = stack(y, dims=2) # [D, L] or [D, L, B] + return mean(y) + end + + d_in, d_out, len, batch_size = 2, 3, 4, 5 + r = RNNCell(d_in => d_out) + x = [randn(Float32, d_in, batch_size) for _ in 1:len] + h = zeros(Float32, d_out) + # Single Step + @test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :rnncell_single ∈ BROKEN_TESTS + # Multiple Steps + @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :rnncell_multiple ∈ BROKEN_TESTS +end + +@testset "RNN GPU AD" begin + struct ModelRNN + rnn::RNN + h0::AbstractVector + end + + Flux.@layer :expand ModelRNN + + (m::ModelRNN)(x) = m.rnn(x, m.h0) + + d_in, d_out, len, batch_size = 2, 3, 4, 5 + model = ModelRNN(RNN(d_in => d_out), zeros(Float32, d_out)) + x_nobatch = randn(Float32, d_in, len) + @test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :rnn_nobatch ∈ BROKEN_TESTS + x = randn(Float32, d_in, batch_size) + @test test_gradients(model, x, test_gpu=true, compare_finite_diff=false) broken = :rnn ∈ BROKEN_TESTS +end + +@testset "LSTMCell" begin + + function loss(r, x, hc) + h, c = hc + h′ = [] + c′ = [] + for x_t in x + h, c = r(x_t, (h, c)) + h′ = vcat(h′, [h]) + c′ = [c′..., c] + end + hnew = stack(h′, dims=2) + cnew = stack(c′, dims=2) + return mean(hnew) + mean(cnew) + end + + d_in, d_out, len, batch_size = 2, 3, 4, 5 + cell = LSTMCell(d_in => d_out) + x = [randn(Float32, d_in, batch_size) for _ in 1:len] + h = zeros(Float32, d_out) + c = zeros(Float32, d_out) + # Single Step + @test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false, + loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single ∈ BROKEN_TESTS + # Multiple Steps + @test test_gradients(cell, x, (h, c); test_gpu=true, compare_finite_diff=false, loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS +end + +@testset "LSTM" begin + struct ModelLSTM + lstm::LSTM + h0::AbstractVector + c0::AbstractVector + end + + Flux.@layer :expand ModelLSTM + + (m::ModelLSTM)(x) = m.lstm(x, (m.h0, m.c0)) + + d_in, d_out, len, batch_size = 2, 3, 4, 5 + model = ModelLSTM(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out)) + x_nobatch = randn(Float32, d_in, len) + @test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false, + loss = (m, x) -> mean(m(x)[1])) broken = :lstm_nobatch ∈ BROKEN_TESTS + x = randn(Float32, d_in, len, batch_size) + @test test_gradients(model, x; test_gpu=true, compare_finite_diff=false, + loss = (m, x) -> mean(m(x)[1])) broken = :lstm ∈ BROKEN_TESTS +end + +@testset "GRUCell" begin + function loss(r, x, h) + y = [] + for x_t in x + h = r(x_t, h) + y = vcat(y, [h]) + end + y = stack(y, dims=2) # [D, L] or [D, L, B] + return mean(y) + end + + d_in, d_out, len, batch_size = 2, 3, 4, 5 + r = GRUCell(d_in => d_out) + x = [randn(Float32, d_in, batch_size) for _ in 1:len] + h = zeros(Float32, d_out) + @test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single ∈ BROKEN_TESTS + @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :grucell_multiple ∈ BROKEN_TESTS +end + +@testset "GRU GPU AD" begin + struct ModelGRU + gru::GRU + h0::AbstractVector + end + + Flux.@layer :expand ModelGRU + + (m::ModelGRU)(x) = m.gru(x, m.h0) + + d_in, d_out, len, batch_size = 2, 3, 4, 5 + model = ModelGRU(GRU(d_in => d_out), zeros(Float32, d_out)) + x_nobatch = randn(Float32, d_in, len) + @test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gru_nobatch ∈ BROKEN_TESTS + x = randn(Float32, d_in, len, batch_size) + @test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gru ∈ BROKEN_TESTS +end + +@testset "GRUv3Cell GPU AD" begin + function loss(r, x, h) + y = [] + for x_t in x + h = r(x_t, h) + y = vcat(y, [h]) + end + y = stack(y, dims=2) # [D, L] or [D, L, B] + return mean(y) + end + + d_in, d_out, len, batch_size = 2, 3, 4, 5 + r = GRUv3Cell(d_in => d_out) + x = [randn(Float32, d_in, batch_size) for _ in 1:len] + h = zeros(Float32, d_out) + @test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :gruv3cell_single ∈ BROKEN_TESTS + @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS +end + +@testset "GRUv3 GPU AD" begin + struct ModelGRUv3 + gru::GRUv3 + h0::AbstractVector + end + + Flux.@layer :expand ModelGRUv3 + + (m::ModelGRUv3)(x) = m.gru(x, m.h0) + + d_in, d_out, len, batch_size = 2, 3, 4, 5 + model = ModelGRUv3(GRUv3(d_in => d_out), zeros(Float32, d_out)) + x_nobatch = randn(Float32, d_in, len) + @test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gruv3_nobatch ∈ BROKEN_TESTS + x = randn(Float32, d_in, len, batch_size) + @test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gruv3 ∈ BROKEN_TESTS +end diff --git a/test/ext_cuda/curnn.jl b/test/ext_cuda/curnn.jl deleted file mode 100644 index 5c460d2aa4..0000000000 --- a/test/ext_cuda/curnn.jl +++ /dev/null @@ -1,67 +0,0 @@ -using Flux, CUDA, Test - -@testset for R in [RNN, GRU, LSTM, GRUv3] - m = R(10, 5) |> gpu - x = gpu(rand(10)) - (m̄,) = gradient(m -> sum(m(x)), m) - Flux.reset!(m) - θ = gradient(() -> sum(m(x)), params(m)) - @test x isa CuArray - @test θ[m.cell.Wi] isa CuArray - @test collect(m̄.cell.Wi) == collect(θ[m.cell.Wi]) -end - -@testset "RNN" begin - @testset for R in [RNN, GRU, LSTM, GRUv3], batch_size in (1, 5) - rnn = R(10, 5) - curnn = fmap(gpu, rnn) - - Flux.reset!(rnn) - Flux.reset!(curnn) - x = batch_size == 1 ? - rand(Float32, 10) : - rand(Float32, 10, batch_size) - cux = gpu(x) - - y, back = pullback((r, x) -> r(x), rnn, x) - cuy, cuback = pullback((r, x) -> r(x), curnn, cux) - - @test y ≈ collect(cuy) - - ȳ = randn(size(y)) - m̄, x̄ = back(ȳ) - cum̄, cux̄ = cuback(gpu(ȳ)) - - @test x̄ ≈ collect(cux̄) - @test m̄[].cell.Wi ≈ collect(cum̄[].cell.Wi) - @test m̄[].cell.Wh ≈ collect(cum̄[].cell.Wh) - @test m̄[].cell.b ≈ collect(cum̄[].cell.b) - if m̄[].state isa Tuple - for (x, cx) in zip(m̄[].state, cum̄[].state) - @test x ≈ collect(cx) - end - else - @test m̄[].state ≈ collect(cum̄[].state) - end - - Flux.reset!(rnn) - Flux.reset!(curnn) - ohx = batch_size == 1 ? - Flux.onehot(rand(1:10), 1:10) : - Flux.onehotbatch(rand(1:10, batch_size), 1:10) - cuohx = gpu(ohx) - y = (rnn(ohx); rnn(ohx)) - - cuy = (curnn(cuohx); curnn(cuohx)) - @test y ≈ collect(cuy) - - Flux.reset!(rnn) - Flux.reset!(curnn) - fx = rand(Float32, 10, batch_size, 3) - cufx = gpu(fx) - fy = (rnn(fx); rnn(fx)) - - cufy = (curnn(cufx); curnn(cufx)) - @test fy ≈ collect(cufy) - end -end diff --git a/test/ext_cuda/layers.jl b/test/ext_cuda/layers.jl index cba95cee75..c458832b94 100644 --- a/test/ext_cuda/layers.jl +++ b/test/ext_cuda/layers.jl @@ -312,6 +312,10 @@ end @test Array(y_gpu) ≈ y_cpu atol=1e-4 @test Array(α_gpu) ≈ α_cpu atol=1e-4 - test_gradients(mha_cpu, x_cpu, loss = o -> sum(o[1].^2) + sum(o[2].^2), + function loss(m, x) + y, α = m(x) + return sum(y.^2) + sum(α.^2) + end + test_gradients(mha_cpu, x_cpu; loss, test_gpu=true, compare_finite_diff=false) end diff --git a/test/ext_cuda/runtests.jl b/test/ext_cuda/runtests.jl index 012a62d41a..be02409077 100644 --- a/test/ext_cuda/runtests.jl +++ b/test/ext_cuda/runtests.jl @@ -22,8 +22,9 @@ end @testset "cudnn" begin include("cudnn.jl") end -@testset "curnn" begin - include("curnn.jl") +@testset "Recurrent" begin + global BROKEN_TESTS = [] + include("../ext_common/recurrent_gpu_ad.jl") end @testset "ctc" begin include("ctc.jl") diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index bae14fd246..aa04150cf0 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -93,7 +93,6 @@ end @testset "Models" begin function loss(model, x) - Flux.reset!(model) sum(model(x)) end @@ -126,7 +125,6 @@ end @testset "Recurrence Tests" begin function loss(model, x) - Flux.reset!(model) for i in 1:3 x = model(x) end diff --git a/test/ext_metal/runtests.jl b/test/ext_metal/runtests.jl index cb9532390e..86e1068cf3 100644 --- a/test/ext_metal/runtests.jl +++ b/test/ext_metal/runtests.jl @@ -32,6 +32,11 @@ end include("basic.jl") end +@testset "Recurrent" begin + global BROKEN_TESTS = [:lstm, :gru, :gruv3] + include("../ext_common/recurrent_gpu_ad.jl") +end + @testset "Huber Loss test" begin X = Flux.gpu(Float32[0,1]) Y = Flux.gpu(Float32[1,0]) diff --git a/test/layers/attention.jl b/test/layers/attention.jl index 2c6fd7d514..be0da20c82 100644 --- a/test/layers/attention.jl +++ b/test/layers/attention.jl @@ -54,7 +54,11 @@ end @testset "gradient" begin - test_gradients(mha, q, loss = o -> sum(o[1].^2) + sum(o[2].^2)) + function loss(m, q) + y, α = m(q) + return sum(y.^2) + sum(α.^2) + end + test_gradients(mha, q; loss) end end diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 7df8b0d4c2..98e072cdb1 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,250 +1,272 @@ -using LinearAlgebra - -@testset "RNN gradients-implicit" begin - layer = Flux.Recur(Flux.RNNCell(1, 1, identity)) - layer.cell.Wi .= 5.0 - layer.cell.Wh .= 4.0 - layer.cell.b .= 0.0f0 - layer.cell.state0 .= 7.0 - x = [[2.0f0], [3.0f0]] - - # theoretical primal gradients - primal = - layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ - x[2] .* layer.cell.Wi - ∇Wi = x[1] .* layer.cell.Wh .+ x[2] - ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi - ∇b = layer.cell.Wh .+ 1 - ∇state0 = layer.cell.Wh .^ 2 - - Flux.reset!(layer) - ps = Flux.params(layer) - e, g = Flux.withgradient(ps) do - out = [layer(xi) for xi in x] - sum(out[2]) + +@testset "RNNCell" begin + function loss1(r, x, h) + for x_t in x + h = r(x_t, h) + end + return mean(h.^2) end - @test primal[1] ≈ e - @test ∇Wi ≈ g[ps[1]] - @test ∇Wh ≈ g[ps[2]] - @test ∇b ≈ g[ps[3]] - @test ∇state0 ≈ g[ps[4]] + function loss2(r, x, h) + y = [r(x_t, h) for x_t in x] + return sum(mean, y) + end -end + function loss3(r, x, h) + y = [] + for x_t in x + h = r(x_t, h) + y = [y..., h] + end + return sum(mean, y) + end -@testset "RNN gradients-explicit" begin - layer = Flux.Recur(Flux.RNNCell(1, 1, identity)) - layer.cell.Wi .= 5.0f0 - layer.cell.Wh .= 4.0f0 - layer.cell.b .= 0.0f0 - layer.cell.state0 .= 7.0f0 - x = [[2.0f0], [3.0f0]] - - # theoretical primal gradients - primal = - layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ - x[2] .* layer.cell.Wi - ∇Wi = x[1] .* layer.cell.Wh .+ x[2] - ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi - ∇b = layer.cell.Wh .+ 1 - ∇state0 = layer.cell.Wh .^ 2 - - Flux.reset!(layer) - e, g = Flux.withgradient(layer) do m - out = [m(xi) for xi in x] - sum(out[2]) + function loss4(r, x, h) + y = [] + for x_t in x + h = r(x_t, h) + y = vcat(y, [h]) + end + y = stack(y, dims=2) # [D, L] or [D, L, B] + return mean(y.^2) end - grads = g[1][:cell] - @test primal[1] ≈ e + r = RNNCell(3 => 5) + @test length(Flux.trainables(r)) == 3 + # An input sequence of length 6 and batch size 4. + x = [rand(Float32, 3, 4) for _ in 1:6] + + # Initial State is a single vector + h = randn(Float32, 5) + test_gradients(r, x, h, loss=loss1) # for loop + test_gradients(r, x, h, loss=loss2) # comprehension + test_gradients(r, x, h, loss=loss3) # splat + test_gradients(r, x, h, loss=loss4) # vcat and stack + + # no initial state same as zero initial state + @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) - @test_broken ∇Wi ≈ grads[:Wi] - @test_broken ∇Wh ≈ grads[:Wh] - @test_broken ∇b ≈ grads[:b] - @test_broken ∇state0 ≈ grads[:state0] + # Now initial state has a batch dimension. + h = randn(Float32, 5, 4) + test_gradients(r, x, h, loss=loss4) + + # The input sequence has no batch dimension. + x = [rand(Float32, 3) for _ in 1:6] + h = randn(Float32, 5) + test_gradients(r, x, h, loss=loss4) + + + # No Bias + r = RNNCell(3 => 5, bias=false) + @test length(Flux.trainables(r)) == 2 + test_gradients(r, x, h, loss=loss4) end -# Ref FluxML/Flux.jl#1209 1D input -@testset "BPTT-1D" begin - seq = [rand(Float32, 2) for i = 1:3] - for r ∈ [RNN,] - rnn = r(2 => 3) - Flux.reset!(rnn) - grads_seq = gradient(Flux.params(rnn)) do - sum([rnn(s) for s in seq][3]) +@testset "RNN" begin + struct ModelRNN + rnn::RNN + h0::AbstractVector end - Flux.reset!(rnn); - bptt = gradient(Wh -> sum(tanh.(rnn.cell.Wi * seq[3] + Wh * - tanh.(rnn.cell.Wi * seq[2] + Wh * - tanh.(rnn.cell.Wi * seq[1] + - Wh * rnn.cell.state0 - + rnn.cell.b) - + rnn.cell.b) - + rnn.cell.b)), - rnn.cell.Wh) - @test grads_seq[rnn.cell.Wh] ≈ bptt[1] - end + + Flux.@layer :expand ModelRNN + + (m::ModelRNN)(x) = m.rnn(x, m.h0) + + model = ModelRNN(RNN(2 => 4), zeros(Float32, 4)) + + x = rand(Float32, 2, 3, 1) + y = model(x) + @test y isa Array{Float32, 3} + @test size(y) == (4, 3, 1) + test_gradients(model, x) + + # no initial state same as zero initial state + rnn = model.rnn + @test rnn(x) ≈ rnn(x, zeros(Float32, 4)) + + x = rand(Float32, 2, 3) + y = model(x) + @test y isa Array{Float32, 2} + @test size(y) == (4, 3) + test_gradients(model, x) end -# Ref FluxML/Flux.jl#1209 2D input -@testset "BPTT-2D" begin - seq = [rand(Float32, (2, 1)) for i = 1:3] - for r ∈ [RNN,] - rnn = r(2 => 3) - Flux.reset!(rnn) - grads_seq = gradient(Flux.params(rnn)) do - sum([rnn(s) for s in seq][3]) +@testset "LSTMCell" begin + + function loss(r, x, hc) + h, c = hc + h′ = [] + c′ = [] + for x_t in x + h, c = r(x_t, (h, c)) + h′ = vcat(h′, [h]) + c′ = [c′..., c] + end + hnew = stack(h′, dims=2) + cnew = stack(c′, dims=2) + return mean(hnew.^2) + mean(cnew.^2) end - Flux.reset!(rnn); - bptt = gradient(Wh -> sum(tanh.(rnn.cell.Wi * seq[3] + Wh * - tanh.(rnn.cell.Wi * seq[2] + Wh * - tanh.(rnn.cell.Wi * seq[1] + - Wh * rnn.cell.state0 - + rnn.cell.b) - + rnn.cell.b) - + rnn.cell.b)), - rnn.cell.Wh) - @test grads_seq[rnn.cell.Wh] ≈ bptt[1] - end -end -@testset "BPTT-3D" begin - seq = rand(Float32, (2, 1, 3)) - rnn = RNN(2 => 3) - Flux.reset!(rnn) - grads_seq = gradient(Flux.params(rnn)) do - sum(rnn(seq)[:, :, 3]) - end - Flux.reset!(rnn); - bptt = gradient(rnn.cell.Wh) do Wh - # calculate state 1 - s1 = tanh.(rnn.cell.Wi * seq[:, :, 1] + - Wh * rnn.cell.state0 + - rnn.cell.b) - #calculate state 2 - s2 = tanh.(rnn.cell.Wi * seq[:, :, 2] + - Wh * s1 + - rnn.cell.b) - #calculate state 3 - s3 = tanh.(rnn.cell.Wi * seq[:, :, 3] + - Wh * s2 + - rnn.cell.b) - sum(s3) # loss is sum of state 3 - end - @test grads_seq[rnn.cell.Wh] ≈ bptt[1] -end + cell = LSTMCell(3 => 5) + @test length(Flux.trainables(cell)) == 3 + x = [rand(Float32, 3, 4) for _ in 1:6] + h = zeros(Float32, 5, 4) + c = zeros(Float32, 5, 4) + hnew, cnew = cell(x[1], (h, c)) + @test hnew isa Matrix{Float32} + @test cnew isa Matrix{Float32} + @test size(hnew) == (5, 4) + @test size(cnew) == (5, 4) + test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1])) + test_gradients(cell, x, (h, c), loss = loss) + + # no initial state same as zero initial state + hnew1, cnew1 = cell(x[1]) + hnew2, cnew2 = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5))) + @test hnew1 ≈ hnew2 + @test cnew1 ≈ cnew2 -@testset "RNN-shapes" begin - @testset for R in [RNN, GRU, LSTM, GRUv3] - m1 = R(3 => 5) - m2 = R(3 => 5) - m3 = R(3, 5) # leave one to test the silently deprecated "," not "=>" notation - x1 = rand(Float32, 3) - x2 = rand(Float32, 3, 1) - x3 = rand(Float32, 3, 1, 2) - Flux.reset!(m1) - Flux.reset!(m2) - Flux.reset!(m3) - @test size(m1(x1)) == (5,) - @test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape - @test size(m2(x2)) == (5, 1) - @test size(m2(x2)) == (5, 1) - @test size(m3(x3)) == (5, 1, 2) - @test size(m3(x3)) == (5, 1, 2) - end + # no bias + cell = LSTMCell(3 => 5, bias=false) + @test length(Flux.trainables(cell)) == 2 end -@testset "multigate" begin - x = rand(6, 5) - res, (dx,) = Flux.withgradient(x) do x - x1, _, x3 = Flux.multigate(x, 2, Val(3)) - sum(x1) + sum(x3 .* 2) - end - @test res == sum(x[1:2, :]) + 2sum(x[5:6, :]) - @test dx == [ones(2, 5); zeros(2, 5); fill(2, 2, 5)] +@testset "LSTM" begin + struct ModelLSTM + lstm::LSTM + h0::AbstractVector + c0::AbstractVector + end + + Flux.@layer :expand ModelLSTM + + (m::ModelLSTM)(x) = m.lstm(x, (m.h0, m.c0)) + + model = ModelLSTM(LSTM(2 => 4), zeros(Float32, 4), zeros(Float32, 4)) + + x = rand(Float32, 2, 3, 1) + h, c = model(x) + @test h isa Array{Float32, 3} + @test size(h) == (4, 3, 1) + @test c isa Array{Float32, 3} + @test size(c) == (4, 3, 1) + test_gradients(model, x, loss = (m, x) -> mean(m(x)[1])) + + x = rand(Float32, 2, 3) + h, c = model(x) + @test h isa Array{Float32, 2} + @test size(h) == (4, 3) + @test c isa Array{Float32, 2} + @test size(c) == (4, 3) + test_gradients(model, x, loss = (m, x) -> mean(m(x)[1])) end -@testset "eachlastdim" begin - x = rand(3, 3, 1, 2, 4) - @test length(Flux.eachlastdim(x)) == size(x, ndims(x)) - @test collect(@inferred(Flux.eachlastdim(x))) == collect(eachslice(x; dims=ndims(x))) - slicedim = (size(x)[1:end-1]..., 1) - res, (dx,) = Flux.withgradient(x) do x - x1, _, x3, _ = Flux.eachlastdim(x) - sum(x1) + sum(x3 .* 3) - end - @test res ≈ sum(selectdim(x, ndims(x), 1)) + 3sum(selectdim(x, ndims(x), 3)) - @test dx ≈ cat(fill(1, slicedim), fill(0, slicedim), - fill(3, slicedim), fill(0, slicedim); dims=ndims(x)) +@testset "GRUCell" begin + function loss(r, x, h) + y = [] + for x_t in x + h = r(x_t, h) + y = vcat(y, [h]) + end + y = stack(y, dims=2) # [D, L] or [D, L, B] + return mean(y.^2) + end + + r = GRUCell(3 => 5) + @test length(Flux.trainables(r)) == 3 + # An input sequence of length 6 and batch size 4. + x = [rand(Float32, 3, 4) for _ in 1:6] + + # Initial State is a single vector + h = randn(Float32, 5) + test_gradients(r, x, h; loss) + + # no initial state same as zero initial state + @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) + + # Now initial state has a batch dimension. + h = randn(Float32, 5, 4) + test_gradients(r, x, h; loss) + + # The input sequence has no batch dimension. + x = [rand(Float32, 3) for _ in 1:6] + h = randn(Float32, 5) + test_gradients(r, x, h; loss) + + # No Bias + r = GRUCell(3 => 5, bias=false) + @test length(Flux.trainables(r)) == 2 end -@testset "∇eachlastdim" begin - x = rand(3, 3, 1, 2, 4) - x_size = size(x) - y = collect(eachslice(x; dims=ndims(x))) - @test @inferred(Flux.∇eachlastdim(y, x)) == x - ZeroTangent = Flux.Zygote.ZeroTangent - NoTangent = Flux.Zygote.NoTangent - abstract_zeros_vector = [ZeroTangent(), ZeroTangent(), NoTangent(), NoTangent()] - @test @inferred(Flux.∇eachlastdim(abstract_zeros_vector, x)) == zeros(size(x)) - x2 = rand(Float64, x_size[1:end-1]) - x3 = rand(Float64, x_size[1:end-1]) - mixed_vector = [ZeroTangent(), x2, x3, ZeroTangent()] - @test @inferred(Flux.∇eachlastdim(mixed_vector, x)) ≈ cat(zeros(x_size[1:end-1]), - x2, - x3, - zeros(x_size[1:end-1]); dims=ndims(x)) +@testset "GRU" begin + struct ModelGRU + gru::GRU + h0::AbstractVector + end + + Flux.@layer :expand ModelGRU + + (m::ModelGRU)(x) = m.gru(x, m.h0) + + model = ModelGRU(GRU(2 => 4), zeros(Float32, 4)) + + x = rand(Float32, 2, 3, 1) + y = model(x) + @test y isa Array{Float32, 3} + @test size(y) == (4, 3, 1) + test_gradients(model, x) + + # no initial state same as zero initial state + gru = model.gru + @test gru(x) ≈ gru(x, zeros(Float32, 4)) + + # No Bias + gru = GRU(2 => 4, bias=false) + @test length(Flux.trainables(gru)) == 2 + test_gradients(gru, x) end -@testset "Different Internal Matrix Types" begin - R = Flux.Recur(Flux.RNNCell(tanh, rand(5, 3), Tridiagonal(rand(5, 5)), rand(5), rand(5, 1))) - # don't want to pull in SparseArrays just for this test, but there aren't any - # non-square structured matrix types in LinearAlgebra. so we will use a different - # eltype matrix, which would fail before when `W_i` and `W_h` were required to be the - # same type. - L = Flux.Recur(Flux.LSTMCell(rand(5*4, 3), rand(1:20, 5*4, 5), rand(5*4), (rand(5, 1), rand(5, 1)))) - G = Flux.Recur(Flux.GRUCell(rand(5*3, 3), rand(1:20, 5*3, 5), rand(5*3), rand(5, 1))) - G3 = Flux.Recur(Flux.GRUv3Cell(rand(5*3, 3), rand(1:20, 5*2, 5), rand(5*3), Tridiagonal(rand(5, 5)), rand(5, 1))) - - for m in [R, L, G, G3] - - x1 = rand(3) - x2 = rand(3, 1) - x3 = rand(3, 1, 2) - Flux.reset!(m) - @test size(m(x1)) == (5,) - Flux.reset!(m) - @test size(m(x1)) == (5,) # repeat in case of effect from change in state shape - @test size(m(x2)) == (5, 1) - Flux.reset!(m) - @test size(m(x2)) == (5, 1) - Flux.reset!(m) - @test size(m(x3)) == (5, 1, 2) - Flux.reset!(m) - @test size(m(x3)) == (5, 1, 2) - end +@testset "GRUv3Cell" begin + r = GRUv3Cell(3 => 5) + @test length(Flux.trainables(r)) == 4 + x = rand(Float32, 3) + + # Initial State is a single vector + h = randn(Float32, 5) + test_gradients(r, x, h) + + # no initial state same as zero initial state + @test r(x) ≈ r(x, zeros(Float32, 5)) + + # Now initial state has a batch dimension. + h = randn(Float32, 5, 4) + test_gradients(r, x, h) + + # The input sequence has no batch dimension. + x = rand(Float32, 3) + h = randn(Float32, 5) + test_gradients(r, x, h) end -@testset "type matching" begin - x = rand(Float64, 2, 4) - m1 = RNN(2=>3) - @test m1(x) isa Matrix{Float32} # uses _match_eltype, may print a warning - @test m1.state isa Matrix{Float32} - @test (@inferred m1(x); true) - @test Flux.outputsize(m1, size(x)) == size(m1(x)) - - m2 = LSTM(2=>3) - @test m2(x) isa Matrix{Float32} - @test (@inferred m2(x); true) - @test Flux.outputsize(m2, size(x)) == size(m2(x)) - - m3 = GRU(2=>3) - @test m3(x) isa Matrix{Float32} - @test (@inferred m3(x); true) - @test Flux.outputsize(m3, size(x)) == size(m3(x)) - - m4 = GRUv3(2=>3) - @test m4(x) isa Matrix{Float32} - @test (@inferred m4(x); true) - @test Flux.outputsize(m4, size(x)) == size(m4(x)) -end \ No newline at end of file +@testset "GRUv3" begin + struct ModelGRUv3 + gru::GRUv3 + h0::AbstractVector + end + + Flux.@layer :expand ModelGRUv3 + + (m::ModelGRUv3)(x) = m.gru(x, m.h0) + + model = ModelGRUv3(GRUv3(2 => 4), zeros(Float32, 4)) + + x = rand(Float32, 2, 3, 1) + y = model(x) + @test y isa Array{Float32, 3} + @test size(y) == (4, 3, 1) + test_gradients(model, x) + + # no initial state same as zero initial state + gru = model.gru + @test gru(x) ≈ gru(x, zeros(Float32, 4)) +end diff --git a/test/test_utils.jl b/test/test_utils.jl index f9a6b6655f..25a4f1af47 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -39,7 +39,7 @@ function test_gradients( test_grad_f = true, test_grad_x = true, compare_finite_diff = true, - loss = mean, + loss = (f, xs...) -> mean(f(xs...)), ) if !test_gpu && !compare_finite_diff @@ -47,27 +47,34 @@ function test_gradients( or CPU AD vs GPU AD.") end + ## Let's make sure first that the forward pass works. + l = loss(f, xs...) + @test l isa Number + if test_gpu + gpu_dev = gpu_device(force=true) + cpu_dev = cpu_device() + xs_gpu = xs |> gpu_dev + f_gpu = f |> gpu_dev + l_gpu = loss(f_gpu, xs_gpu...) + @test l_gpu isa Number + end + if test_grad_x # Zygote gradient with respect to input. - y, g = Zygote.withgradient((xs...) -> loss(f(xs...)), xs...) + y, g = Zygote.withgradient((xs...) -> loss(f, xs...), xs...) if compare_finite_diff # Cast to Float64 to avoid precision issues. f64 = f |> Flux.f64 xs64 = xs .|> Flux.f64 - y_fd, g_fd = finitediff_withgradient((xs...) -> loss(f64(xs...)), xs64...) + y_fd, g_fd = finitediff_withgradient((xs...) -> loss(f64, xs...), xs64...) @test y ≈ y_fd rtol=rtol atol=atol check_equal_leaves(g, g_fd; rtol, atol) end if test_gpu - gpu_dev = gpu_device(force=true) - cpu_dev = cpu_device() - xs_gpu = xs |> gpu_dev - f_gpu = f |> gpu_dev - # Zygote gradient with respect to input on GPU. - y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu(xs...)), xs_gpu...) + y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, xs...), xs_gpu...) @test get_device(g_gpu) == get_device(xs_gpu) @test y_gpu ≈ y rtol=rtol atol=atol check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) @@ -76,31 +83,25 @@ function test_gradients( if test_grad_f # Zygote gradient with respect to f. - y, g = Zygote.withgradient(f -> loss(f(xs...)), f) + y, g = Zygote.withgradient(f -> loss(f, xs...), f) if compare_finite_diff - # Use finite differences gradient as a reference. - # y_fd, g_fd = finitediff_withgradient(f -> loss(f(x)), f) # Cast to Float64 to avoid precision issues. f64 = f |> Flux.f64 ps, re = Flux.destructure(f64) - y_fd, g_fd = finitediff_withgradient(ps -> loss(re(ps)(xs...)), ps) + y_fd, g_fd = finitediff_withgradient(ps -> loss(re(ps), xs...), ps) g_fd = (re(g_fd[1]),) @test y ≈ y_fd rtol=rtol atol=atol check_equal_leaves(g, g_fd; rtol, atol) end if test_gpu - gpu_dev = gpu_device(force=true) - cpu_dev = cpu_device() - xs_gpu = xs |> gpu_dev - f_gpu = f |> gpu_dev - # Zygote gradient with respect to f on GPU. - y_gpu, g_gpu = Zygote.withgradient(f -> loss(f(xs_gpu...)), f_gpu) + y_gpu, g_gpu = Zygote.withgradient(f -> loss(f, xs_gpu...), f_gpu) # @test get_device(g_gpu) == get_device(xs_gpu) @test y_gpu ≈ y rtol=rtol atol=atol check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) end end + return true end diff --git a/test/utils.jl b/test/utils.jl index 79eebded49..0236a3d636 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -251,19 +251,19 @@ end end @testset "Params" begin - m = Dense(10, 5) + m = Dense(10 => 5) @test size.(params(m)) == [(5, 10), (5,)] - m = RNN(10, 5) - @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)] + m = RNN(10 => 5) + @test size.(params(m)) == [(5, 10), (5, 5), (5,)] # Layer duplicated in same chain, params just once pls. c = Chain(m, m) - @test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)] + @test size.(params(c)) == [(5, 10), (5, 5), (5,)] # Self-referential array. Just want params, no stack overflow pls. r = Any[nothing,m] r[1] = r - @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)] + @test size.(params(r)) == [(5, 10), (5, 5), (5,)] # Ensure functor explores inside Transpose but not SubArray m = (x = view([1,2,3]pi, 1:2), y = transpose([4 5]pi)) @@ -273,7 +273,7 @@ end @testset "params gradient" begin m = (x=[1,2.0], y=[3.0]); - # Explicit -- was broken by #2054 / then fixed / now broken again on julia v1.11 + # Explicit -- was broken by #2054 gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1] @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] @test gnew.y ≈ [1.0] @@ -286,7 +286,7 @@ end end @testset "Precision" begin - m = Chain(Dense(10, 5, relu; bias=false), Dense(5, 2)) + m = Chain(Dense(10 => 5, relu; bias=false), Dense(5 => 2)) x64 = rand(Float64, 10) x32 = rand(Float32, 10) i64 = rand(Int64, 10) @@ -467,10 +467,10 @@ end @test modules[5] === m2 @test modules[6] === m3 - mod_par = Flux.modules(Parallel(Flux.Bilinear(2,2,2,cbrt), Dense(2,2,abs), Dense(2,2,abs2))) + mod_par = Flux.modules(Parallel(Flux.Bilinear(2,2,2,cbrt), Dense(2=>2,abs), Dense(2=>2,abs2))) @test length(mod_par) == 5 - mod_rnn = Flux.modules(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4))) + mod_rnn = Flux.modules(Chain(Dense(2=>3), BatchNorm(3), LSTM(3=>4))) @test length(mod_rnn) == 6 @test mod_rnn[end] isa Flux.LSTMCell