-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature requests for RNNs #2514
Comments
Could it be possible to add to the list the option to use different initializers for the input matrix and recurrent matrix? This is provided by both Keras/TF and Flax. This should be as straightforward as function RNNCell((in, out)::Pair, σ=relu;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
Wi = kernel_init(out, in)
U = recurrent_kernel_init(out, out)
b = create_bias(Wi, bias, size(Wi, 1))
return RNNCell(σ, Wi, U, b)
end I can also open a quick PR on this if needed |
yes! PR welcome |
Following up on this, should we also have an option to choose the init for the bias? |
We don't do it for feedforward layers, if someone wants a non-zero bias can just change it manually in the constructor, |
Ehi, a couple of questions on this features request again. Would the would something simple like function initialstates(rnn::RNNCell; init_state = zeros)
state = init_state(size(rnn.Wh, 2))
return state
end
function initialstates(lstm::LSTMCell; init_state = zeros, init_cstate = zeros)
state = init_state(size(lstm.Wh, 2))
cstate = init_cstate(size(lstm.Wh, 2))
return state, cstate
end suffice or were you looking for something more? Maybe more control on the type would be needed |
I would just have
If different initializations are needed, we could add an |
so this way we would simply do function (rnn::RNNCell)(inp::AbstractVecOrMat)
state = initialstates(rnn)
return rnn(inp, state)
end to keep compatibility for the current version, right? I think your point is good, additionally no other library provides a specific |
Trying to tackle adding struct TestRNN{A, B}
cells::A
dropout_layer::B
end
Flux.@layer :expand TestRNN
function TestRNN((in_size, out_size)::Pair;
n_layers::Int=1,
dropout::Float64=0.0,
kwargs...)
cells = []
for i in 1:n_layers
tin_size = i == 1 ? in_size : out_size
push!(cells, RNNCell(tin_size => out_size; kwargs...))
end
if dropout > 0.0
dropout_layer = Dropout(dropout)
else
dropout_layer = nothing
end
return TestRNN(cells, dropout_layer)
end
function (rnn::TestRNN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
output = []
num_layers = length(rnn.cells)
for inp_t in eachslice(inp, dims=2)
new_states = []
for (idx_cell, cell) in enumerate(rnn.cells)
new_state = cell(inp_t, state[:, idx_cell])
new_states = vcat(new_states, [new_state])
inp_t = new_state
if rnn.dropout_layer isa Dropout && idx_cell < num_layers - 1
inp_t = rnn.dropout_layer(inp_t)
end
end
state = stack(new_states)
output = vcat(output, [inp_t])
end
output = stack(output, dims=2)
return output, state
end
|
After the redesign in #2500, here is a list of potential improvements for recurrent layers and recurrent cells
add an option in constructors to have trainable initial statelet's keep it simple (and also follow pytorch) by not having thisBidirectional
for RNN layers #1790)initialstates
function. It could be useful in the LSTM case where the initial state is more complicated (two vectors).Recur
(but maybe this is confusing) orRecurrence
as in Lux.The text was updated successfully, but these errors were encountered: