diff --git a/src/instructlab/training/__init__.py b/src/instructlab/training/__init__.py index 5319e0c6..151989a6 100644 --- a/src/instructlab/training/__init__.py +++ b/src/instructlab/training/__init__.py @@ -6,7 +6,7 @@ "QuantizeDataType", "TorchrunArgs", "TrainingArgs", - "run_training", # pylint: disable=undefined-all-variable + "run_training", ) # Local @@ -21,18 +21,10 @@ ) -def __dir__(): - return globals().keys() | {"run_training"} +# defer import of main_ds +def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: + """Wrapper around the main training job that calls torchrun.""" + # Local + from .main_ds import run_training - -def __getattr__(name): - # lazy import run_training - if name == "run_training": - # pylint: disable=global-statement,import-outside-toplevel - global run_training - # Local - from .main_ds import run_training - - return run_training - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + return run_training(torch_args=torch_args, train_args=train_args) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index db50c397..6d2ed50d 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, ConfigDict, Field +# public API class DeepSpeedOffloadStrategy(Enum): """ Defines the offload strategy for DeepSpeed. @@ -24,6 +25,7 @@ class DeepSpeedOffloadStrategy(Enum): NONE = None +# public API class QuantizeDataType(Enum): """ Defines what datatype we use during quantization. @@ -34,6 +36,7 @@ class QuantizeDataType(Enum): NONE = None +# public API class DataProcessArgs(BaseModel): """ All the arguments consumed by the training data pre-process script. @@ -49,6 +52,7 @@ class DataProcessArgs(BaseModel): model_config = ConfigDict(protected_namespaces=()) +# public API class TorchrunArgs(BaseModel): """ Representation of the arguments being used by torchrun. @@ -63,6 +67,7 @@ class TorchrunArgs(BaseModel): rdzv_endpoint: str +# public API class LoraOptions(BaseModel): """ Options to specify when training using a LoRA. @@ -81,6 +86,7 @@ class Config: use_enum_values = True +# public API class DeepSpeedOptions(BaseModel): """ Represents the available options we support when training with the DeepSpeed optimizer. @@ -98,6 +104,7 @@ class DeepSpeedOptions(BaseModel): save_samples: int | None = None +# public API class TrainingArgs(BaseModel): """ This class represents the arguments being used by the training script. diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 1fe082bf..f43100c8 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -556,7 +556,8 @@ def main(args): torch.distributed.destroy_process_group() -def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs): +# public API +def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: """ Wrapper around the main training job that calls torchrun. """