-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backprop through time #648
Comments
What about the train_step! proposed at: #607 (comment) ? I'm trying to translate some PyTorch that does BPTT and in the code I'm translating (ENAS-Pytorch) they seem to have put in an explicit loop inside their forward function to do the time steps (35 time steps in this case). I suppose that's possible in with Flux's train! function: you could put the train! call inside of a loop that counts to the number of time steps you want for BPTT and give it a batch of data in each iteration - would that work? |
Yeah, right now the Flux approach to this is essentially the same as PyTorch. |
Hi, where's the code implementing BPTT? (in recurrent.jl ?) I wrote some code doing BPTT using plain julia, and would like to compare the intermediate results with Flux. The fprop is easy to understand, but i didn't find where the bprop is implemented. Could someone pointing me to where to check? Thanks. (a google search landed me on this page, which seems the best place to ask) |
There is no code directly in Flux.jl implementing BPTT. This is "just" calling |
does the "gradient" function accumulate gradients through time somehow? in case of seq2one, how does the function know how to loop? ideally, I would like to be able to intercept and verify the accumulation process. for comparison, a manually written bptt has something like the following: Are there variables similar to "dWhh" kept somewhere? karpathy's code for t in reversed(xrange(len(inputs))): |
Yes, the AD backend, Zygote, will handle gradient accumulation through the loop. See this comment for how you can implement many to many or many to one models. Also check the recurrent docs. |
thanks for your reply, darsnack. I posted another question in issue 144, as that one seems more recent. Do you mind taking a look?
|
closing as old |
Continuing our series "cool things we can't have yet", and inspired by this comment I was thinking about how we'll expose BPTT. Currently, given a forward pass like this:
If we don't want to backprop over the whole sequence at once (
gradient
outside the loop) or over only a single step at a time (gradient
inside the loop) then we need to split the loop as follows:An alternative to this is to just expose primitives that let us fiddle with time steps directly. Consider:
Alright, bear with me here. This is written as if we were backprop-ing only across a single time step at a time, but with model evaluation wrapped in
record
. The idea is thatrecord
will log 5 previous backpropagators for the closure it is passed, and then chain these together for the backwards pass, which means we can actually backpropagate throughn
previous iterations of the loop -- i.e. backpropagation through time.What's cool about this is that it makes BPTT completely orthogonal to the structure of the forward pass. The recorder can equally well be set up to backprop the last
n
steps each iteration (sliding window BTTF) or only everyn
th iteration (normal BTTF), or anything in between, and this can be set up differently for different parts of the model. It also isn't specific to any particular RNN implementation, e.g. this will work even though we have to backprop throughh
over loop iterations:The main question is whether this is actually going to be intuitive for people (who aren't travelling at 88mph). If it looks weird right now I think that's partly because we're not used to using
gradient
this way, so getting used to that will make the extra feature easier to reason about. At least for sliding windows, I think it's strictly better than flow-based alternatives.The text was updated successfully, but these errors were encountered: