Skip to content

Commit

Permalink
Implement export with PyT Distributed checkpoints (#9058)
Browse files Browse the repository at this point in the history
* Implement PyT Dist load with MCore

Signed-off-by: Mikołaj Błaż <[email protected]>

* Use plain PyT Dist utils

Signed-off-by: Mikołaj Błaż <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Implement TarPath compatible version

Signed-off-by: Mikołaj Błaż <[email protected]>

* Apply black

Signed-off-by: Mikołaj Błaż <[email protected]>

---------

Signed-off-by: Mikołaj Błaż <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mikolajblaz and pre-commit-ci[bot] authored May 17, 2024
1 parent b715f5a commit 18eed4d
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions nemo/export/trt_llm/nemo/nemo_ckpt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import configparser
import json
import logging
import math
import multiprocessing
Expand All @@ -28,6 +29,8 @@
import torch
import zarr
from tensorrt_llm._utils import np_bfloat16, pad_vocab_size, str_dtype_to_torch, torch_to_numpy
from torch.distributed.checkpoint import FileSystemReader, TensorStorageMetadata
from torch.distributed.checkpoint.state_dict_loader import load_state_dict
from tqdm import tqdm
from transformers import AutoTokenizer, GPT2Tokenizer, LlamaConfig

Expand Down Expand Up @@ -122,6 +125,54 @@ def rename_key_dist_ckpt(old_key: str, layer: int):


def load_sharded_metadata(checkpoint_dir: Union[Path, TarPath], torch_tensor=True):
with (checkpoint_dir / 'metadata.json').open(mode='r') as f:
config_dict = json.load(f)
if config_dict['sharded_backend'] == 'zarr':
return load_sharded_metadata_zarr(checkpoint_dir, torch_tensor)
elif config_dict['sharded_backend'] == 'torch_dist':
return load_sharded_metadata_torch_dist(checkpoint_dir, torch_tensor)
else:
raise NotImplementedError(f'Distributed checkpoint backend {config_dict["sharded_backend"]} not supported')


class TarFileSystemReader(FileSystemReader):
"""Reader that accepts both Path and TarPath checkpoint directory.
The FileSystemReader works with TarPath, but expects a pure Path.
It's enough to skip the Path check in __init__.
"""

def __init__(self, path: Union[Path, TarPath]) -> None:
"""No call to super().__init__ because it expects pure Path."""
self.path = path
self.storage_data = dict()


def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch_tensor=True):
fs_reader = TarFileSystemReader(checkpoint_dir)
metadata = fs_reader.read_metadata()

state_dict = {
k: torch.empty(tp.size, dtype=tp.properties.dtype)
for k, tp in metadata.state_dict_metadata.items()
if isinstance(tp, TensorStorageMetadata)
}
load_state_dict(
state_dict,
storage_reader=fs_reader,
no_dist=True,
)

if not torch_tensor:
for k, v in state_dict.items():
if v.dtype == torch.bfloat16:
state_dict[k] = v.view(torch.int16).numpy().view(np_bfloat16)
else:
state_dict[k] = v.numpy()
return state_dict


def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tensor=True):
sharded_state_dict = {}
for subdir in checkpoint_dir.iterdir():
if not subdir.is_dir() or not (subdir / '.zarray').exists():
Expand Down

0 comments on commit 18eed4d

Please sign in to comment.