Skip to content

Commit

Permalink
Use deferred imports instead of lazy loading (#121)
Browse files Browse the repository at this point in the history
PR #116 added lazy loading of `run_training` function with a PEP 562
`__getattr__` module hook. Pylint does not understand the trick, which
caused some problems. Replace the lazy hook with a wrapper function and
deferred imports.

Also mark all public APIs with a comment, so developers are aware which
classes are used by other packages.

Signed-off-by: Christian Heimes <[email protected]>
  • Loading branch information
tiran authored Jul 2, 2024
1 parent 506ce79 commit 8fa3cb8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
22 changes: 7 additions & 15 deletions src/instructlab/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"QuantizeDataType",
"TorchrunArgs",
"TrainingArgs",
"run_training", # pylint: disable=undefined-all-variable
"run_training",
)

# Local
Expand All @@ -21,18 +21,10 @@
)


def __dir__():
return globals().keys() | {"run_training"}
# defer import of main_ds
def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""Wrapper around the main training job that calls torchrun."""
# Local
from .main_ds import run_training


def __getattr__(name):
# lazy import run_training
if name == "run_training":
# pylint: disable=global-statement,import-outside-toplevel
global run_training
# Local
from .main_ds import run_training

return run_training

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
return run_training(torch_args=torch_args, train_args=train_args)
7 changes: 7 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import BaseModel, ConfigDict, Field


# public API
class DeepSpeedOffloadStrategy(Enum):
"""
Defines the offload strategy for DeepSpeed.
Expand All @@ -24,6 +25,7 @@ class DeepSpeedOffloadStrategy(Enum):
NONE = None


# public API
class QuantizeDataType(Enum):
"""
Defines what datatype we use during quantization.
Expand All @@ -34,6 +36,7 @@ class QuantizeDataType(Enum):
NONE = None


# public API
class DataProcessArgs(BaseModel):
"""
All the arguments consumed by the training data pre-process script.
Expand All @@ -49,6 +52,7 @@ class DataProcessArgs(BaseModel):
model_config = ConfigDict(protected_namespaces=())


# public API
class TorchrunArgs(BaseModel):
"""
Representation of the arguments being used by torchrun.
Expand All @@ -63,6 +67,7 @@ class TorchrunArgs(BaseModel):
rdzv_endpoint: str


# public API
class LoraOptions(BaseModel):
"""
Options to specify when training using a LoRA.
Expand All @@ -81,6 +86,7 @@ class Config:
use_enum_values = True


# public API
class DeepSpeedOptions(BaseModel):
"""
Represents the available options we support when training with the DeepSpeed optimizer.
Expand All @@ -98,6 +104,7 @@ class DeepSpeedOptions(BaseModel):
save_samples: int | None = None


# public API
class TrainingArgs(BaseModel):
"""
This class represents the arguments being used by the training script.
Expand Down
3 changes: 2 additions & 1 deletion src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,8 @@ def main(args):
torch.distributed.destroy_process_group()


def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
# public API
def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""
Wrapper around the main training job that calls torchrun.
"""
Expand Down

0 comments on commit 8fa3cb8

Please sign in to comment.