-
Notifications
You must be signed in to change notification settings - Fork 312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Use context managers to toggle the recurrent mode of RNN modules? #2562
Comments
If there's interest in building something like this, I'd be wiling to invest time into it. |
Yep we could perfectly do this, I don't think anyone would wish to run the RNN in both modes in the same function. Wanna work on this, or should I? What's you opinion on these use cases: policy = make_policy()
policy(input) # what's the default?
policy.set_recurrent_mode(False) # should raise a deprec warning?
with set_recurrent_mode(True):
policy(input) # does the decorator overrides the mode set in the line above?
with set_recurrent_mode(False):
policy.set_recurrent_mode(True) # does this work?
policy(input) If the old API is to be deprecated, it's fine if the CM overrides the internal mode, but if we want both APIs to coexist it can be tricky to decide who should prevail |
The only thing I don't particularly like is the setting of a global variable. Would make sense to have a lock for it?
Yes, this would def not work in this setting. This is only really applicable for local operations and, tbf ideally the context where this is active should be short. The use-case I imagine is something like this: for _ in range(num_steps):
# Collect batch
td = env.rollout(100, policy)
# Train on batch
with set_recurrent_mode(True):
loss = loss_module(td)
loss.backward()
... To your questions:
I'd say we should keep the current default (recurrent mode off)
I'd say we could support both approaches of setting the recurrent mode. The context should be used in short-lived use-cases, while the method is better in the case of distributed systems.
I'd say that the context takes precedence over the default recurrent mode.
Uuuh, tricky. Based on the previous statement ("context takes precedence over the default recurrent mode") I'd say that this would run, but the context would still take precedence. I'd suggest that the method Maybe this should also be accompanied by a change to the interface
I'd be interested, but it might take some time due to some other high-prio work. So feel free to take it over! |
Sure, whatever works! |
@thomasbbrunner one thing I learned in the PyTorch team is that having 2 similar APIs to do one thing should be avoided when possible. |
Definitely agree with that! I also prefer the CM approach, so I'd be ok with deprecating the other. At the same time, I feel like the two APIs serve slightly different purposes. I guess it's similar to Maybe for some use-cases (like distributed setups) it would be beneficial to set a default recurrent mode. I don't think that the |
Also, feel free to work on this, I won't have much time this week! |
Motivation
In TorchRL users must manually set the mode of RNN modules using the
set_recurrent_mode
method. This toggles between processing steps individually or processing entire sequences of steps.We believe that this approach has some issues:
set_recurrent_mode
for the policy).Proposed Alternative
Can we leverage context managers for this? Similar to how
tensordict
does withset_interaction_type
.For instance:
Have you considered this approach in the past?
Potential Implementation
The
set_recurrent_mode
could be implemented in a similar fashion to theset_interaction_type
:Potential Issues
Checklist
The text was updated successfully, but these errors were encountered: