Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not learn RNN initial state by default #808

Closed
wants to merge 2 commits into from

Conversation

tanhevg
Copy link

@tanhevg tanhevg commented Jul 9, 2019

Do not learn RNN initial state during backprop by default; leave this as an option
#807

@tanhevg tanhevg changed the title Do not learn RNN initial state during backprop by default Do not learn RNN initial state by default Jul 9, 2019
@MikeInnes
Copy link
Member

This is a fairly big change, given that it could break people's models, and it isn't obvious to me that not learning by default is better than learning by default. It'd be good to have some discussion of the pros and cons here.

The other important thing is that any new APIs we expose for this kind of thing need to be reasonably compatible with the new Zygote-y way of doing RNNs, which will land soon (#648). Otherwise we're adding this just to soon remove it again.

@tanhevg
Copy link
Author

tanhevg commented Jul 11, 2019

This PR came up after a Slack thread with @oxinabox . My main argument "for" would be that I have never come across trainable initial state in literature, including articles that are referenced from Flux documentation (e.g. this one). So I thought that keeping the initial state constant at zero is the default convention in the DL community. Happy to change it the other way around, if persuaded.

As for "Zygote-y way" - this change does not go inside Tracker, so from (briefly) reading the intro to Zygote I thought that the API should still be compatible. Please let me know if that is not the case. Is there a place where I could have a look how Flux with Zygote is shaping up?

@MikeInnes
Copy link
Member

Zygote completely removes any tracking or notion of parameters, only about what happens inside the call to gradient, so that's likely to turn the intuition about this a bit on its head.

The good news is that the Zygote version of this will actually do what you want by default.

m = LSTM(...)
reset!(m) # redundant here but usually you'd want this per sample
gradient(m) do m
  for x in xs
    m(x) ...

Inside gradient, the cell's hidden state parameter is never actually touched (since we copied it to Recur.state), so the gradient is zero. Things are actually different if you do:

m = LSTM(...)
gradient(m) do m
  reset!(m)
  for x in xs
    m(x) ...

Since now there's an explicit link between the state of the RNN and the parameter of the cell, so this trains your hidden state.

Unfortunately there's no way to simulate this distinction with Tracker (part of the reason that Zygote exists). So we'll have to wait for that; in the mean time, I think it'd be OK to add a (non-breaking) utility around this if it's needed. But I think you can already get that, effectively, with truncate!, so that might be enough in the short term.

@tanhevg
Copy link
Author

tanhevg commented Jul 13, 2019

But I think you can already get that, effectively, with truncate!, so that might be enough in the short term.

That would indeed be the case had truncate! erased the gradients for init. But it does not; it only clears them for state. Should both of them be erased at the same time?

As a side note, I got a completely wrong idea of what truncate! does from the documentation. The doc does not mention that it affects only the hidden state; I thought that model weights would also be truncated. Also, I thought that the gradient will resume accumulating after its history is wiped by truncate!, whereas in reality there will be no gradient at all.

I think it'd be OK to add a (non-breaking) utility around this if it's needed.

If Zygote will break things anyway, I think there is even less value in maintaining the current behaviour, as opposed to maintaining the mathematical conventions.

@MikeInnes
Copy link
Member

Your understanding of what truncate! is supposed to do from the docs seems right to me. It should drop gradients for anything that was used to calculate state, which includes model weights from past time steps as well as init.

What could be happening here is that state === init is a bit of an edge case that we're potentially not handling correctly. If that's the case then it should be easy to fix.

@tfburns
Copy link

tfburns commented Aug 20, 2019

Initial states should be trainable. These are used in some RNN literature to provide initial "input" or "bias" for a sequence.

@oxinabox
Copy link
Member

oxinabox commented Aug 20, 2019

@tfburns the question is not: should it be trainable, but rather should it be trainable by default.
(To which I say it should not. In my experience it is the exception, rather than the rule to train it)

@tanhevg tanhevg force-pushed the tanhevg/rnn_initial_state_807 branch from 814992d to d19988f Compare September 6, 2019 09:05
@tanhevg tanhevg closed this Sep 6, 2019
@tanhevg tanhevg deleted the tanhevg/rnn_initial_state_807 branch September 6, 2019 15:31
@oxinabox
Copy link
Member

oxinabox commented Sep 6, 2019

why is this closed now?

@tanhevg
Copy link
Author

tanhevg commented Sep 6, 2019

accidental force push to my fork. reopening

@tanhevg tanhevg reopened this Sep 6, 2019
@tanhevg tanhevg force-pushed the tanhevg/rnn_initial_state_807 branch from d19988f to e3105df Compare September 6, 2019 16:48
@DhairyaLGandhi
Copy link
Member

Would be good to update it as per master with Zygote

@MikeInnes
Copy link
Member

This is no longer be necessary for the reasons mentioned in #808 (comment). We definitely do need more docs and tooling around RNNs with Zygote to show how to do these things, though.

@AlexLewandowski
Copy link

AlexLewandowski commented Jul 21, 2020

Zygote completely removes any tracking or notion of parameters, only about what happens inside the call to gradient, so that's likely to turn the intuition about this a bit on its head.

The good news is that the Zygote version of this will actually do what you want by default.

m = LSTM(...)
reset!(m) # redundant here but usually you'd want this per sample
gradient(m) do m
  for x in xs
    m(x) ...

Inside gradient, the cell's hidden state parameter is never actually touched (since we copied it to Recur.state), so the gradient is zero. Things are actually different if you do:

m = LSTM(...)
gradient(m) do m
  reset!(m)
  for x in xs
    m(x) ...

Since now there's an explicit link between the state of the RNN and the parameter of the cell, so this trains your hidden state.

Unfortunately there's no way to simulate this distinction with Tracker (part of the reason that Zygote exists). So we'll have to wait for that; in the mean time, I think it'd be OK to add a (non-breaking) utility around this if it's needed. But I think you can already get that, effectively, with truncate!, so that might be enough in the short term.

@MikeInnes Any idea why the following code snippet computes a gradient for the initial state? Of course, I can just change my code to use gradient(m) do m but I'm curious why this is different. Edit: my only guess is that gradient(Flux.params(rnn)) do explicitly asks for the gradient for every param. I'm just not sure how gradient(m) do m handles things differently.

rnn = Flux.RNN(2, 3)
seq = [rand(2) for i = 1:3]

grads = gradient(Flux.params(rnn)) do
    loss = 0f0
    for s in seq
        loss += sum(rnn(s))
    end
    loss
end

grads[rnn.init]

@oxinabox
Copy link
Member

@AlexLewandowski i susggest discussion this on #807

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants