Skip to content

Commit

Permalink
feat: add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
SoniaGrh committed Nov 4, 2024
1 parent 33ca68f commit 17df8c5
Showing 1 changed file with 75 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,27 @@


class PicselliaDatalakeProcessingContext(PicselliaContext, Generic[TParameters]):
"""
A context class designed for handling Picsellia datalake processing jobs.
This class extends `PicselliaContext` to manage the specific setup and execution
of a datalake processing job in Picsellia, including fetching job context,
input/output datalakes, and model versions.
Attributes:
job_id (str): The job ID, either passed or fetched from environment variables.
job (picsellia.Job): The Picsellia job object initialized from the job ID.
job_type (str): The type of the job (e.g., pre-annotation).
job_context (dict): The context of the job containing model version, datalakes, and other details.
input_datalake (Datalake): The input datalake used in the job.
output_datalake (Optional[Datalake]): The output datalake (optional).
model_version (Optional[ModelVersion]): The model version associated with the job.
data_ids (list[UUID]): List of data IDs fetched from the job payload.
use_id (bool): A flag indicating whether to use data IDs.
download_annotations (bool): A flag indicating whether to download annotations.
processing_parameters (TParameters): The parameters used for the processing job.
"""

def __init__(
self,
processing_parameters_cls: Type[TParameters],
Expand All @@ -23,6 +44,21 @@ def __init__(
use_id: Optional[bool] = True,
download_annotations: Optional[bool] = True,
):
"""
Initializes the PicselliaDatalakeProcessingContext with parameters to run a processing job.
Args:
processing_parameters_cls (Type[TParameters]): The class used to define the processing parameters.
api_token (Optional[str], optional): The API token to authenticate with Picsellia. Defaults to None.
host (Optional[str], optional): The host URL for the Picsellia platform. Defaults to None.
organization_id (Optional[str], optional): The organization ID within Picsellia. Defaults to None.
job_id (Optional[str], optional): The ID of the job to be processed. Defaults to None.
use_id (Optional[bool], optional): Whether to use data IDs in the processing job. Defaults to True.
download_annotations (Optional[bool], optional): Whether to download annotations for the datalake. Defaults to True.
Raises:
ValueError: If the job ID is not provided or found in the environment variables.
"""
super().__init__(api_token, host, organization_id)

self.job_id = job_id or os.environ.get("job_id")
Expand Down Expand Up @@ -66,6 +102,15 @@ def __init__(

@property
def model_version_id(self) -> Union[str, None]:
"""
Retrieves the model version ID if available, and ensures it is required for certain job types.
Returns:
Union[str, None]: The model version ID or None if not applicable.
Raises:
ValueError: If the model version ID is missing when it is required for the job type.
"""
if (
not self._model_version_id
and self.job_type == ProcessingType.PRE_ANNOTATION
Expand All @@ -77,6 +122,13 @@ def model_version_id(self) -> Union[str, None]:
return self._model_version_id

def to_dict(self) -> Dict[str, Any]:
"""
Converts the current processing context to a dictionary format for easier serialization.
Returns:
Dict[str, Any]: A dictionary representation of the context, including job parameters,
datalake IDs, and processing parameters.
"""
return {
"context_parameters": {
"host": self.host,
Expand All @@ -93,48 +145,57 @@ def to_dict(self) -> Dict[str, Any]:
}

def _initialize_job_context(self) -> Dict[str, Any]:
"""Initializes the context by fetching the necessary information from the job."""
"""
Initializes the job context by synchronizing the job data from Picsellia.
Returns:
Dict[str, Any]: The job context containing information such as input datalake, output datalake,
model version ID, and other processing parameters.
"""
job_context = self.job.sync()["datalake_processing_job"]

return job_context

def _initialize_job(self) -> picsellia.Job:
"""
Fetches the job from Picsellia using the job ID.
The Job, in a Picsellia processing context,
is the entity that contains all the information needed to run a processing job.
Initializes and retrieves the job from Picsellia using the job ID.
Returns:
The job fetched from Picsellia.
picsellia.Job: The initialized job object fetched from Picsellia.
"""
return self.client.get_job_by_id(self.job_id)

def get_datalake(self, datalake_id: str) -> Datalake:
"""
Fetches the datalake from Picsellia using the datalake ID.
Fetches a datalake from Picsellia using the datalake ID.
The Datalake, in a Picsellia processing context,
is the entity that contains all the data needed to process a model.
Args:
datalake_id (str): The ID of the datalake to fetch.
Returns:
The datalake fetched from Picsellia.
Datalake: The datalake object fetched from Picsellia.
"""
return self.client.get_datalake(id=datalake_id)

def get_model_version(self) -> ModelVersion:
"""
Fetches the model version from Picsellia using the model version ID.
The ModelVersion, in a Picsellia processing context,
is the entity that contains all the information needed to process a model.
Returns:
The model version fetched from Picsellia.
ModelVersion: The model version object fetched from Picsellia.
"""
return self.client.get_model_version_by_id(self.model_version_id)

def get_data_ids(self) -> list[UUID]:
"""
Retrieves the list of data IDs from the job's payload presigned URL.
Returns:
list[UUID]: A list of UUIDs representing the data IDs.
Raises:
ValueError: If the payload presigned URL is missing from the job context.
"""
if self._payload_presigned_url:
payload = requests.get(self._payload_presigned_url).json()
data_ids = payload["data_ids"]
Expand Down

0 comments on commit 17df8c5

Please sign in to comment.