From 26c9ef01f763fb9bcc5319f8cfc61cf1a88f7b09 Mon Sep 17 00:00:00 2001 From: Francesco Vertemati <63065831+vertefra@users.noreply.github.com> Date: Thu, 29 Aug 2024 11:44:32 -0400 Subject: [PATCH] Feature/pytorch on the cloud training status 1 wandb (#50) * update structure * update data uploader * closing event loop * remove unused * fix async calls * format test fix * Update documentation * remove reference to base client * add docstring * format * add test * chain of resp * lint * fix * add newline * validation handler * linting * adding uploader handler * factory function * fix test error * commit before sync * training flow * create standalone handler for params upload * uploader * update error on default allowed modules * update training handle * lint and format * training handle and test * revert * checking out * better implementation for the pytorch agent * address all the comments * fix types * update setters for context and tools * remove custom exception * remove forward ref * update returning type * chekcing out * fix comments * fix undefined chain * update operator * checkout * removing unused if statements * checkout * fix docs requirements * newline * refactor of discovery and validation * fix third party modules discovery * checkout * adding installation of dependencies * format * change log level for pytest * adding unit tests * update comments * update dependendecies * add newline * add wandb * remove obsolete package * addressing comments * adding newline * update installation * update deepcopy * remove comment * lint fix * adding trace name * update installation command * wandb version constraints * more fix --- .../actions/setup_lint_and_test/action.yml | 2 +- Makefile | 3 + docs/requirements.txt | 2 + pyproject.toml | 7 + qcog_python_client/monitor/__init__.py | 13 ++ qcog_python_client/monitor/_wandb.py | 73 +++++++ qcog_python_client/monitor/interface.py | 23 +++ qcog_python_client/qcog/_baseclient.py | 57 ++++-- qcog_python_client/qcog/_httpclient.py | 11 ++ qcog_python_client/qcog/_interfaces.py | 4 +- .../qcog/pytorch/discover/discoverhandler.py | 172 ++++++++++++---- .../qcog/pytorch/discover/types.py | 10 + .../qcog/pytorch/discover/utils.py | 8 + qcog_python_client/qcog/pytorch/handler.py | 4 +- qcog_python_client/qcog/pytorch/types.py | 58 ++++++ .../qcog/pytorch/upload/uploadhandler.py | 46 +---- .../qcog/pytorch/upload/utils.py | 46 +++++ qcog_python_client/qcog/pytorch/utils.py | 82 ++++++++ .../qcog/pytorch/validate/__init__.py | 1 - .../pytorch/validate/_setup_monitor_import.py | 142 ++++++++++++++ ...te_module.py => _validate_model_module.py} | 64 +++--- .../qcog/pytorch/validate/utils.py | 86 ++++++++ .../qcog/pytorch/validate/validate_utils.py | 83 -------- .../qcog/pytorch/validate/validatehandler.py | 77 +++++--- tests/pytorch_model/model.py | 14 ++ tests/pytorch_model/ubiops_train_pt.py | 184 ------------------ tests/test_pytorch_agent.py | 39 +--- .../pytorch/discover/test_discover_handler.py | 170 ++++++++++++++++ .../validate/test_setup_monitor_import.py | 174 +++++++++++++++++ .../unit/qcog/pytorch/validate/test_utils.py | 42 ++++ 30 files changed, 1242 insertions(+), 455 deletions(-) create mode 100644 qcog_python_client/monitor/__init__.py create mode 100644 qcog_python_client/monitor/_wandb.py create mode 100644 qcog_python_client/monitor/interface.py create mode 100644 qcog_python_client/qcog/pytorch/discover/types.py create mode 100644 qcog_python_client/qcog/pytorch/discover/utils.py create mode 100644 qcog_python_client/qcog/pytorch/types.py create mode 100644 qcog_python_client/qcog/pytorch/upload/utils.py create mode 100644 qcog_python_client/qcog/pytorch/utils.py create mode 100644 qcog_python_client/qcog/pytorch/validate/_setup_monitor_import.py rename qcog_python_client/qcog/pytorch/validate/{_validate_module.py => _validate_model_module.py} (72%) create mode 100644 qcog_python_client/qcog/pytorch/validate/utils.py delete mode 100644 qcog_python_client/qcog/pytorch/validate/validate_utils.py delete mode 100644 tests/pytorch_model/ubiops_train_pt.py create mode 100644 tests/unit/qcog/pytorch/discover/test_discover_handler.py create mode 100644 tests/unit/qcog/pytorch/validate/test_setup_monitor_import.py create mode 100644 tests/unit/qcog/pytorch/validate/test_utils.py diff --git a/.github/actions/setup_lint_and_test/action.yml b/.github/actions/setup_lint_and_test/action.yml index d97e4c6..f54905f 100644 --- a/.github/actions/setup_lint_and_test/action.yml +++ b/.github/actions/setup_lint_and_test/action.yml @@ -9,7 +9,7 @@ runs: with: python-version: '3.12.4' - name: install dev dependencies - run: pip install ".[dev]" + run: pip install .[dev,wandb] shell: bash - name: run linting run: make lint-check diff --git a/Makefile b/Makefile index 36a3e2f..8a40df1 100644 --- a/Makefile +++ b/Makefile @@ -25,3 +25,6 @@ lint-write: schema-build: python schema.py + +test-unit: + pytest -v tests/unit diff --git a/docs/requirements.txt b/docs/requirements.txt index 364d903..df3b466 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -8,3 +8,5 @@ typing_extensions pydantic pydantic_settings datamodel-code-generator +anyio +wandb diff --git a/pyproject.toml b/pyproject.toml index fc34962..1452b06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "pydantic-settings", "anyio", ] + dynamic = ["version"] [project.optional-dependencies] @@ -42,6 +43,9 @@ examples = [ "torch", "pillow", ] +wandb = [ + "wandb>=0.17.7", +] [tool.setuptools_scm] version_file = "qcog_python_client/__version__.py" @@ -123,3 +127,6 @@ exclude = [ "**/__init__.py" = ["F401", "D104"] # Remove D rules from tests "**/tests/**" = ["D"] + +[tool.pytest.ini_options] +log_cli=true diff --git a/qcog_python_client/monitor/__init__.py b/qcog_python_client/monitor/__init__.py new file mode 100644 index 0000000..c601def --- /dev/null +++ b/qcog_python_client/monitor/__init__.py @@ -0,0 +1,13 @@ +from typing import Literal, TypeAlias + +from ._wandb import WandbMonitor + +Service: TypeAlias = Literal["wandb"] + + +def get_monitor(service: Service) -> WandbMonitor: + """Return the monitoring service.""" + if service == "wandb": + return WandbMonitor() + + raise ValueError(f"Unknown service: {service}") diff --git a/qcog_python_client/monitor/_wandb.py b/qcog_python_client/monitor/_wandb.py new file mode 100644 index 0000000..e0813c9 --- /dev/null +++ b/qcog_python_client/monitor/_wandb.py @@ -0,0 +1,73 @@ +"""Wandb Monitor implementation.""" + +import os + +import wandb +import wandb.sdk as wandbsdk + +from .interface import Monitor + +WANDB_DEFAULT_PROJECT = "qognitive-dev" + + +class WandbMonitor(Monitor): + """Wandb Monitor implementation.""" + + def init( # noqa: D417 # Complains about parameters description not being present + self, + api_key: str | None = None, + project: str = WANDB_DEFAULT_PROJECT, + parameters: dict | None = None, + labels: list[str] | None = None, + trace_name: str | None = None, + ) -> None: + """Initialize the Wandb Monitor. + + Parameters + ---------- + api_key : str | None + Wandb API key. + project : str | None + Name of the project. + parameters : dict | None + Hyperparameters to be logged. + labels : list | None + Tags to be associated with the project. + + Raises + ------ + ValueError: If the Wandb API key is not provided. + + """ + key = api_key or os.getenv("WANDB_API_KEY") + + if not key: + raise ValueError( + "Wandb API key is required. Please provide a key, ether as an argument or as an environment variable -> WANDB_API_KEY" # noqa + ) + + wandbsdk.login( + anonymous="never", + key=key, + ) + + wandbsdk.init( + project=project, + config=parameters, + tags=labels, + name=trace_name, + ) + + def log(self, data: dict) -> None: # noqa: D417 # Complains about parameters description not being present + """Log data to Wandb. + + Parameters + ---------- + data : dict Data to be logged. + + """ + wandb.log(data) + + def close(self) -> None: + """Close the Wandb monitor.""" + wandbsdk.finish() diff --git a/qcog_python_client/monitor/interface.py b/qcog_python_client/monitor/interface.py new file mode 100644 index 0000000..961f0e8 --- /dev/null +++ b/qcog_python_client/monitor/interface.py @@ -0,0 +1,23 @@ +"""Define the interface for a monitoring service.""" + +from abc import ABC, abstractmethod +from typing import Any + + +class Monitor(ABC): + """Define the interface for a monitoring service.""" + + @abstractmethod + def init(self, *args: Any, **kwargs: Any) -> Any: + """Define the initialization method for the monitoring service.""" + ... + + @abstractmethod + def log(self, *args: Any, **kwargs: Any) -> Any: + """Define the logging method for the monitoring service.""" + ... + + @abstractmethod + def close(self) -> Any: + """Define the close method for the monitoring service.""" + ... diff --git a/qcog_python_client/qcog/_baseclient.py b/qcog_python_client/qcog/_baseclient.py index c8a901d..91310bd 100644 --- a/qcog_python_client/qcog/_baseclient.py +++ b/qcog_python_client/qcog/_baseclient.py @@ -67,6 +67,8 @@ def __init__(self) -> None: # noqa: D107 self._inference_result: dict | None = None self._loss: Matrix | None = None self._pytorch_model: dict | None = None + self.last_status: TrainingStatus | None = None + self.metrics: dict | None = None @property def pytorch_model(self) -> dict: @@ -153,9 +155,18 @@ def trained_model(self) -> dict: @trained_model.setter def trained_model(self, value: dict) -> None: """Set and validate the trained model.""" - self._trained_model = AppSchemasTrainTrainedModelPayloadResponse.model_validate( - value - ).model_dump() + if self.model.model_name == Model.pytorch.value: + self._trained_model = ( + AppSchemasPytorchModelPytorchTrainedModelPayloadResponse.model_validate( + value + ).model_dump() + ) + else: + self._trained_model = ( + AppSchemasTrainTrainedModelPayloadResponse.model_validate( + value + ).model_dump() + ) @property def inference_result(self) -> dict: @@ -370,18 +381,7 @@ async def _train_pytorch( ) ) - generic_trained_model = AppSchemasTrainTrainedModelPayloadResponse( - qcog_version=self.pytorch_model["model_name"], - guid=pytorch_trained_model.guid, - dataset_guid=pytorch_trained_model.dataset_guid, - training_parameters_guid=pytorch_trained_model.training_parameters_guid, - status=TrainingStatus(pytorch_trained_model.status), - loss=None, - training_completion=0, - current_batch_completion=0, - ) - - self.trained_model = generic_trained_model.model_dump() + self.trained_model = pytorch_trained_model.model_dump() return self ############################ @@ -400,6 +400,10 @@ async def _progress(self) -> dict: `status` : TrainingStatus """ + if self.model.model_name == Model.pytorch.value: + logger.warning("Progress is not available for PyTorch models.") + return {} + await self._load_trained_model() return { "guid": self.trained_model.get("guid"), @@ -412,11 +416,30 @@ async def _progress(self) -> dict: async def _load_trained_model(self) -> None: """Load the status of the current trained model.""" + if self.model.model_name == Model.pytorch.value: + raise ValueError("Load trained model is not available for PyTorch models.") + self.trained_model = await self.http_client.get( f"model/{self.trained_model['guid']}" ) async def _status(self) -> TrainingStatus: + if self.model.model_name == Model.pytorch.value: + return await self._get_pt_trained_model_status() + return await self._get_trained_model_status() + + async def _get_pt_trained_model_status(self) -> TrainingStatus: + """Retrieve a PyTorch trained model status.""" + pt_model_guid = self.trained_model["pytorch_model_guid"] + trained_model_guid = self.trained_model["guid"] + response = await self.http_client.get( + f"pytorch_model/{pt_model_guid}/trained_model/{trained_model_guid}" + ) + self.metrics = response.get("metrics", None) + self.last_status = TrainingStatus(response["status"]) + return self.last_status + + async def _get_trained_model_status(self) -> TrainingStatus: """Check the status of the training job.""" # Load last status await self._load_trained_model() @@ -437,10 +460,14 @@ async def _get_loss(self) -> Matrix | None: Loss matrix is available only after training is completed. """ + if self.model.model_name == Model.pytorch.value: + raise ValueError("Loss matrix is not available for PyTorch models.") + # loss matrix is available only after training is completed if self._loss is None: await self._load_trained_model() self._loss = self.trained_model.get("loss", None) + return self._loss async def _wait_for_training(self, poll_time: int = 60) -> None: diff --git a/qcog_python_client/qcog/_httpclient.py b/qcog_python_client/qcog/_httpclient.py index b52f903..7681f04 100644 --- a/qcog_python_client/qcog/_httpclient.py +++ b/qcog_python_client/qcog/_httpclient.py @@ -11,6 +11,7 @@ import urllib3 import urllib3.util +from qcog_python_client.log import qcoglogger as logger from qcog_python_client.qcog._interfaces import ABCRequestClient @@ -109,6 +110,10 @@ async def _request_retry( is_data = isinstance(data, aiohttp.FormData) is_json = isinstance(data, dict) + logger.debug(f"Requesting {uri} with {method} method") + logger.debug(f"is_data: {is_data}") + logger.debug(f"is_json: {is_json}") + for _ in range(self.retries): try: async with aiohttp.ClientSession( @@ -130,6 +135,12 @@ async def _request_retry( json=data, ) + elif data is None and method == HttpMethod.get: + resp = await session.request( + method.value, + uri, + ) + else: raise ValueError(f"Invalid Content Type found: {type(data)}") diff --git a/qcog_python_client/qcog/_interfaces.py b/qcog_python_client/qcog/_interfaces.py index 2243dc5..88ef7ba 100644 --- a/qcog_python_client/qcog/_interfaces.py +++ b/qcog_python_client/qcog/_interfaces.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Literal, overload +from typing import overload import aiohttp import pandas as pd @@ -28,8 +28,6 @@ async def post( self, url: str, data: dict | aiohttp.FormData, - *, - content_type: Literal["json", "data"] = "json", ) -> dict: """Execute a post request.""" ... diff --git a/qcog_python_client/qcog/pytorch/discover/discoverhandler.py b/qcog_python_client/qcog/pytorch/discover/discoverhandler.py index 65edc23..8c5fbbe 100644 --- a/qcog_python_client/qcog/pytorch/discover/discoverhandler.py +++ b/qcog_python_client/qcog/pytorch/discover/discoverhandler.py @@ -1,24 +1,88 @@ -"""Discover the module and the model.""" +"""Discover the module and the model. +Finds the folder, convert the folder into a dictionary +Create a `relevant_files` dictionary that contains the +relevant files for the model. that will be validated later. +""" + +from __future__ import annotations + +import ast +import asyncio +import io import os +from typing import ( + Iterable, +) from anyio import open_file -from qcog_python_client.qcog.pytorch.handler import BoundedCommand, Command, Handler -from qcog_python_client.qcog.pytorch.validate.validatehandler import ValidateCommand +from qcog_python_client.qcog.pytorch import utils +from qcog_python_client.qcog.pytorch.discover.types import IsRelevantFile +from qcog_python_client.qcog.pytorch.discover.utils import pkg_name +from qcog_python_client.qcog.pytorch.handler import Command, Handler +from qcog_python_client.qcog.pytorch.types import ( + Directory, + DiscoverCommand, + QFile, + RelevantFileId, + RelevantFiles, + ValidateCommand, +) + + +async def _is_model_module(self: DiscoverHandler, file: QFile) -> bool: + """Check if the file is the model module.""" + module_name = os.path.basename(file.path) + return module_name == self.model_module_name + + +async def _is_service_import_module(self: DiscoverHandler, file: QFile) -> bool: + """Check if the file is importing the monitor service.""" + # Make sure the item is not a folder. If so, exit + if os.path.isdir(file.path): + return False + + tree = ast.parse(file.content.read()) + file.content.seek(0) + + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom): + # We assume that the only way that the monitor service + # will be imported is like + # from qcog_python_client import monitor or eventually + # from qcog_python_client import monitor as . + if node.module == "qcog_python_client": + # We only support a single monitor import from qcog_python_client + # for now. + if len(node.names) > 1: + raise ValueError( + "You cannot import anything from qcog_python_client other than monitor." # noqa: E501 + ) + + return node.names[0].name == "monitor" + + return False -class DiscoverCommand(BoundedCommand): - """Payload to dispatch a discover command.""" +relevant_files_map: dict[RelevantFileId, IsRelevantFile] = { + "model_module": _is_model_module, # type: ignore + "monitor_service_import_module": _is_service_import_module, # type: ignore +} - model_name: str - model_path: str - command: Command = Command.discover +async def maybe_relevant_file( + self: DiscoverHandler, + file: QFile, +) -> dict[RelevantFileId, QFile]: + """Check if the file is relevant.""" + retval: dict[RelevantFileId, QFile] = {} -def pkg_name(package_path: str) -> str: - """From the package path, get the package name.""" - return os.path.basename(package_path) + for relevant_file_id, _maybe_relevant_file_fn in relevant_files_map.items(): + if await _maybe_relevant_file_fn(self, file): + retval.update({relevant_file_id: file}) + + return retval class DiscoverHandler(Handler): @@ -33,9 +97,10 @@ class DiscoverHandler(Handler): """ model_module_name = "model.py" # The name of the model module + monitor_service_import = "from qcog_python_client import monitor" retries = 0 commands = (Command.discover,) - relevant_files: dict + relevant_files: RelevantFiles async def handle(self, payload: DiscoverCommand) -> ValidateCommand: """Handle the discovery of a custom model. @@ -61,29 +126,65 @@ async def handle(self, payload: DiscoverCommand) -> ValidateCommand: # Check if the folder contains the model module content = os.listdir(self.model_path) - # Initialize the relevant files dictionary - self.relevant_files = {} + # --- Load training folder in memory --- + # The training folder is loaded in memory as a dictionary + # where the key is the path of the file and the value is + # a QFile object. The QFile object contains the filename, + # the path, the content of the file and the package name + + self.directory: Directory = {} + pkg_name_ = pkg_name(self.model_path) for item in content: - if item == self.model_module_name: - item_path = os.path.join(self.model_path, item) - - async with await open_file(item_path, "rb") as file: - encoded_content = await file.read() - self.relevant_files.update( - { - "model_module": { - "path": item_path, - "content": encoded_content, - "pkg_name": pkg_name(self.model_path), - } - } - ) + item_path = os.path.join(self.model_path, item) + # filter by exclusion rules + if utils.exclude(item_path): + continue + # Avoid folders + if os.path.isdir(item_path): + continue + + async with await open_file(item_path, "rb") as file: + io_file = io.BytesIO(await file.read()) + self.directory[item_path] = QFile.model_validate( + { + "path": item_path, + "filename": item, + "content": io_file, + "pkg_name": pkg_name_, + } + ) + + # --- Discover the relevant files --- + # Relevant files have a specific id that + # is specified in the `relevant_files_map` + # And used to index the file. Relevant files + # are key files that are used to run + # the training session and are further + # validate in the chain. + + self.relevant_files: RelevantFiles = {} + + # Process the files in parallel gathering the results + # from the coroutines returned by the `maybe_relevant_file` + # function. Some of the files might not be relevant. + # `lambda f: f is not None` will filter out those. + + processed: Iterable[dict[RelevantFileId, QFile]] = filter( + lambda f: f is not None, # Filter out the files that are not relevant + await asyncio.gather( + *map(lambda f: maybe_relevant_file(self, f), self.directory.values()) + ), # Process the files in parallel + ) + + # Index the relevant files on the relevantFileId + self.relevant_files = { + fid: rel for file in processed for fid, rel in file.items() + } - # Once the discovery has been completed, - # Issue a validate command that will be executed next return ValidateCommand( relevant_files=self.relevant_files, + directory=self.directory, model_name=self.model_name, model_path=self.model_path, ) @@ -91,6 +192,11 @@ async def handle(self, payload: DiscoverCommand) -> ValidateCommand: async def revert(self) -> None: """Revert the changes.""" # Unset the attributes - delattr(self, "model_name") - delattr(self, "model_path") - delattr(self, "relevant_files") + if hasattr(self, "model_name"): + delattr(self, "model_name") + if hasattr(self, "model_path"): + delattr(self, "model_path") + if hasattr(self, "directory"): + delattr(self, "directory") + if hasattr(self, "relevant_files"): + delattr(self, "relevant_files") diff --git a/qcog_python_client/qcog/pytorch/discover/types.py b/qcog_python_client/qcog/pytorch/discover/types.py new file mode 100644 index 0000000..04a3cde --- /dev/null +++ b/qcog_python_client/qcog/pytorch/discover/types.py @@ -0,0 +1,10 @@ +"""Package related types.""" + +from typing import Any, Callable, Coroutine, TypeAlias + +from qcog_python_client.qcog.pytorch.handler import Handler +from qcog_python_client.qcog.pytorch.types import DiscoverCommand, QFile + +IsRelevantFile: TypeAlias = Callable[ + [Handler[DiscoverCommand], QFile], Coroutine[Any, Any, bool] +] diff --git a/qcog_python_client/qcog/pytorch/discover/utils.py b/qcog_python_client/qcog/pytorch/discover/utils.py new file mode 100644 index 0000000..ca41cac --- /dev/null +++ b/qcog_python_client/qcog/pytorch/discover/utils.py @@ -0,0 +1,8 @@ +"""Shared utilities.""" + +import os + + +def pkg_name(package_path: str) -> str: + """From the package path, get the package name.""" + return os.path.basename(package_path) diff --git a/qcog_python_client/qcog/pytorch/handler.py b/qcog_python_client/qcog/pytorch/handler.py index f8f1bbf..0df4d0c 100644 --- a/qcog_python_client/qcog/pytorch/handler.py +++ b/qcog_python_client/qcog/pytorch/handler.py @@ -19,6 +19,8 @@ from pydantic import BaseModel +from qcog_python_client.log import qcoglogger as logger + class BoundedCommand(BaseModel): """Command type.""" @@ -129,7 +131,7 @@ async def dispatch(self, payload: CommandPayloadType) -> Handler: # 1 - revert the state of the handler # 2 - wait for the specified time # 3 - try again - print(f"Attempt {i}, error: {e}") + logger.info(f"Attempt {i}/{self.attempts}, error: {e}") exception = e await self.revert() await asyncio.sleep(self.retry_after) diff --git a/qcog_python_client/qcog/pytorch/types.py b/qcog_python_client/qcog/pytorch/types.py new file mode 100644 index 0000000..9d200cb --- /dev/null +++ b/qcog_python_client/qcog/pytorch/types.py @@ -0,0 +1,58 @@ +"""Types shared between multiple handlers.""" + +from __future__ import annotations + +import io +from typing import Literal, TypeAlias + +from pydantic import BaseModel + +from qcog_python_client.qcog.pytorch.handler import BoundedCommand, Command + + +class DiscoverCommand(BoundedCommand): + """Payload to dispatch a discover command.""" + + model_name: str + model_path: str + command: Command = Command.discover + + """Shared Types.""" + + +class ValidateCommand(BoundedCommand): + """Validate command.""" + + model_name: str + model_path: str + relevant_files: RelevantFiles + directory: Directory + command: Command = Command.validate + + +class UploadCommand(BoundedCommand): + """Payload to dispatch an upload command.""" + + upload_folder: str + model_name: str + command: Command = Command.upload + directory: Directory + + +class QFile(BaseModel): + """File object.""" + + filename: str + path: str + content: io.BytesIO + pkg_name: str | None = None + + model_config = {"arbitrary_types_allowed": True} + + +FilePath: TypeAlias = str + +RelevantFileId: TypeAlias = Literal["model_module", "monitor_service_import_module"] + +Directory: TypeAlias = dict[FilePath, QFile] +RelevantFiles: TypeAlias = dict[RelevantFileId, QFile] diff --git a/qcog_python_client/qcog/pytorch/upload/uploadhandler.py b/qcog_python_client/qcog/pytorch/upload/uploadhandler.py index c12b156..049e1c1 100644 --- a/qcog_python_client/qcog/pytorch/upload/uploadhandler.py +++ b/qcog_python_client/qcog/pytorch/upload/uploadhandler.py @@ -1,50 +1,13 @@ """Handler for uploading the model to the server.""" -import io -import os -import tarfile - import aiohttp -from qcog_python_client.log import qcoglogger as logger from qcog_python_client.qcog.pytorch.handler import ( - BoundedCommand, Command, Handler, ) - - -def compress_folder(folder_path: str) -> io.BytesIO: - """Compress a folder.""" - # We define the arcname as the basename of the folder - # In this way we avoid the full path in the tar.gz file - arcname = os.path.basename(folder_path) - - # What to exclude (__pycache__, .git, etc) - def filter(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo | None: - return ( - None - if any(name in tarinfo.name for name in {".git", "__pycache__"}) - else tarinfo - ) - - buffer = io.BytesIO() - - with tarfile.open(fileobj=buffer, mode="w:gz") as tar: - tar.add(folder_path, arcname=arcname, filter=filter) - logger.info("Compressed targzip with memebers: ", tar.getnames()) - - buffer.seek(0) - - return buffer - - -class UploadCommand(BoundedCommand): - """Payload to dispatch an upload command.""" - - upload_folder: str - model_name: str - command: Command = Command.upload +from qcog_python_client.qcog.pytorch.types import UploadCommand +from qcog_python_client.qcog.pytorch.upload.utils import compress_folder class UploadHandler(Handler[UploadCommand]): @@ -87,12 +50,11 @@ async def handle(self, payload: UploadCommand) -> None: """Handle the upload.""" folder_path = payload.upload_folder # Compress the folder - tar_gzip_folder = compress_folder(folder_path) + tar_gzip_folder = compress_folder(payload.directory, folder_path) # Retrieve the multipart request tool post_multipart = self.get_tool("post_multipart") self.data = aiohttp.FormData() - self.data.add_field( "model", tar_gzip_folder, @@ -109,4 +71,4 @@ async def handle(self, payload: UploadCommand) -> None: async def revert(self) -> None: """Revert the changes.""" - delattr(self, "data") + pass diff --git a/qcog_python_client/qcog/pytorch/upload/utils.py b/qcog_python_client/qcog/pytorch/upload/utils.py new file mode 100644 index 0000000..50fc643 --- /dev/null +++ b/qcog_python_client/qcog/pytorch/upload/utils.py @@ -0,0 +1,46 @@ +"""Utility functions for uploading files to the QCoG platform.""" + +import io +import os +import tarfile + +from qcog_python_client.log import qcoglogger as logger +from qcog_python_client.qcog.pytorch.types import Directory + + +def compress_folder(directory: Directory, folder_path: str) -> io.BytesIO: + """Compress a folder from a in-memory directory. + + Parameters + ---------- + directory : Directory + The directory to compress, where the key is the path of the file + and the value is a QFile object. + + folder_path : str + The path of the folder to compress. + + Returns + ------- + io.BytesIO + The compressed folder as a BytesIO object representing a tarfile. + + """ + buffer = io.BytesIO() + + with tarfile.open(fileobj=buffer, mode="w:gz") as tar: + for qfile in directory.values(): + # Get a relative path from the arcname + rel_path = os.path.relpath(qfile.path, folder_path) + logger.info(f"Adding {rel_path} to training package...") + + # Create TarInfo object and set the size + tarfinfo = tarfile.TarInfo(name=rel_path) + tarfinfo.size = len(qfile.content.getvalue()) + + # Add the file to the tar archive + tar.addfile(tarfinfo, fileobj=io.BytesIO(qfile.content.getvalue())) + + buffer.seek(0) + + return buffer diff --git a/qcog_python_client/qcog/pytorch/utils.py b/qcog_python_client/qcog/pytorch/utils.py new file mode 100644 index 0000000..e57b592 --- /dev/null +++ b/qcog_python_client/qcog/pytorch/utils.py @@ -0,0 +1,82 @@ +"""Utility functions for the PyTorch client.""" + +import io +import os +import re +from typing import Callable + +from qcog_python_client.qcog.pytorch.types import FilePath, QFile + +default_rules = { + re.compile(r".*__pycache__.*"), + re.compile(r".*\.git.*"), +} + + +def exclude(file_path: str, rules: set[re.Pattern[str]] | None = None) -> bool: + """Check against a set of regexes rules. + + Parameters + ---------- + file_path : str + The path to the file. + + rules : set[re.Pattern[str]] | None + The set of regex rules to check against. Defaults to `default_rules`. + if None is passed. + + """ + rules = rules or default_rules + for pattern in rules: + if re.match(pattern, file_path): + return True + + return False + + +def get_folder_structure( + file_path: str, *, filter: Callable[[FilePath], bool] | None = None +) -> dict[FilePath, QFile]: + """Return the folder structure as a dictionary. + + Parameters + ---------- + file_path : str + The path to the folder. + + filter : Callable[[FilePath], bool] | None + The filter function to apply to the folder structure. + to exclude some paths + + Returns + ------- + dict[FilePath, QFile] + The folder structure as a dictionary. + + """ + folder_items = os.listdir(file_path) + + retval: dict[FilePath, QFile] = {} + for item in folder_items: + item_path = os.path.join(file_path, item) + + if filter and filter(item_path): + continue + + # Check if the item is a file + if os.path.isfile(item_path): + with open(item_path, "rb") as f: + retval[item_path] = QFile.model_validate( + { + "path": item_path, + "filename": item, + "content": io.BytesIO(f.read()), + } + ) + elif os.path.isdir(item_path): + retval.update(get_folder_structure(item_path)) + + else: + raise ValueError(f"Item {item_path} is neither a file nor a directory.") + + return retval diff --git a/qcog_python_client/qcog/pytorch/validate/__init__.py b/qcog_python_client/qcog/pytorch/validate/__init__.py index 17d9675..e69de29 100644 --- a/qcog_python_client/qcog/pytorch/validate/__init__.py +++ b/qcog_python_client/qcog/pytorch/validate/__init__.py @@ -1 +0,0 @@ -from .validatehandler import ValidateCommand, ValidateHandler diff --git a/qcog_python_client/qcog/pytorch/validate/_setup_monitor_import.py b/qcog_python_client/qcog/pytorch/validate/_setup_monitor_import.py new file mode 100644 index 0000000..a658dbe --- /dev/null +++ b/qcog_python_client/qcog/pytorch/validate/_setup_monitor_import.py @@ -0,0 +1,142 @@ +import ast +import copy +import io +import os +from pathlib import Path +from typing import Callable + +from qcog_python_client import monitor +from qcog_python_client.qcog.pytorch import utils +from qcog_python_client.qcog.pytorch.handler import Handler +from qcog_python_client.qcog.pytorch.types import ( + Directory, + QFile, + ValidateCommand, +) + +MONITOR_PACKAGE_NAME = "_monitor_" +MONITOR_PACKAGE_FOLDER_PATH = str(Path(os.path.abspath(monitor.__file__)).parent) + + +def setup_monitor_import( + self: Handler[ValidateCommand], + file: QFile, + directory: Directory, + monitor_package_folder_path: str = MONITOR_PACKAGE_FOLDER_PATH, + folder_content_getter: Callable[ + [str], Directory + ] = lambda folder_path: utils.get_folder_structure( + folder_path, filter=utils.exclude + ), +) -> Directory: + """Monitor import setup. + + Parameters + ---------- + self: Handler[ValidateCommand] + The handler calling this function + file : QFile + The file to validate. + + directory : Directory + The directory to validate. + + monitor_package_folder_path : str + The path to the monitor package folder. + + folder_content_getter : Callable[[str], Directory] + The function to get the content of the folder. + + Returns + ------- + Directory + The updated directory. + + """ + directory = copy.deepcopy(directory) + # We need to add the monitoring package from the qcog_package into + # the training directory and update the import on the file in order + # to point to the new location. + + # From the monitor_package_folder_path location, + # create a dictionary with the content of the + # monitor package. The dictionary will have the path of the file + # as the key and the file as the value. + monitor_package_content = folder_content_getter(monitor_package_folder_path) + + # Now we want to copy the package to the training directory. + # This "copy" is only happening in memory, we are not writing. + # The `folder` is defined by the `keys` of the dictionary. + # We need to change the keys and the `path` of the files in the + # monitor_package_content dictionary in order to match the new location + # defined by the keys of the `directory` dictionary. + + # The `root` of the folder is defined as the parent folder + # of the file to validate. + + # We can use that to re-construct the new path of the files + # in the monitor_package_content dictionary. + + root = Path(file.path).parent + + for file_path, file_ in monitor_package_content.items(): + # Find the root of the monitor package + # Get the relative path of the file + relative_path = os.path.relpath(file_path, monitor_package_folder_path) + + # prepend the relative path of the content of the package + # with the package name, that, in this case is `_monitor_` + # to avoid conflicts with the user's package. + # Et voilĂ , we have the new path of the file in the training + new_path = os.path.join(root, MONITOR_PACKAGE_NAME, relative_path) + + # Update the path of the file + file_.path = new_path + file_.pkg_name = MONITOR_PACKAGE_NAME + + # Update the directory + directory[new_path] = file_ + + # Now we need to update the import in the file that has the + # import of the monitor package. The file is the same file + # at the address of the file to validate. + + # Generate the ast from the content + ast_tree = ast.parse(file.content.getvalue()) + + # Now we need to find the import statement `from qcog_python_client import monitor` + # Remove it and add a new statement `import _monitor_ as monitor` + + for node in ast.walk(ast_tree): + if isinstance(node, ast.ImportFrom) and node.module == "qcog_python_client": + if len(node.names) > 1: + raise ValueError( + "Only one import is allowed from the qcog_python_client package." # noqa: E501 + ) + + package_name = node.names[0].name + + if package_name != "monitor": + raise ValueError("The only package that can be imported is monitor.") + + # Now we need to remove the import statement + # and add a new one + ast_tree.body.remove(node) + + # Add the new import statement + new_import = ast.Import( + names=[ast.alias(name=MONITOR_PACKAGE_NAME, asname="monitor")] + ) + + ast_tree.body.insert(0, new_import) + + # Now re-write the content of the file + # starting from the modified AST tree + # Parse the AST tree to a string see: https://stackoverflow.com/questions/768634/parse-a-py-file-read-the-ast-modify-it-then-write-back-the-modified-source-c + file.content = io.BytesIO(ast.unparse(ast_tree).encode()) + + directory[file.path] = file + + return directory + + raise ValueError("No monitor import found in the file.") diff --git a/qcog_python_client/qcog/pytorch/validate/_validate_module.py b/qcog_python_client/qcog/pytorch/validate/_validate_model_module.py similarity index 72% rename from qcog_python_client/qcog/pytorch/validate/_validate_module.py rename to qcog_python_client/qcog/pytorch/validate/_validate_model_module.py index c5836a6..acd346b 100644 --- a/qcog_python_client/qcog/pytorch/validate/_validate_module.py +++ b/qcog_python_client/qcog/pytorch/validate/_validate_model_module.py @@ -1,34 +1,23 @@ +import copy import importlib import inspect import os import sys -from dataclasses import dataclass +from typing import Any from pydantic import BaseModel -from qcog_python_client.qcog.pytorch.validate.validate_utils import ( +from qcog_python_client.qcog.pytorch.handler import Handler +from qcog_python_client.qcog.pytorch.types import ( + Directory, + QFile, + ValidateCommand, +) +from qcog_python_client.qcog.pytorch.validate.utils import ( get_third_party_imports, is_package_module, ) - -class FileToValidate(BaseModel): - path: str - content: str - pkg_name: str - - -@dataclass -class TrainFnAnnotation: - arg_name: str - arg_type: type - - -@dataclass -class ValidateModelModule: - train_fn: dict[str, TrainFnAnnotation] - - # whitelist of allowed modules default_allowed_modules = { "torch", @@ -39,34 +28,36 @@ class ValidateModelModule: } +class TrainFnAnnotation(BaseModel): + """Train function annotation.""" + + arg_name: str + arg_type: Any + + def validate_model_module( - file: FileToValidate, + self: Handler[ValidateCommand], + file: QFile, + directory: Directory, allowed_modules: set[str] | None = None, -) -> ValidateModelModule: +) -> Directory: """Validate the model module.""" + directory = copy.deepcopy(directory) allowed_modules = allowed_modules or default_allowed_modules dir_path = os.path.dirname(file.path) - content = os.listdir(dir_path) # Very naive way to inspect all the package. # Assumes one level deep and doesn't recurse. modules_found = set() - - for item in content: - # Inspect each python file and try to find third-party modules - if item.endswith(".py"): - third_party_modules = get_third_party_imports(os.path.join(dir_path, item)) + for item_path, item in directory.items(): + if item_path.endswith(".py"): + third_party_modules = get_third_party_imports(item.content, dir_path) for module_name in third_party_modules: # If the module_name name is the package, skip it if module_name == file.pkg_name: continue - - # For each module check if it's part of the current package - # If not, raise an error - module_path = os.path.join(dir_path, module_name) - # If the module is contained is not in the package - if not is_package_module(module_path): + if not is_package_module(item_path): modules_found.add(module_name) # Check if the modules found are allowed @@ -111,4 +102,7 @@ def validate_model_module( arg_name=ann, arg_type=train_fn.__annotations__[ann] ) - return ValidateModelModule(train_fn=train_fn_annotations) + # Set the train function annotations on the handler + self.train_fn_annotations = train_fn_annotations # type: ignore + # Directory has been validated + return directory diff --git a/qcog_python_client/qcog/pytorch/validate/utils.py b/qcog_python_client/qcog/pytorch/validate/utils.py new file mode 100644 index 0000000..c7b4310 --- /dev/null +++ b/qcog_python_client/qcog/pytorch/validate/utils.py @@ -0,0 +1,86 @@ +"""Utility functions for validating the input data.""" + +import ast +import distutils +import distutils.sysconfig +import importlib +import io +import os +import sys + +# def validate_directory(dir: dict) -> Directory: +# """Validate the directory. + +# It takes the raw dictionary and validate +# """ +# return {k: QFile(**v) for k, v in dir.items()} + + +def get_third_party_imports(source_code: io.BytesIO, package_path: str) -> set[str]: + """Get all third-party packages imported in a Python module. + + Parameters + ---------- + source_code : io.BytesIO + The source code of the module. + package_path : str + The path of the package to which the module belongs. + + Returns + ------- + A set of third-party packages imported by the module. + + """ + # Parse the source code + tree = ast.parse(source_code.getvalue()) + + # Find all import statements + imports: set[str] = set() + for node in ast.walk(tree): + # Import nodes can be of type ast.Import or ast.ImportFrom + # as the import statement can be of the form `import module` + # or `from module import submodule` + if isinstance(node, ast.Import): + for name in node.names: + imports.add(name.name) + elif isinstance(node, ast.ImportFrom): + if node.module: + imports.add(node.module) + + # Identify third-party packages + third_party_packages = set() + + # Get the path of the standard library. + # All the modules that are OS dependent are on this path + python_sys_lib = distutils.sysconfig.get_python_lib(standard_lib=True) + + for imp_ in imports: + # Split the package name to handle submodules + base_package = imp_.split(".")[0] + + # Check if it's a package that belongs to the current package + if is_package_module(os.path.join(package_path, base_package)): + continue + + spec = importlib.util.find_spec(base_package) + + if spec is None or spec.origin is None: + continue + + path = spec.origin + # If the path of the module matches the path of the standard library + # or the module is a built-in module, then it is not a third-party + + if ( + base_package not in sys.builtin_module_names + and spec.origin != "built-in" + and python_sys_lib not in path + ): + third_party_packages.add(base_package) + return third_party_packages + + +def is_package_module(module_path: str) -> bool: + """Check if a Python module exists in the specified path.""" + module_path = module_path if module_path.endswith(".py") else module_path + ".py" + return os.path.isfile(module_path) diff --git a/qcog_python_client/qcog/pytorch/validate/validate_utils.py b/qcog_python_client/qcog/pytorch/validate/validate_utils.py deleted file mode 100644 index 30a15c2..0000000 --- a/qcog_python_client/qcog/pytorch/validate/validate_utils.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Utility functions for validating the model package.""" - -import ast -import os -import pkgutil -import sys -from functools import lru_cache - -from qcog_python_client.log import qcoglogger as logger - - -@lru_cache -def get_stdlib_modules() -> set[str]: - """Get a set of all standard library modules.""" - stdlib_modules: set[str] = set() - for importer, modname, ispkg in pkgutil.iter_modules(): - # Exclude packages - stdlib_modules.add(modname) - - return stdlib_modules - - -def get_third_party_imports(module_path: str) -> set[str]: - """Get all third-party packages imported in a Python module. - - Parameters - ---------- - module_path : str - The absolute path to the module file (e.g., /path/to/module.py). - - Returns - ------- - A set of third-party packages imported by the module. - - """ - # Check if the file exists - if not os.path.isfile(module_path): - logger.warning(f"Module file not found: {module_path}") - return set() - - # Read the module's source code - with open(module_path, "r") as file: - source_code = file.read() - - # Parse the source code - tree = ast.parse(source_code) - - # Find all import statements - imports: set[str] = set() - for node in ast.walk(tree): - # Import nodes can be of type ast.Import or ast.ImportFrom - # as the import statement can be of the form `import module` - # or `from module import submodule` - if isinstance(node, ast.Import): - for alias in node.names: - imports.add(alias.name) - elif isinstance(node, ast.ImportFrom): - if node.module: - imports.add(node.module) - - # Get list of standard library modules - stdlib_modules = get_stdlib_modules() - - # Identify third-party packages - third_party_packages = set() - for imp in imports: - # Split the package name to handle submodules - base_package = imp.split(".")[0] - if ( - base_package not in stdlib_modules - and base_package not in sys.builtin_module_names - ): - third_party_packages.add(base_package) - - return third_party_packages - - -def is_package_module(module_path: str) -> bool: - """Check if a Python module exists in the specified path.""" - # Check if the file exists - - module_path = module_path if module_path.endswith(".py") else module_path + ".py" - return os.path.isfile(module_path) diff --git a/qcog_python_client/qcog/pytorch/validate/validatehandler.py b/qcog_python_client/qcog/pytorch/validate/validatehandler.py index fa0b0a6..d1e0315 100644 --- a/qcog_python_client/qcog/pytorch/validate/validatehandler.py +++ b/qcog_python_client/qcog/pytorch/validate/validatehandler.py @@ -1,51 +1,84 @@ """Validate the model module.""" -from typing import Any, Callable +from __future__ import annotations -from qcog_python_client.qcog.pytorch.handler import BoundedCommand, Command, Handler +import os +from typing import Callable + +from qcog_python_client.log import qcoglogger as logger +from qcog_python_client.qcog.pytorch.handler import Command, Handler +from qcog_python_client.qcog.pytorch.types import ( + Directory, + QFile, + RelevantFileId, + ValidateCommand, +) from qcog_python_client.qcog.pytorch.upload.uploadhandler import UploadCommand -from qcog_python_client.qcog.pytorch.validate._validate_module import ( - FileToValidate, +from qcog_python_client.qcog.pytorch.validate._setup_monitor_import import ( + setup_monitor_import, +) +from qcog_python_client.qcog.pytorch.validate._validate_model_module import ( validate_model_module, ) -class ValidateCommand(BoundedCommand): - """Validate command.""" - - model_name: str - model_path: str - relevant_files: dict - command: Command = Command.validate - - class ValidateHandler(Handler): """Validate the model module.""" commands = (Command.validate,) attempts = 1 + directory: Directory - validate_map: dict[str, Callable[[FileToValidate], Any]] = { - "model_module": validate_model_module + validate_map: dict[ + RelevantFileId, Callable[[ValidateHandler, QFile, Directory], Directory] + ] = { + "model_module": validate_model_module, + "monitor_service_import_module": setup_monitor_import, } async def handle(self, payload: ValidateCommand) -> UploadCommand: """Handle the validation.""" - validated: list = [] + self.directory = payload.directory + # `directory` will go through a series of validations + # based on the `validate_map` keys. for key, validate_fn in self.validate_map.items(): - file = payload.relevant_files.get(key) + relevant_file = payload.relevant_files.get(key) - if not file: - raise FileNotFoundError(f"File {key} not found in the relevant files.") + if not relevant_file: + raise FileNotFoundError( + f"File {key} not found in the relevant files. Keys: {payload.relevant_files.keys()}" # noqa: E501 + ) - parsed = FileToValidate.model_validate(file) - validated.append(validate_fn(parsed)) + parsed = QFile.model_validate(relevant_file) + self.directory = validate_fn(self, parsed, self.directory) + + verify_directory(self.directory) return UploadCommand( - upload_folder=payload.model_path, model_name=payload.model_name + upload_folder=payload.model_path, + model_name=payload.model_name, + directory=self.directory, ) async def revert(self) -> None: """Revert the changes.""" pass + + +def verify_directory(d: Directory) -> None: + """Verify the directory.""" + for file_path, file in d.items(): + if file_path != file.path: + raise ValueError(f"File path mismatch: {file_path} != {file.path}") + + if file.filename != os.path.basename(file_path): + raise ValueError( + f"File path is not a basename: {file_path} - {file.filename}" + ) + + # Make sure the file content is not empty + if not file.content.read(): + logger.warning(f"File content is empty: {file_path}") + + file.content.seek(0) diff --git a/tests/pytorch_model/model.py b/tests/pytorch_model/model.py index 827b0a7..c6785bb 100644 --- a/tests/pytorch_model/model.py +++ b/tests/pytorch_model/model.py @@ -9,6 +9,8 @@ from sklearn.calibration import LabelEncoder from torch.autograd import Variable +from qcog_python_client import monitor + def train( data: pd.DataFrame, @@ -18,6 +20,15 @@ def train( ) -> dict: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + m_service = monitor.get_monitor("wandb") + + m_service.init( + api_key="", + parameters={ + "epochs": epochs, + }, + ) + cols = data.columns # Show the data @@ -46,9 +57,12 @@ def train( y_pred = model(x_data.float()) loss = criterion(y_pred, y_data.view(-1, 1).float()) # print('Epoch', epoch, 'Loss:',e loss.item(), '- Pred:', y_pred.data[0]) + m_service.log({"loss": loss.item(), "epoch": epoch}) loss_list.append(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() + m_service.close() + return {"model": model, "metrics": {"loss": loss_list}} diff --git a/tests/pytorch_model/ubiops_train_pt.py b/tests/pytorch_model/ubiops_train_pt.py deleted file mode 100644 index 3e6e0b2..0000000 --- a/tests/pytorch_model/ubiops_train_pt.py +++ /dev/null @@ -1,184 +0,0 @@ -"""Wrapper around a customer defined training function. -Local development version for local testing. -""" - -import base64 -import gzip -import io -import json -import os -import re -import sys -from typing import Any, Callable, Literal, TypedDict - -import pandas as pd -import torch - -train_fn: Callable[..., dict] | None = None - -# LOG PWD - -print("-------> PWD: ", os.getcwd()) - -# LOG MODULE PATH - -print("-------> MODULE PATH: ", __file__) - -sys.path.append(os.getcwd()) - -# Try to import the customer function -try: - from model import train as _train_fn - - train_fn = _train_fn -except ImportError as e: - raise ImportError( - "Failed to import the train function from the model.py file: ", e - ) from e - - -def decode_base64(encoded_string: str) -> str: - """Decode into original string str type. - - Parameters - ---------- - encoded_string: str - encoded base64 string - - Returns - ------- - str: decoded string - """ - base64_bytes: bytes = encoded_string.encode("ascii") - decoded_bytes: bytes = base64.b64decode(base64_bytes) - return decoded_bytes.decode("ascii") - - -Compression = Literal["gzip"] - - -class DataFramePayload(TypedDict): - blob: str - indexing: list[int] - - -def base642dataframe(encoded_string: str) -> pd.DataFrame: - """Decode a base64 string and parse as DataFrame. - - Parameters - ---------- - encoded_string: str - base64 encoded string - - Returns - ------- - pd.DataFrame: parsed csv dataframe - """ - decoded_string: str = decode_base64(encoded_string) - raw_dict: dict = json.loads(decoded_string) - - payload: DataFramePayload = DataFramePayload( - blob=raw_dict["blob"], - indexing=raw_dict["indexing"], - ) - s = io.StringIO(payload["blob"]) - return pd.read_csv(s, index_col=payload["indexing"]) - - -def base64Location2dataframe( # noqa - location: str, *, compression: Compression | None = None -) -> pd.DataFrame: - """Return from a file location, a pandas dataframe. - - Parameters - ---------- - location: str - the location of the file to be read - - compression: 'gzip' | None - specify if a compression has been applied to the file - - Returns - ------- - pd.DataFrame: parsed csv dataframe - """ - encoded_string: str | bytes | None = None - - import boto3 - - s3_client = boto3.client("s3") - key = "/".join(location.split("/")[3:]) - - # Location has two parts that are interesting to us: - # 1. The bucket name - # 2. The key of the file in the bucket - # The url is classic s3 url format: s3://bucket-name/key - # We need to extract the bucket name and the key - - pattern = r"s3://(?P[^/]+)/(?P.+)" - match = re.match(pattern, location) - - if not match: - raise ValueError("Invalid s3 url format") - - bucket, key = match.group("bucket"), match.group("key") - - print("=============================") - print("Bucket: ", bucket) - print("Key: ", key) - print("=============================") - - file_object = io.BytesIO() - s3_client.download_fileobj(bucket, key, file_object) - file_object.seek(0) - - # If the file is compressed, open it with the right method - if compression == "gzip": - with gzip.GzipFile(fileobj=file_object, mode="r") as f: - encoded_string = f.read() - - # Otherwise assume it's just a base64 encoded string - else: - with open(location, "r") as f: - encoded_string = f.read() - - # Decompress the file if it's a bytes object and gzip compression is specified - if isinstance(encoded_string, bytes): - encoded_string = encoded_string.decode("ascii") - - return base642dataframe(encoded_string) - - -# UbiOps training function -def train( - training_data: str, parameters: dict, context: dict, base_directory: Any -) -> dict: - """Train function adapter for UbiOps.""" - print("Context: ", context) - if train_fn is None: - raise ValueError("No train function found") - - training_data_location = training_data - - print("DEBUG: Training with parameters: ", parameters) - print("Training data Location: ", training_data_location) - print("Running Training Function...") - - data = base64Location2dataframe(training_data_location, compression="gzip") - result = train_fn(data, **parameters["parameters"]) - - print("Training Function Completed...") - - model: torch.nn.Module = result.get("model") - metrics: dict = result.get("metrics", {}) - - # Save the model - artifact_filename = parameters.get("artifact_filename", "model.pt") - torch.save(model, artifact_filename) - - return { - "artifact": artifact_filename, - "metadata": {}, - "metrics": {"run_id": context.get("run_id"), **metrics}, - "additional_output_files": [], - } diff --git a/tests/test_pytorch_agent.py b/tests/test_pytorch_agent.py index 334e520..53a0f0c 100644 --- a/tests/test_pytorch_agent.py +++ b/tests/test_pytorch_agent.py @@ -1,47 +1,13 @@ import os -import aiohttp import pytest from qcog_python_client.qcog import AsyncQcogClient -from qcog_python_client.qcog._httpclient import RequestClient -from qcog_python_client.qcog.pytorch.agent import PyTorchAgent from tests.datasets import get_wbc_data df_train, df_test, df_target = get_wbc_data() -@pytest.mark.asyncio -async def test_pytorch_agent_discovery(): - """Test basic discovery of PyTorchAgent""" - - model_path = "tests/pytorch_model" - model_name = "test_model_01" - # Register custom tools for the agent. - # The following `tools` will be available - # inside the handlers of the agent - request_client = RequestClient( - token=os.getenv("API_TOKEN"), - hostname="localhost", - port=8000, - ) - - async def post_multipart(url: str, data: aiohttp.FormData) -> dict: - return await request_client.post( - url, - data, - content_type="data", - ) - - agent = PyTorchAgent.create_agent( - post_request=request_client.post, - get_request=request_client.get, - post_multipart=post_multipart, - ) - - await agent.upload(model_path, model_name) - - @pytest.mark.asyncio async def test_pytorch_workflow(): client = await AsyncQcogClient.create( @@ -51,7 +17,7 @@ async def test_pytorch_workflow(): ) client = await client.pytorch( - model_name="test-model-03", + model_name="test-model-05", model_path="tests/pytorch_model", ) client = await client.data(df_train) @@ -62,3 +28,6 @@ async def test_pytorch_workflow(): "epochs": 10, } ) + print(await client.status()) + await client.wait_for_training(poll_time=10) + print(client.metrics) diff --git a/tests/unit/qcog/pytorch/discover/test_discover_handler.py b/tests/unit/qcog/pytorch/discover/test_discover_handler.py new file mode 100644 index 0000000..f650e1b --- /dev/null +++ b/tests/unit/qcog/pytorch/discover/test_discover_handler.py @@ -0,0 +1,170 @@ +# tests/pythorch_model/test_discoverhandler.py +import io +import os + +import pytest + +from qcog_python_client.qcog.pytorch.discover import DiscoverCommand, DiscoverHandler +from qcog_python_client.qcog.pytorch.discover.discoverhandler import ( + _maybe_model_module, + _maybe_monitor_service_import_module, +) +from qcog_python_client.qcog.pytorch.types import QFile +from qcog_python_client.qcog.pytorch.validate.validatehandler import ValidateCommand + + +@pytest.fixture +def mock_model_dir(tmp_path): + # Create a temporary directory structure + model_dir = tmp_path / "model" + model_dir.mkdir() + + # Create mock files + model_file = model_dir / "model.py" + model_file.write_text("print('This is a model file')") + + monitor_file = model_dir / "monitor.py" + monitor_file.write_text("from qcog_python_client import monitor") + + other_file = model_dir / "other.py" + other_file.write_text("print('This is another file')") + + return model_dir + + +@pytest.fixture +def discover_handler(): + return DiscoverHandler() + + +@pytest.mark.asyncio +async def test_handle(mock_model_dir, discover_handler): + payload = DiscoverCommand(model_name="test_model", model_path=str(mock_model_dir)) + result = await discover_handler.handle(payload) + + assert isinstance(result, ValidateCommand) + + relevant_file_ids = result.relevant_files.keys() + dir_file_names = {f.filename for f in result.directory.values()} + assert "model_module" in relevant_file_ids + assert "monitor_service_import_module" in relevant_file_ids + assert "model.py" in dir_file_names + assert "monitor.py" in dir_file_names + assert "other.py" in dir_file_names + + +@pytest.mark.asyncio +async def test_maybe_model_module_positive(mock_model_dir, discover_handler): + model_file_path = os.path.join(mock_model_dir, "model.py") + with open(model_file_path, "rb") as f: # noqa: ASYNC230 + file_content = io.BytesIO(f.read()) + + f = QFile( + filename="model.py", + path=model_file_path, + content=file_content, + pkg_name=None, + ) + discover_handler.model_path = mock_model_dir + result = await _maybe_model_module(discover_handler, f) + assert result is f + + +@pytest.mark.asyncio +async def test_maybe_model_module_negative(mock_model_dir, discover_handler): + model_file_path = os.path.join(mock_model_dir, "monitor.py") + with open(model_file_path, "rb") as f: # noqa: ASYNC230 + file_content = io.BytesIO(f.read()) + + f = QFile( + filename="monitor.py", + path=model_file_path, + content=file_content, + pkg_name=None, + ) + + discover_handler.model_path = mock_model_dir + result = await _maybe_model_module(discover_handler, f) + assert result is None + + +@pytest.mark.asyncio +async def test_maybe_monitor_service_import_module(mock_model_dir, discover_handler): + monitor_file_path = os.path.join(mock_model_dir, "monitor.py") + with open(monitor_file_path, "rb") as f: # noqa: ASYNC230 + file_content = io.BytesIO(f.read()) + + f = QFile( + filename="monitor.py", + path=monitor_file_path, + content=file_content, + pkg_name=None, + ) + + discover_handler.model_path = mock_model_dir + result = await _maybe_monitor_service_import_module(discover_handler, f) + assert result is f + + +@pytest.mark.asyncio +async def test_maybe_monitor_service_import_module_wrong_import( + mock_model_dir, discover_handler +): + # monitor_file_path = os.path.join(mock_model_dir, "monitor.py") + # with open(monitor_file_path, "rb") as f: # noqa: ASYNC230 + # file_content = io.BytesIO(f.read()) + + monitor_file_path = os.path.join(mock_model_dir, "monitor.py") + file_content = io.BytesIO(b"from qcog_python_client import monitor, extra") + + f = QFile( + filename="monitor.py", + path=monitor_file_path, + content=file_content, + pkg_name=None, + ) + + discover_handler.model_path = mock_model_dir + with pytest.raises(ValueError) as exc_info: + await _maybe_monitor_service_import_module(discover_handler, f) + assert ( + "You cannot import anything from qcog_python_client other than monitor." + in str(exc_info.value) + ) + + +@pytest.mark.asyncio +async def test_maybe_monitor_service_import_module_with_alias( + mock_model_dir, discover_handler +): + # monitor_file_path = os.path.join(mock_model_dir, "monitor.py") + # with open(monitor_file_path, "rb") as f: # noqa: ASYNC230 + # file_content = io.BytesIO(f.read()) + + monitor_file_path = os.path.join(mock_model_dir, "monitor.py") + file_content = io.BytesIO(b"from qcog_python_client import monitor as mon") + + f = QFile( + filename="monitor.py", + path=monitor_file_path, + content=file_content, + pkg_name=None, + ) + + discover_handler.model_path = mock_model_dir + result = await _maybe_monitor_service_import_module(discover_handler, f) + + assert result is f + + +@pytest.mark.asyncio +async def test_revert(discover_handler): + discover_handler.model_name = "test_model" + discover_handler.model_path = "/path/to/model" + discover_handler.relevant_files = {"model_module": {}} + + await discover_handler.revert() + + assert not hasattr(discover_handler, "model_name") + assert not hasattr(discover_handler, "model_path") + assert not hasattr(discover_handler, "relevant_files") diff --git a/tests/unit/qcog/pytorch/validate/test_setup_monitor_import.py b/tests/unit/qcog/pytorch/validate/test_setup_monitor_import.py new file mode 100644 index 0000000..2e4cdd9 --- /dev/null +++ b/tests/unit/qcog/pytorch/validate/test_setup_monitor_import.py @@ -0,0 +1,174 @@ +import io +from unittest.mock import Mock + +import pytest +from anyio import Path + +from qcog_python_client.qcog.pytorch.handler import Handler +from qcog_python_client.qcog.pytorch.types import QFile + + +@pytest.fixture +def mock_handler(): + return Mock(spec=Handler) + + +@pytest.fixture +def monitor_package_folder_path(): + return "/package_path/to/monitor_package" + + +@pytest.fixture +def training_package_folder_path(): + return "/package_path/to/training_package" + + +@pytest.fixture +def mock_relevant_file(training_package_folder_path): + """Mock the relevant file that has a monitor import statement.""" + # Sample file content with import statement + sample_content = b""" +from qcog_python_client import monitor + +def dummy_function(): + pass + """ + return QFile( + filename="train.py", + path=f"{training_package_folder_path}/train.py", + content=io.BytesIO(sample_content), + pkg_name="training_package", + ) + + +@pytest.fixture +def mock_training_package(training_package_folder_path, mock_relevant_file): + """Mock a training package with a train.py file that contains the import statement.""" # noqa: E501 + return { + f"{training_package_folder_path}/__init__.py": QFile( + filename="__init__.py", + path=f"{training_package_folder_path}/__init__.py", + content=io.BytesIO(b""), + pkg_name="training_package", + ), + f"{training_package_folder_path}/train.py": mock_relevant_file, + } + + +@pytest.fixture +def mock_monitor_package(monitor_package_folder_path): + """Mock a monitor package that is located in another path.""" + return { + f"{monitor_package_folder_path}/__init__.py": QFile( + filename="__init__.py", + path=f"{monitor_package_folder_path}/__init__.py", + content=io.BytesIO(b""), + pkg_name="monitor_package", + ), + f"{monitor_package_folder_path}/monitor.py": QFile( + filename="monitor.py", + path=f"{monitor_package_folder_path}/monitor.py", + content=io.BytesIO(b"def dummy_function(): \n\tpass"), + pkg_name="monitor_package", + ), + } + + +def test_setup_monitor_import_directory_update( + mock_handler, + mock_monitor_package, + mock_relevant_file, + mock_training_package, + monitor_package_folder_path, + training_package_folder_path, +): + from qcog_python_client.qcog.pytorch.validate._setup_monitor_import import ( + setup_monitor_import, + ) + + # In this test we want to make sure that, no matter where the monitor + # package is located, the files are correctly copied into the directory + # and the path of the files is correctly moved inside the directory. + + # We are overriding the two functions to get the monitor package folder + # and the monitor package content in order to return the mock values + + # We assume that we will find the same files that are inside the mocked + # monitor package, inside the directory within a _monitor_ folder. + + updated_directory = setup_monitor_import( + mock_handler, + mock_relevant_file, + mock_training_package, + monitor_package_folder_path_getter=lambda: monitor_package_folder_path, + folder_content_getter=lambda folder_path: mock_monitor_package, + ) + + monitor_files = [ + (path, f) for path, f in updated_directory.items() if "monitor" in path + ] + + # The file moved are the same as the mocked monitor package + assert len(monitor_files) == len(mock_monitor_package) + + # All the files moved have the same path as the keys + assert any(path == f.path for path, f in monitor_files) + + # The base path of the moved files is the same as + # the base path of the training package + monitor_file_paths = [str(Path(path).parent) for path, _ in monitor_files] + assert any( + path == training_package_folder_path + "/_monitor_" + for path in monitor_file_paths + ) + + # The relevant file import has been updated to point to the new location + relevant_file = updated_directory.get(mock_relevant_file.path) + + assert relevant_file is not None + + relevant_file_content = relevant_file.content.read() + assert b"import _monitor_ as monitor" in relevant_file_content + + # Make sure the old import is not there anymore + assert b"from qcog_python_client import monitor" not in relevant_file_content + + +def test_update_import_exceptions_multiple_files_imported_from_qcog_python_client( + mock_handler, + mock_monitor_package, + mock_relevant_file, + mock_training_package, + monitor_package_folder_path, + training_package_folder_path, +): + from qcog_python_client.qcog.pytorch.validate._setup_monitor_import import ( + setup_monitor_import, + ) + + file_with_multiple_imports = b""" +from qcog_python_client import monitor, other_module + +def dummy_function(): + pass +""" + file = QFile( + filename="train.py", + path=f"{training_package_folder_path}/train.py", + content=io.BytesIO(file_with_multiple_imports), + pkg_name="training_package", + ) + + # Overrider the file with one that has multiple imports from the qcog_python_client + mock_training_package[file.path] = file + + with pytest.raises(ValueError) as exc_info: + setup_monitor_import( + mock_handler, + file, + mock_training_package, + monitor_package_folder_path_getter=lambda: monitor_package_folder_path, + folder_content_getter=lambda folder_path: mock_monitor_package, + ) + + exc_info == "Only one import is allowed from the qcog_python_client package." diff --git a/tests/unit/qcog/pytorch/validate/test_utils.py b/tests/unit/qcog/pytorch/validate/test_utils.py new file mode 100644 index 0000000..a877c58 --- /dev/null +++ b/tests/unit/qcog/pytorch/validate/test_utils.py @@ -0,0 +1,42 @@ +import io + +from qcog_python_client.qcog.pytorch.validate.utils import get_third_party_imports + + +def test_get_third_party_imports_no_imports(): + source_code = io.BytesIO(b"") + result = get_third_party_imports(source_code, "path") + assert result == set() + + +def test_get_third_party_imports_standard_library_imports(): + # asyncio and os are system dependent modules + # and are more tricky to test cause they are + # in a specific location + source_code = io.BytesIO(b"import os\nimport sys\nimport asyncio") + result = get_third_party_imports(source_code, "path") + assert result == set() + + +def test_get_third_party_imports_third_party_imports(): + source_code = io.BytesIO(b"import requests\nimport numpy") + result = get_third_party_imports(source_code, "path") + assert result == {"requests", "numpy"} + + +def test_get_third_party_imports_mixed_imports(): + source_code = io.BytesIO(b"import os\nimport requests\nfrom numpy import array") + result = get_third_party_imports(source_code, "path") + assert result == {"requests", "numpy"} + + +def test_get_third_party_imports_with_aliases(): + source_code = io.BytesIO(b"import os as os_module\nimport requests as req") + result = get_third_party_imports(source_code, "path") + assert result == {"requests"} + + +def test_get_third_party_imports_with_module_imports(): + source_code = io.BytesIO(b"from os import path\nfrom requests import get") + result = get_third_party_imports(source_code, "path") + assert result == {"requests"}