Skip to content

Commit

Permalink
Akoumparouli/nemo ux fix dir or string artifact (#10936)
Browse files Browse the repository at this point in the history
* Add __repr__ to Artifact

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* nemo.lightning.io.artifact: represent strings as fdl.Config to avoid path adjustment during restoration

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* t5 test minification

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa authored Oct 18, 2024
1 parent f4d1c5d commit c82a597
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 10 deletions.
7 changes: 5 additions & 2 deletions nemo/lightning/io/artifact/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@


class Artifact(ABC, Generic[ValueT]):
def __init__(self, attr: str, required: bool = True):
def __init__(self, attr: str, required: bool = True, skip: bool = False):
self.attr = attr
self.required = required
self.skip = False
self.skip = skip

@abstractmethod
def dump(self, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT:
Expand All @@ -18,3 +18,6 @@ def dump(self, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT:
@abstractmethod
def load(self, path: Path) -> ValueT:
pass

def __repr__(self):
return f"{type(self).__name__}(skip= {self.skip}, attr= {self.attr}, required= {self.required})"
7 changes: 3 additions & 4 deletions nemo/lightning/io/artifact/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
from pathlib import Path
from typing import Union
import fiddle as fdl

from nemo.lightning.io.artifact.base import Artifact

Expand All @@ -19,8 +20,7 @@ class FileArtifact(Artifact[str]):
def dump(self, value: str, absolute_dir: Path, relative_dir: Path) -> str:
if not pathize(value).exists():
# This is Artifact is just a string.
self.skip = True
return value
return fdl.Config(FileArtifact, attr=value, skip=True)
new_value = copy_file(value, absolute_dir, relative_dir)
return str(new_value)

Expand Down Expand Up @@ -65,8 +65,7 @@ class DirOrStringArtifact(DirArtifact):
def dump(self, value: str, absolute_dir: Path, relative_dir: Path) -> str:
if not pathize(value).exists():
# This is Artifact is just a string.
self.skip = True
return value
return fdl.Config(DirOrStringArtifact, attr=value, skip=True)
return super().dump(value, absolute_dir, relative_dir)

def load(self, path: str) -> str:
Expand Down
9 changes: 9 additions & 0 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,15 @@ def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: P

def _artifact_transform_load(cfg: fdl.Config, path: Path):
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):
# We expect an artifact.attr to be a string or a fdl.Config.
# Some parameteres can be a string or a filepath. When those parameters are just strings,
# we will represent it with a fdl.Config, and will skip the rest of the loop (base-dir adjustment).
current_val = getattr(cfg, artifact.attr)
if isinstance(current_val, fdl.Config):
# artifact.attr is a string not a path.
setattr(cfg, artifact.attr, fdl.build(current_val).attr)
continue

if artifact.skip:
continue
current_val = getattr(cfg, artifact.attr)
Expand Down
8 changes: 4 additions & 4 deletions tests/collections/llm/megatron_t5_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ def get_args():
paths=args.data_path,
seq_length=512,
seq_length_dec=128,
micro_batch_size=64,
global_batch_size=512,
micro_batch_size=args.devices,
global_batch_size=2 * args.devices,
seed=1234,
tokenizer=tokenizer,
split="99982,9,9",
index_mapping_dir=args.index_mapping_dir,
)
t5_config = llm.t5.model.t5.T5Config(
num_layers=12,
encoder_num_layers=12,
num_layers=args.devices,
encoder_num_layers=args.devices,
hidden_size=768,
ffn_hidden_size=3072,
num_attention_heads=12,
Expand Down

0 comments on commit c82a597

Please sign in to comment.