Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Use context managers to toggle the recurrent mode of RNN modules? #2562

Open
1 task done
thomasbbrunner opened this issue Nov 14, 2024 · 7 comments
Open
1 task done
Assignees
Labels
enhancement New feature or request

Comments

@thomasbbrunner
Copy link
Contributor

Motivation

In TorchRL users must manually set the mode of RNN modules using the set_recurrent_mode method. This toggles between processing steps individually or processing entire sequences of steps.

We believe that this approach has some issues:

  1. Requires keeping track and maintaining two versions of the policy (yes, they have the same weights, but are still two objects).
  2. Is cumbersome when dealing with large policies with multiple sub-modules (as you have to re-implement the set_recurrent_mode for the policy).
  3. Seems to be easy to get wrong for people new to the code.

Proposed Alternative

Can we leverage context managers for this? Similar to how tensordict does with set_interaction_type.

For instance:

input = TensorDict(...)
lstm = LSTMModule(...)
mlp = MLP(...)
policy = TensorDictSequential(lstm, mlp)

# By default, the lstm would not be in recurrent mode.
policy(input)

# Recurrent mode can be activated with a context manager.
with set_recurrent_mode(True):
    policy(input)

Have you considered this approach in the past?

Potential Implementation

The set_recurrent_mode could be implemented in a similar fashion to the set_interaction_type:

_RECURRENT_MODE: bool = False

class set_recurrent_mode(_DecoratorContextManager):
    def __init__(self, mode: bool = False) -> None:
        super().__init__()
        self.mode = mode

    def __enter__(self) -> None:
        global _RECURRENT_MODE
        self.prev = _RECURRENT_MODE
        _RECURRENT_MODE = self.mode

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        global _RECURRENT_MODE
        _RECURRENT_MODE = self.prev

Potential Issues

  1. Unsure of the implications of this in distributed systems (is this thread or process safe?).
  2. Users could still forget to set this mode.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@thomasbbrunner thomasbbrunner added the enhancement New feature or request label Nov 14, 2024
@thomasbbrunner
Copy link
Contributor Author

If there's interest in building something like this, I'd be wiling to invest time into it.

@vmoens
Copy link
Contributor

vmoens commented Nov 14, 2024

Yep we could perfectly do this, I don't think anyone would wish to run the RNN in both modes in the same function.
I like CMs but I know some people don't. The main issue is usually how to make distributed ops know what the context is, but I don't think it's a problem here.

Wanna work on this, or should I?

What's you opinion on these use cases:

policy = make_policy()

policy(input)  # what's the default?
policy.set_recurrent_mode(False)  # should raise a deprec warning?
with set_recurrent_mode(True):
    policy(input) # does the decorator overrides the mode set in the line above? 

with set_recurrent_mode(False):
    policy.set_recurrent_mode(True) # does this work?
    policy(input)

If the old API is to be deprecated, it's fine if the CM overrides the internal mode, but if we want both APIs to coexist it can be tricky to decide who should prevail

@thomasbbrunner
Copy link
Contributor Author

thomasbbrunner commented Nov 15, 2024

I like CMs but I know some people don't.

The only thing I don't particularly like is the setting of a global variable. Would make sense to have a lock for it?

The main issue is usually how to make distributed ops know what the context is, but I don't think it's a problem here.

Yes, this would def not work in this setting. This is only really applicable for local operations and, tbf ideally the context where this is active should be short.

The use-case I imagine is something like this:

for _ in range(num_steps):
    # Collect batch
    td = env.rollout(100, policy)

    # Train on batch
    with set_recurrent_mode(True):
        loss = loss_module(td)
    
    loss.backward()
    ...

To your questions:

policy(input) # what's the default?

I'd say we should keep the current default (recurrent mode off)

policy.set_recurrent_mode(False) # should raise a deprec warning?

I'd say we could support both approaches of setting the recurrent mode. The context should be used in short-lived use-cases, while the method is better in the case of distributed systems.

with set_recurrent_mode(True):
   policy(input) # does the decorator overrides the mode set in the line above? ```

I'd say that the context takes precedence over the default recurrent mode.

with set_recurrent_mode(False):
    policy.set_recurrent_mode(True) # does this work?
    policy(input)```

Uuuh, tricky. Based on the previous statement ("context takes precedence over the default recurrent mode") I'd say that this would run, but the context would still take precedence.

I'd suggest that the method set_recurrent_mode would setting the default recurrent mode? Which can then be overridden by the context.

Maybe this should also be accompanied by a change to the interface set_recurrent_mode --> set_default_recurrent_mode.

Wanna work on this, or should I?

I'd be interested, but it might take some time due to some other high-prio work. So feel free to take it over!

@vmoens
Copy link
Contributor

vmoens commented Nov 15, 2024

The only thing I don't particularly like is the setting of a global variable. Would make sense to have a lock for it?

Sure, whatever works!

@vmoens
Copy link
Contributor

vmoens commented Nov 15, 2024

@thomasbbrunner one thing I learned in the PyTorch team is that having 2 similar APIs to do one thing should be avoided when possible.
I like the CM better, so I'm enclined to adopt it and deprecate the other. If you think it's a viable way, I can make a PR pretty quickly.

@thomasbbrunner
Copy link
Contributor Author

@vmoens

one thing I learned in the PyTorch team is that having 2 similar APIs to do one thing should be avoided when possible.

Definitely agree with that! I also prefer the CM approach, so I'd be ok with deprecating the other.

At the same time, I feel like the two APIs serve slightly different purposes. I guess it's similar to with torch.no_grad(): and the requires_grad argument for torch.Tensor and Parameter.

Maybe for some use-cases (like distributed setups) it would be beneficial to set a default recurrent mode. I don't think that the set_recurrent_mode method is ideal for this. I'd argue that a recurrent_mode argument in the constructor would work better. Either way, def something that could be implemented if needed at a later point.

@thomasbbrunner
Copy link
Contributor Author

Also, feel free to work on this, I won't have much time this week!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants