Skip to content

Commit

Permalink
Add a build option to load_context (#10713)
Browse files Browse the repository at this point in the history
* Add a build option to load_context

Signed-off-by: Marc Romeijn <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Adding test

Signed-off-by: Marc Romeijn <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Trying to fix failing CPU test

Signed-off-by: Marc Romeijn <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>

* cherry-pick fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

---------

Signed-off-by: Marc Romeijn <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>
Co-authored-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
marcromeyn and akoumpa authored Oct 24, 2024
1 parent 938f570 commit 652093c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
18 changes: 14 additions & 4 deletions nemo/lightning/io/api.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
from pathlib import Path
from typing import Callable, Optional, Type
from typing import Callable, Optional, Type, overload
import fiddle as fdl

import pytorch_lightning as pl

from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector, load
from nemo.lightning.io.pl import TrainerContext


def load_context(path: Path, subpath: Optional[str] = None) -> TrainerContext:
@overload
def load_context(path: Path, subpath: Optional[str] = None, build: bool = True) -> TrainerContext: ...


@overload
def load_context(path: Path, subpath: Optional[str] = None, build: bool = False) -> fdl.Config[TrainerContext]: ...


def load_context(path: Path, subpath: Optional[str] = None, build: bool = True):
"""
Loads a TrainerContext from a json-file or directory.
Args:
path (Path): The path to the json-file or directory containing 'io.json'.
subpath (Optional[str]): Subpath to selectively load only specific objects inside the TrainerContext. Defaults to None.
build (bool): Whether to build the TrainerContext. Defaults to True.
Otherwise, the TrainerContext is returned as a Config[TrainerContext] object.
Returns
-------
TrainerContext: The loaded TrainerContext instance.
Expand All @@ -27,7 +37,7 @@ def load_context(path: Path, subpath: Optional[str] = None) -> TrainerContext:
checkpoint: TrainerContext = load_ckpt("/path/to/checkpoint", subpath="model.config")
"""
return load(path, output_type=TrainerContext, subpath=subpath)
return load(path, output_type=TrainerContext, subpath=subpath, build=build)


def model_importer(target: Type[ConnectorMixin], ext: str) -> Callable[[Type[ConnT]], Type[ConnT]]:
Expand Down
5 changes: 4 additions & 1 deletion nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def _artifact_transform_load(cfg: fdl.Config, path: Path):
pass


def load(path: Path, output_type: Type[CkptType] = Any, subpath: Optional[str] = None) -> CkptType:
def load(path: Path, output_type: Type[CkptType] = Any, subpath: Optional[str] = None, build: bool = True) -> CkptType:
"""
Loads a configuration from a pickle file and constructs an object of the specified type.
Expand Down Expand Up @@ -700,4 +700,7 @@ def load(path: Path, output_type: Type[CkptType] = Any, subpath: Optional[str] =
config = serialization.Deserialization(json_config).result
_artifact_transform_load(config, path)

if not build:
return config

return fdl.build(config)
13 changes: 7 additions & 6 deletions tests/lightning/_io/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import partial
from pathlib import Path

import fiddle as fdl
import pytest
import yaml
from pytorch_lightning.loggers import TensorBoardLogger
Expand Down Expand Up @@ -69,9 +70,9 @@ def test_reload_ckpt(self, tmpdir, partial_function_with_pos_and_key_args):
loaded_func = loaded.extra["dummy"]
assert loaded_func(b=2) == partial_function_with_pos_and_key_args(b=2)

model_yaml = Path(tmpdir) / "model.yaml"
assert model_yaml.exists()

observed = yaml.safe_load(model_yaml.read_text())
expected = yaml.safe_load((Path(ARTIFACTS_DIR) / "model.yaml").read_text())
assert observed.keys() == expected.keys()
config = io.load_context(tmpdir, build=False)
assert isinstance(config, fdl.Config)
assert config.model.config.seq_length == ckpt.model.config.seq_length
assert config.model.tokenizer.vocab_file.startswith(str(tmpdir))
assert config.model.tokenizer.merges_file.startswith(str(tmpdir))
assert config.extra["dummy"] == fdl.Partial(dummy_extra, 10, c=15)

0 comments on commit 652093c

Please sign in to comment.