From 17df8c502f9024005a854381081822f876aae51c Mon Sep 17 00:00:00 2001 From: SoniaGrh Date: Thu, 26 Sep 2024 15:44:14 +0200 Subject: [PATCH] feat: add docstrings --- .../picsellia_datalake_processing_context.py | 89 ++++++++++++++++--- 1 file changed, 75 insertions(+), 14 deletions(-) diff --git a/src/models/contexts/processing/picsellia_datalake_processing_context.py b/src/models/contexts/processing/picsellia_datalake_processing_context.py index c94ca88b..9de045f7 100644 --- a/src/models/contexts/processing/picsellia_datalake_processing_context.py +++ b/src/models/contexts/processing/picsellia_datalake_processing_context.py @@ -13,6 +13,27 @@ class PicselliaDatalakeProcessingContext(PicselliaContext, Generic[TParameters]): + """ + A context class designed for handling Picsellia datalake processing jobs. + + This class extends `PicselliaContext` to manage the specific setup and execution + of a datalake processing job in Picsellia, including fetching job context, + input/output datalakes, and model versions. + + Attributes: + job_id (str): The job ID, either passed or fetched from environment variables. + job (picsellia.Job): The Picsellia job object initialized from the job ID. + job_type (str): The type of the job (e.g., pre-annotation). + job_context (dict): The context of the job containing model version, datalakes, and other details. + input_datalake (Datalake): The input datalake used in the job. + output_datalake (Optional[Datalake]): The output datalake (optional). + model_version (Optional[ModelVersion]): The model version associated with the job. + data_ids (list[UUID]): List of data IDs fetched from the job payload. + use_id (bool): A flag indicating whether to use data IDs. + download_annotations (bool): A flag indicating whether to download annotations. + processing_parameters (TParameters): The parameters used for the processing job. + """ + def __init__( self, processing_parameters_cls: Type[TParameters], @@ -23,6 +44,21 @@ def __init__( use_id: Optional[bool] = True, download_annotations: Optional[bool] = True, ): + """ + Initializes the PicselliaDatalakeProcessingContext with parameters to run a processing job. + + Args: + processing_parameters_cls (Type[TParameters]): The class used to define the processing parameters. + api_token (Optional[str], optional): The API token to authenticate with Picsellia. Defaults to None. + host (Optional[str], optional): The host URL for the Picsellia platform. Defaults to None. + organization_id (Optional[str], optional): The organization ID within Picsellia. Defaults to None. + job_id (Optional[str], optional): The ID of the job to be processed. Defaults to None. + use_id (Optional[bool], optional): Whether to use data IDs in the processing job. Defaults to True. + download_annotations (Optional[bool], optional): Whether to download annotations for the datalake. Defaults to True. + + Raises: + ValueError: If the job ID is not provided or found in the environment variables. + """ super().__init__(api_token, host, organization_id) self.job_id = job_id or os.environ.get("job_id") @@ -66,6 +102,15 @@ def __init__( @property def model_version_id(self) -> Union[str, None]: + """ + Retrieves the model version ID if available, and ensures it is required for certain job types. + + Returns: + Union[str, None]: The model version ID or None if not applicable. + + Raises: + ValueError: If the model version ID is missing when it is required for the job type. + """ if ( not self._model_version_id and self.job_type == ProcessingType.PRE_ANNOTATION @@ -77,6 +122,13 @@ def model_version_id(self) -> Union[str, None]: return self._model_version_id def to_dict(self) -> Dict[str, Any]: + """ + Converts the current processing context to a dictionary format for easier serialization. + + Returns: + Dict[str, Any]: A dictionary representation of the context, including job parameters, + datalake IDs, and processing parameters. + """ return { "context_parameters": { "host": self.host, @@ -93,32 +145,35 @@ def to_dict(self) -> Dict[str, Any]: } def _initialize_job_context(self) -> Dict[str, Any]: - """Initializes the context by fetching the necessary information from the job.""" + """ + Initializes the job context by synchronizing the job data from Picsellia. + + Returns: + Dict[str, Any]: The job context containing information such as input datalake, output datalake, + model version ID, and other processing parameters. + """ job_context = self.job.sync()["datalake_processing_job"] return job_context def _initialize_job(self) -> picsellia.Job: """ - Fetches the job from Picsellia using the job ID. - - The Job, in a Picsellia processing context, - is the entity that contains all the information needed to run a processing job. + Initializes and retrieves the job from Picsellia using the job ID. Returns: - The job fetched from Picsellia. + picsellia.Job: The initialized job object fetched from Picsellia. """ return self.client.get_job_by_id(self.job_id) def get_datalake(self, datalake_id: str) -> Datalake: """ - Fetches the datalake from Picsellia using the datalake ID. + Fetches a datalake from Picsellia using the datalake ID. - The Datalake, in a Picsellia processing context, - is the entity that contains all the data needed to process a model. + Args: + datalake_id (str): The ID of the datalake to fetch. Returns: - The datalake fetched from Picsellia. + Datalake: The datalake object fetched from Picsellia. """ return self.client.get_datalake(id=datalake_id) @@ -126,15 +181,21 @@ def get_model_version(self) -> ModelVersion: """ Fetches the model version from Picsellia using the model version ID. - The ModelVersion, in a Picsellia processing context, - is the entity that contains all the information needed to process a model. - Returns: - The model version fetched from Picsellia. + ModelVersion: The model version object fetched from Picsellia. """ return self.client.get_model_version_by_id(self.model_version_id) def get_data_ids(self) -> list[UUID]: + """ + Retrieves the list of data IDs from the job's payload presigned URL. + + Returns: + list[UUID]: A list of UUIDs representing the data IDs. + + Raises: + ValueError: If the payload presigned URL is missing from the job context. + """ if self._payload_presigned_url: payload = requests.get(self._payload_presigned_url).json() data_ids = payload["data_ids"]