Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] SliceSampler should return unique IDs when sampling multiple times from the same trajectory #2588

Open
3 tasks done
dmelcer9 opened this issue Nov 20, 2024 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@dmelcer9
Copy link

Describe the bug

When using SliceSampler, with strict_length=False, the documentation recommends the use of split_trajectories. However, if two samples from the same episode are placed next to each other, this produces the wrong output because subsequent samples may have the same trajectory_key despite being logically independent.

To Reproduce

import torch
from tensordict import TensorDict
from torchrl.collectors.utils import split_trajectories
from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler

rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
                  sampler=SliceSampler(
                      slice_len=5, traj_key="episode",strict_length=False
                  ))

ep_1 = TensorDict(
    {"obs": torch.arange(100),
    "episode": torch.zeros(100),},
    batch_size=[100]
)
ep_2 = TensorDict(
    {"obs": torch.arange(4),
    "episode": torch.ones(4),},
    batch_size=[4]
)
rb.extend(ep_1)
rb.extend(ep_2)

s = rb.sample(50)
t = split_trajectories(s, trajectory_key="episode")

split_trajectories returns nonsense results when trajectory_key contains non-contiguous duplicates.
Even if that weren't the case, there would still be a bug:

When SliceSampler is drawing from relatively few trajectories, there will be situations where multiple slices of the same trajectory are returned next to each other:

episode 0  0  0  0  0  0  0  0  0  0...
obs     2  3  4  5  6  41 42 43 44 45...
        |-1st slice-|  |-2nd slice--|

However, split_trajectories will see that episode is the same for both slices, and incorrectly combine them into one longer slice.

Expected behavior

SliceSampler should add an additional key to its returned dict to distinguish samples, at least when strict_length=False:

episode 0  0  0  0  0  0  0  0  0  0...
obs     2  3  4  5  6  41 42 43 44 45...
slice   0  0  0  0  0  1  1  1  1  1

Screenshots

If applicable, add screenshots to help explain your problem.

System info

M1 Mac, version 15.1

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.6.0+7bf320c 1.26.4 3.11.9 (main, Apr 19 2024, 11:44:45) [Clang 14.0.6 ] darwin

Both torchrl and tensordict were installed from source.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@dmelcer9 dmelcer9 added the bug Something isn't working label Nov 20, 2024
@dmelcer9 dmelcer9 changed the title [BUG] SplitSampler should return unique IDs when sampling multiple times from the same trajectory [BUG] SliceSampler should return unique IDs when sampling multiple times from the same trajectory Nov 20, 2024
@dmelcer9
Copy link
Author

Taking a fresh look at this again, it seems that a workaround may be to do something like:

sample, info = rb.sample(minibatch_size, return_info=True)

sample["next", "end_of_slice"] = (
    info["next", "truncated"]
    | info["next", "done"]
    | info["next", "terminated"]
)

sample = split_trajectories(sample, done_key="end_of_slice")

But this is hardly ergonomic, or should at least be clarified as an example in the documentation.

@vmoens
Copy link
Contributor

vmoens commented Nov 20, 2024

Hey
Thanks for reporting this

  1. One option is to use SliceSamplerWithoutReplacement:
import torch
from tensordict import TensorDict
from torchrl.collectors.utils import split_trajectories
from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement

rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
                  sampler=SliceSamplerWithoutReplacement(
                      slice_len=5, traj_key="episode",strict_length=False
                  ))

ep_1 = TensorDict(
    {"obs": torch.arange(100),
    "episode": torch.zeros(100),},
    batch_size=[100]
)
ep_2 = TensorDict(
    {"obs": torch.arange(4),
    "episode": torch.ones(4),},
    batch_size=[4]
)
rb.extend(ep_1)
rb.extend(ep_2)

s = rb.sample(50)
t = split_trajectories(s, trajectory_key="episode")
print(t["obs"])
print(t["episode"])

That will ensure that you don't have the same item twice

  1. Another is to use TensorDictReplayBuffer with the slice sampler. That will update the ("next", "truncated") key in the sampled data and split_trajectories can understand that
import torch
from tensordict import TensorDict
from torchrl.collectors.utils import split_trajectories
from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement

rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000),
                  sampler=SliceSampler(
                      slice_len=5, traj_key="episode",strict_length=False,
                  ))

ep_1 = TensorDict(
    {"obs": torch.arange(100),
    "episode": torch.zeros(100),},
    batch_size=[100]
)
ep_2 = TensorDict(
    {"obs": torch.arange(4),
    "episode": torch.ones(4),},
    batch_size=[4]
)
rb.extend(ep_1)
rb.extend(ep_2)

s = rb.sample(50)
print(s)
t = split_trajectories(s, done_key="truncated")
print(t["obs"])
print(t["episode"])
  1. Finally there is your solution (do it manually) but as you mention it's clunky. If you do it manually you could also just do
s, info = rb.sample(50, return_info=True)
print(s)
s["next", "truncated"] = info[("next", "truncated")]
t = split_trajectories(s, done_key="truncated")

But in general I do agree that we need better doc.
Aside from the docstrings of the slice sampler, where would you look for that info?

@dmelcer9
Copy link
Author

Thanks for responding so quickly!

In my particular case, I am collecting a few episodes (of wildly varying length), training on a few large-ish batches on short-ish slices, and then clearing the replay buffer, so unfortunately SliceSamplerWithoutReplacement wouldn't work (though the documentation should clarify if without replacement refers to never sampling two different slices of the same episode vs allowing sampling the same episode multiple times on non-overlapping slices).

I first looked at the SliceSampler docs, for the strict_length parameter. This then led me to the docs for split_trajectories, which showed the basic usage, but should probably include a warning about its input assumptions (no duplicate trajectory_keys from different slices, even noncontiguous).

I then looked at the SliceSampler docs again, for the truncated_key parameter, which led me to discovering the return_info=True option. The docs also seem to imply ("next", "truncated") is False in cases where the last step in a slice is simply the done last step in an episode.

@vmoens
Copy link
Contributor

vmoens commented Nov 25, 2024

Is this a good edit?
#2607
Would you add anything?
Is there anything you think is broken in the API?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants