Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Akoumparouli/nemo ux fix dir or string artifact #10936

Merged
merged 4 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions nemo/lightning/io/artifact/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@


class Artifact(ABC, Generic[ValueT]):
def __init__(self, attr: str, required: bool = True):
def __init__(self, attr: str, required: bool = True, skip: bool = False):
self.attr = attr
self.required = required
self.skip = False
self.skip = skip

@abstractmethod
def dump(self, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT:
Expand All @@ -18,3 +18,6 @@ def dump(self, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT:
@abstractmethod
def load(self, path: Path) -> ValueT:
pass

def __repr__(self):
return f"{type(self).__name__}(skip= {self.skip}, attr= {self.attr}, required= {self.required})"
7 changes: 3 additions & 4 deletions nemo/lightning/io/artifact/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
from pathlib import Path
from typing import Union
import fiddle as fdl

from nemo.lightning.io.artifact.base import Artifact

Expand All @@ -19,8 +20,7 @@ class FileArtifact(Artifact[str]):
def dump(self, value: str, absolute_dir: Path, relative_dir: Path) -> str:
if not pathize(value).exists():
# This is Artifact is just a string.
self.skip = True
return value
return fdl.Config(FileArtifact, attr=value, skip=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is skip True by default here? Shouldn't it be equal to self.skip?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's true inside the if's body, but it does not matter, later the code checks if it's a fdl.Config

new_value = copy_file(value, absolute_dir, relative_dir)
return str(new_value)

Expand Down Expand Up @@ -65,8 +65,7 @@ class DirOrStringArtifact(DirArtifact):
def dump(self, value: str, absolute_dir: Path, relative_dir: Path) -> str:
if not pathize(value).exists():
# This is Artifact is just a string.
self.skip = True
return value
return fdl.Config(DirOrStringArtifact, attr=value, skip=True)
return super().dump(value, absolute_dir, relative_dir)

def load(self, path: str) -> str:
Expand Down
9 changes: 9 additions & 0 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,15 @@ def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: P

def _artifact_transform_load(cfg: fdl.Config, path: Path):
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):
# We expect an artifact.attr to be a string or a fdl.Config.
# Some parameteres can be a string or a filepath. When those parameters are just strings,
# we will represent it with a fdl.Config, and will skip the rest of the loop (base-dir adjustment).
current_val = getattr(cfg, artifact.attr)
if isinstance(current_val, fdl.Config):
# artifact.attr is a string not a path.
setattr(cfg, artifact.attr, fdl.build(current_val).attr)
continue

if artifact.skip:
continue
current_val = getattr(cfg, artifact.attr)
Expand Down
8 changes: 4 additions & 4 deletions tests/collections/llm/megatron_t5_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ def get_args():
paths=args.data_path,
seq_length=512,
seq_length_dec=128,
micro_batch_size=64,
global_batch_size=512,
micro_batch_size=args.devices,
global_batch_size=2 * args.devices,
seed=1234,
tokenizer=tokenizer,
split="99982,9,9",
index_mapping_dir=args.index_mapping_dir,
)
t5_config = llm.t5.model.t5.T5Config(
num_layers=12,
encoder_num_layers=12,
num_layers=args.devices,
encoder_num_layers=args.devices,
hidden_size=768,
ffn_hidden_size=3072,
num_attention_heads=12,
Expand Down
Loading