PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation targeting performant eager-mode while using per-parameter sharding for improved usability.
- If you are new to FSDP, we recommend that you start with FSDP2 due to improved usability.
- If you are currently using FSDP1, consider evaluating the following differences to see if you should switch to FSDP2:
Compared to PyTorch FSDP1 (FullyShardedDataParallel
):
- FSDP2 uses
DTensor
-based dim-0 per-parameter sharding for a simpler sharding representation compared to FSDP1's flat-parameter sharding, while preserving similar throughput performance. More specifically, FSDP2 chunks each parameter on dim-0 across the data parallel workers (usingtorch.chunk(dim=0)
), whereas FSDP1 flattens, concatenates, and chunks a group of tensors together, making reasoning about what data is present on each worker and resharding to different parallelisms complex. Per-parameter sharding provides a more intuitive user experience, relaxes constraints around frozen parameters, and allows for communication-free (sharded) state dicts, which otherwise require all-gathers in FSDP1. - FSDP2 implements a different memory management approach to handle the
multi-stream usages that avoids
torch.Tensor.record_stream
. This ensures deterministic and expected memory usage and does not require blocking the CPU like in FSDP1'slimit_all_gathers=True
. - FSDP2 exposes APIs for manual control over prefetching and collective
scheduling, allowing power users more customization. See the methods on
FSDPModule
below for details. - FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly
support full state dicts. Instead, users can reshard the sharded state dicts
containing
DTensor
s to full state dicts themselves usingDTensor
APIs likeDTensor.full_tensor()
or by using higher-level APIs like PyTorch Distributed Checkpoint 's distributed state dict APIs. Also, some other args have been removed; see here for details.
If you are onboarding FSDP for the first time or if any of the above appeals to your use case, we recommend that you consider using FSDP2.
See this RFC for details on system design and implementation.
Note
torch.distributed.fsdp.fully_shard
is currently in prototype state and
under development. The core API will likely not change, but we may make some
API changes if necessary.
.. currentmodule:: torch.distributed.fsdp
The frontend API is fully_shard
that can be called on a module
:
.. autofunction:: fully_shard
Calling fully_shard(module)
dynamically constructs a new class that
subclasses type(module)
and an FSDP class FSDPModule
. For example, if
we call fully_shard(linear)
on a module linear: nn.Linear
, then FSDP
constructs a new class FSDPLinear
and changes linear
's type to this.
Otherwise, fully_shard
does not change the module structure and parameter
fully-qualified names. The class FSDPModule
allows providing some
FSDP-specific methods on the module.
.. autoclass:: FSDPModule :members: :member-order: bysource
.. autoclass:: UnshardHandle :members:
.. autofunction:: register_fsdp_forward_method
.. autoclass:: MixedPrecisionPolicy :members:
.. autoclass:: OffloadPolicy :members:
.. autoclass:: CPUOffloadPolicy :members: