-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Do not learn RNN initial state by default #808
Conversation
This is a fairly big change, given that it could break people's models, and it isn't obvious to me that not learning by default is better than learning by default. It'd be good to have some discussion of the pros and cons here. The other important thing is that any new APIs we expose for this kind of thing need to be reasonably compatible with the new Zygote-y way of doing RNNs, which will land soon (#648). Otherwise we're adding this just to soon remove it again. |
This PR came up after a Slack thread with @oxinabox . My main argument "for" would be that I have never come across trainable initial state in literature, including articles that are referenced from Flux documentation (e.g. this one). So I thought that keeping the initial state constant at zero is the default convention in the DL community. Happy to change it the other way around, if persuaded. As for "Zygote-y way" - this change does not go inside |
Zygote completely removes any tracking or notion of parameters, only about what happens inside the call to The good news is that the Zygote version of this will actually do what you want by default. m = LSTM(...)
reset!(m) # redundant here but usually you'd want this per sample
gradient(m) do m
for x in xs
m(x) ... Inside m = LSTM(...)
gradient(m) do m
reset!(m)
for x in xs
m(x) ... Since now there's an explicit link between the state of the RNN and the parameter of the cell, so this trains your hidden state. Unfortunately there's no way to simulate this distinction with Tracker (part of the reason that Zygote exists). So we'll have to wait for that; in the mean time, I think it'd be OK to add a (non-breaking) utility around this if it's needed. But I think you can already get that, effectively, with |
That would indeed be the case had As a side note, I got a completely wrong idea of what
If Zygote will break things anyway, I think there is even less value in maintaining the current behaviour, as opposed to maintaining the mathematical conventions. |
Your understanding of what What could be happening here is that |
Initial states should be trainable. These are used in some RNN literature to provide initial "input" or "bias" for a sequence. |
@tfburns the question is not: should it be trainable, but rather should it be trainable by default. |
814992d
to
d19988f
Compare
why is this closed now? |
accidental force push to my fork. reopening |
d19988f
to
e3105df
Compare
Would be good to update it as per master with Zygote |
This is no longer be necessary for the reasons mentioned in #808 (comment). We definitely do need more docs and tooling around RNNs with Zygote to show how to do these things, though. |
@MikeInnes Any idea why the following code snippet computes a gradient for the initial state? Of course, I can just change my code to use rnn = Flux.RNN(2, 3)
seq = [rand(2) for i = 1:3]
grads = gradient(Flux.params(rnn)) do
loss = 0f0
for s in seq
loss += sum(rnn(s))
end
loss
end
grads[rnn.init] |
@AlexLewandowski i susggest discussion this on #807 |
Do not learn RNN initial state during backprop by default; leave this as an option
#807