Skip to content

Commit

Permalink
feat: 110 initialize tests for the repo (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
picsalex authored Mar 26, 2024
1 parent f06baed commit f8abecf
Show file tree
Hide file tree
Showing 19 changed files with 1,239 additions and 121 deletions.
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[coverage:report]
skip_empty = true
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,6 @@ checkpoint

*/poetry.lock
yolox-detection-test/*

# MacOS files
.DS_Store
2 changes: 1 addition & 1 deletion src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from src.decorators.pipeline_decorator import Pipeline, pipeline # noqa
from src.decorators.step_decorator import step # noqa
from src.decorators.step_decorator import Step, step # noqa
261 changes: 203 additions & 58 deletions src/decorators/pipeline_decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ast
import inspect
from typing import Optional, Union, Callable, TypeVar, Any, List
from typing import Optional, Union, Callable, TypeVar, Any, List, Tuple

from tabulate import tabulate

Expand Down Expand Up @@ -31,12 +31,13 @@ def __init__(
name: The name of the pipeline.
log_folder_path: The path to the log folder. If left empty, a temporary folder will be created.
remove_logs_on_completion: Whether to remove the logs on completion. Defaults to True.
entrypoint: The entrypoint of the pipeline. This is the function that will be called when the pipeline is run.
entrypoint: The entrypoint of the pipeline.
This is the function that will be called when the pipeline is run.
"""
self._context = context
self.name = name
self.logger_manager = LoggerManager(
pipeline_name=name, log_folder_path=log_folder_path
pipeline_name=name, log_folder_root_path=log_folder_path
)
self.remove_logs_on_completion = remove_logs_on_completion
self.entrypoint = entrypoint
Expand All @@ -48,14 +49,34 @@ def __init__(
self.initialization_log_file_path = None

def __call__(self, *args, **kwargs) -> Any:
"""Handles the pipeline call.
This method first analyses and registers the steps of the pipeline.
Then, it configures the logging, flags the pipeline as running,
logs the pipeline context, and runs the entrypoint.
Args:
*args: Arguments to be passed to the entrypoint function.
**kwargs: Keyword arguments to be passed to the entrypoint function.
Returns:
The outputs of the entrypoint function call.
"""
with self:
self._analyze_and_register_steps()
self.finalize_initialization()
self.log_pipeline_context()
self._scan_steps()
self._configure_logging()
self._flag_pipeline(state=PipelineState.RUNNING)
self._log_pipeline_context()

return self.entrypoint(*args, **kwargs)

def __enter__(self):
"""Activate the pipeline context.
Raises:
RuntimeError: If another pipeline is already active.
Typically, occurs when a pipeline is run within another pipeline.
"""
if Pipeline.ACTIVE_PIPELINE is not None:
raise RuntimeError(
"Another pipeline is already active."
Expand All @@ -64,20 +85,34 @@ def __enter__(self):

Pipeline.ACTIVE_PIPELINE = self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, *args: Any):
"""Deactivates the pipeline context and clean the log folder is requested.
Args:
*args: The arguments passed to the context exit handler.
"""
Pipeline.ACTIVE_PIPELINE = None

if self.remove_logs_on_completion:
self.logger_manager.clean()

@property
def is_initialized(self) -> bool:
return self._is_pipeline_initialized

@property
def state(self) -> PipelineState:
"""
The state of the pipeline is determined by the state of its steps.
"""The state of the pipeline.
The pipeline's state is determined by the states of its steps.
The pipeline can be in one of the following states:
- PENDING: All the steps are pending.
- RUNNING: At least one step is running, and no step has failed.
- SUCCESS: All the steps have succeeded.
- FAILED: At least one step has failed and no step has succeeded after it. By default, a step is skipped if a
previous step has failed. This behavior can be changed by setting the `continue_on_failure` parameter to
`True` when defining a step.
- PARTIAL_SUCCESS: At least one step has succeeded after a failed step.
Returns:
The state of the pipeline.
"""
if all(
step_metadata.state == StepState.PENDING
Expand All @@ -86,17 +121,17 @@ def state(self) -> PipelineState:
return PipelineState.PENDING

elif all(
step_metadata.state
in [StepState.RUNNING, StepState.PENDING, StepState.SUCCESS]
step_metadata.state == StepState.SUCCESS
for step_metadata in self.steps_metadata
):
return PipelineState.RUNNING
return PipelineState.SUCCESS

elif all(
step_metadata.state == StepState.SUCCESS
step_metadata.state
in [StepState.RUNNING, StepState.PENDING, StepState.SUCCESS]
for step_metadata in self.steps_metadata
):
return PipelineState.SUCCESS
return PipelineState.RUNNING

else:
for step_metadata in self.steps_metadata:
Expand All @@ -110,19 +145,38 @@ def state(self) -> PipelineState:
return PipelineState.PARTIAL_SUCCESS
return PipelineState.FAILED

raise RuntimeError(
"Pipeline state could not be determined."
"You need at least one step to determine the pipeline state.",
)

@property
def steps_metadata(self) -> List[StepMetadata]:
"""All the pipeline's steps' metadata.
Returns:
All the pipeline's steps' metadata.
"""
return self._registered_steps_metadata

def finalize_initialization(self) -> None:
self.logger_manager.configure(steps_metadata=self.steps_metadata)
self._is_pipeline_initialized = True
self._state = PipelineState.RUNNING
def log_pipeline_info(self, log_content: str) -> None:
"""Log the provided content inside the pipeline log file.
Args:
log_content: The content to log.
"""
self.logger_manager.logger.info(f"{log_content}")

def register_active_step_metadata(self, step_metadata: StepMetadata) -> None:
"""Register the metadata of a step found during the pipeline scan.
Args:
step_metadata: The metadata of the step to register.
"""
self._registered_steps_metadata.append(step_metadata)

def _configure_logging(self) -> None:
"""Configures the logging for the pipeline.
This method configures the pipeline logger, the pipeline's dedicated log file and prepares the logger if
the pipeline needs to log something before the first step is run.
"""
self.logger_manager.configure_log_files(steps_metadata=self.steps_metadata)

self.initialization_log_file_path = (
self.logger_manager.configure_pipeline_initialization_log_file()
Expand All @@ -131,53 +185,68 @@ def finalize_initialization(self) -> None:
log_file_path=self.initialization_log_file_path
)

def log_pipeline_context(self):
def _flag_pipeline(self, state: PipelineState) -> None:
"""Flags the pipeline with the provided state.
Args:
state: The state to flag the pipeline with.
"""
self._state = state

def _log_pipeline_context(self):
"""Log the pipeline context.
This method logs an introduction sentence, followed by the pipeline's context.
- If the context exposes a `to_dict` method, then it will be called to convert the context to a dictionary.
- If the context is already a dictionary, it will be used as is.
- If the context is neither a dictionary nor exposes a `to_dict` method, then it will be ignored.
Typically, a context would look like this: `{'hyperparameters': {'learning_rate': 0.01, 'batch_size': 32}}`.
In this case, the context will be printed as a Markdown table with the following format:
| hyperparameters | values |
|-------------------|----------|
| learning_rate | 0.01 |
| batch_size | 32 |
"""
self.log_pipeline_info(
f"Pipeline \033[94m{self.name}\033[0m is starting with the following context:"
)

# Normalize the context to always be a dictionary
if hasattr(self._context, "to_dict") and callable(
getattr(self._context, "to_dict")
):
context_dict = self._context.to_dict()
elif isinstance(self._context, dict):
context_dict = self._context
else:
self.log_pipeline_info(
"Cannot print the context. It should be a dictionary or an object with a `to_dict() -> dict` method."
)
return
context_dict = self._parse_context_to_dict(self._context)

# Separate flat parameters from nested ones
flat_parameters = {}
nested_parameters = {}

for key, value in context_dict.items():
if isinstance(value, dict):
nested_parameters[key] = value
else:
flat_parameters[key] = value
flat_parameters, nested_parameters = self._extract_parameters_from_context_dict(
context_dict
)

# Log flat parameters
if flat_parameters:
markdown_table = self._get_markdown_table("parameters", flat_parameters)
markdown_table = self._compute_markdown_table("parameters", flat_parameters)
self.log_pipeline_info(f"{markdown_table}\n")

# Log each nested dictionary under its key
for key, nested_dict in nested_parameters.items():
markdown_table = self._get_markdown_table(key, nested_dict)
markdown_table = self._compute_markdown_table(key, nested_dict)
self.log_pipeline_info(f"{markdown_table}\n")

def log_pipeline_info(self, log_content: str) -> None:
self.logger_manager.logger.info(f"{log_content}")
def _scan_steps(self) -> None:
"""Analyze the pipeline entrypoint function to identify and register step calls.
def register_active_step_metadata(self, step_metadata: StepMetadata) -> None:
self._registered_steps_metadata.append(step_metadata)
The pipeline is scanned using the `inspect` module to extract the source code of the entrypoint function.
Each note is matched with the global STEPS_REGISTRY to identify the steps that are called.
Raises:
ValueError: If the provided entrypoint cannot be scanned.
"""
try:
src = inspect.getsource(self.entrypoint)
except TypeError as e:
raise ValueError(
f"The provided entrypoint cannot be scanned: {str(e)}"
) from e

def _analyze_and_register_steps(self) -> None:
"""Analyze the pipeline entrypoint function to identify and register step calls."""
src = inspect.getsource(self.entrypoint)
tree = ast.parse(src)

pipeline_instance = self
Expand All @@ -195,7 +264,17 @@ def visit_Call(self, node):
visitor = StepCallVisitor()
visitor.visit(tree)

def _get_markdown_table(self, category: str, context: dict) -> str:
def _compute_markdown_table(self, category: str, context: dict) -> str:
"""Format a dictionary as a Markdown table with two columns: keys and values.
Args:
category: The category of the provided context. This will be the header of the first column.
For example, "hyperparameters", "augmentation_parameters, etc.
context: The context to format.
Returns:
The context formatted as a Markdown table.
"""
headers = (
[category, "values"]
if isinstance(context, dict)
Expand All @@ -212,8 +291,61 @@ def _get_markdown_table(self, category: str, context: dict) -> str:
tablefmt="github",
)

def _extract_parameters_from_context_dict(self, context: dict) -> Tuple[dict, dict]:
"""Extracts flat parameters and nested parameters from a context dictionary.
For example, given the following context: `{'nested': {'learning_rate': 0.01}, 'flat': "value"}`,
the output will be :
- flat_parameters: `{'flat': "value"}`
- nested_parameters: `{'nested': {'learning_rate': 0.01}}`
Args:
context: The context dictionary to extract the parameters from.
Returns:
A tuple containing the flat parameters and the nested parameters.
"""
flat_parameters = {}
nested_parameters = {}

for key, value in context.items():
if isinstance(value, dict):
nested_parameters[key] = value
else:
flat_parameters[key] = value

return flat_parameters, nested_parameters

def _parse_context_to_dict(self, context: Any) -> Optional[dict]:
"""Parse the context to a dictionary.
This method only works if the context exposes a `to_dict` method or is already a dictionary.
Args:
context: The context to parse.
Returns:
The context as a dictionary.
"""
if hasattr(context, "to_dict") and callable(getattr(context, "to_dict")):
return context.to_dict()
elif isinstance(context, dict):
return context
else:
return None

@staticmethod
def get_active_context() -> Any:
"""Get the context of the currently running pipeline.
Returns:
The context of the currently running pipeline.
Raises:
RuntimeError: If no current pipeline is running.
RuntimeError: If no context has been set for the current pipeline.
"""
if Pipeline.ACTIVE_PIPELINE is None:
raise RuntimeError(
"No current pipeline running."
Expand All @@ -231,6 +363,19 @@ def get_active_context() -> Any:

@staticmethod
def register_step_metadata(step_metadata: StepMetadata) -> None:
"""Register a step metadata into the global steps' registry.
This method is only used to register the steps' metadata within the global steps registry,
typically when a step is defined. This registry will then be used during the pipeline scan to identify the steps
used inside.
Args:
step_metadata: The metadata of the step to register.
Raises:
ValueError: If a step with the same name has already been registered. The step names must be unique,
so two functions in two different modules cannot be decorated with @step if they have the same name.
"""
step_name = step_metadata.name

if step_name in Pipeline.STEPS_REGISTRY:
Expand Down
Loading

0 comments on commit f8abecf

Please sign in to comment.