Skip to content

Latest commit

 

History

History
85 lines (67 loc) · 3.65 KB

distributed.fsdp.fully_shard.rst

File metadata and controls

85 lines (67 loc) · 3.65 KB

torch.distributed.fsdp.fully_shard

PyTorch FSDP2 (fully_shard)

PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation targeting performant eager-mode while using per-parameter sharding for improved usability.

  • If you are new to FSDP, we recommend that you start with FSDP2 due to improved usability.
  • If you are currently using FSDP1, consider evaluating the following differences to see if you should switch to FSDP2:

Compared to PyTorch FSDP1 (FullyShardedDataParallel):

  • FSDP2 uses DTensor-based dim-0 per-parameter sharding for a simpler sharding representation compared to FSDP1's flat-parameter sharding, while preserving similar throughput performance. More specifically, FSDP2 chunks each parameter on dim-0 across the data parallel workers (using torch.chunk(dim=0)), whereas FSDP1 flattens, concatenates, and chunks a group of tensors together, making reasoning about what data is present on each worker and resharding to different parallelisms complex. Per-parameter sharding provides a more intuitive user experience, relaxes constraints around frozen parameters, and allows for communication-free (sharded) state dicts, which otherwise require all-gathers in FSDP1.
  • FSDP2 implements a different memory management approach to handle the multi-stream usages that avoids torch.Tensor.record_stream. This ensures deterministic and expected memory usage and does not require blocking the CPU like in FSDP1's limit_all_gathers=True.
  • FSDP2 exposes APIs for manual control over prefetching and collective scheduling, allowing power users more customization. See the methods on FSDPModule below for details.
  • FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly support full state dicts. Instead, users can reshard the sharded state dicts containing DTensor s to full state dicts themselves using DTensor APIs like DTensor.full_tensor() or by using higher-level APIs like PyTorch Distributed Checkpoint 's distributed state dict APIs. Also, some other args have been removed; see here for details.

If you are onboarding FSDP for the first time or if any of the above appeals to your use case, we recommend that you consider using FSDP2.

See this RFC for details on system design and implementation.

Note

torch.distributed.fsdp.fully_shard is currently in prototype state and under development. The core API will likely not change, but we may make some API changes if necessary.

.. currentmodule:: torch.distributed.fsdp

The frontend API is fully_shard that can be called on a module:

.. autofunction:: fully_shard

Calling fully_shard(module) dynamically constructs a new class that subclasses type(module) and an FSDP class FSDPModule. For example, if we call fully_shard(linear) on a module linear: nn.Linear, then FSDP constructs a new class FSDPLinear and changes linear 's type to this. Otherwise, fully_shard does not change the module structure and parameter fully-qualified names. The class FSDPModule allows providing some FSDP-specific methods on the module.

.. autoclass:: FSDPModule
    :members:
    :member-order: bysource

.. autoclass:: UnshardHandle
    :members:

.. autofunction:: register_fsdp_forward_method

.. autoclass:: MixedPrecisionPolicy
    :members:

.. autoclass:: OffloadPolicy
    :members:

.. autoclass:: CPUOffloadPolicy
    :members: