Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
finish RNNCell

RNN rework

LSTMCell

LSTM

more work

gru

extended testing

runtests

add tests

finish RNNCell

RNN rework

LSTMCell

LSTM

more work

gru

extended testing

reset! deprecation

fix test

unbreak l2 test

fix tests

fixes
  • Loading branch information
CarloLucibello committed Oct 21, 2024
1 parent 31dccd1 commit 73dae52
Show file tree
Hide file tree
Showing 15 changed files with 957 additions and 676 deletions.
2 changes: 0 additions & 2 deletions perf/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@ Flux.@functor RNNWrapper

# Need to specialize for RNNWrapper.
fw(r::RNNWrapper, X::Vector{<:AbstractArray}) = begin
Flux.reset!(r.rnn)
[r.rnn(x) for x in X]
end

fw(r::RNNWrapper, X) = begin
Flux.reset!(r.rnn)
r.rnn(X)
end

Expand Down
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zyg

export Chain, Dense, Embedding, EmbeddingBag,
Maxout, SkipConnection, Parallel, PairwiseFusion,
RNNCell, LSTMCell, GRUCell, GRUv3Cell,
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Expand Down
5 changes: 5 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,8 @@ end
# where `loss_mxy` accepts the model as its first argument.
# """
# ))

function reset!(x)
Base.depwarn("reset!(m) is deprecated. You can remove this call as it is no more needed.", :reset!)
return x
end
Loading

0 comments on commit 73dae52

Please sign in to comment.