Skip to content

Commit

Permalink
addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vertefra committed Aug 29, 2024
1 parent 2afdc0b commit a478b14
Show file tree
Hide file tree
Showing 12 changed files with 169 additions and 95 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ schema-build:
python schema.py

test-unit:
pytest -v tests/unit tests/unit
pytest -v tests/unit
37 changes: 31 additions & 6 deletions qcog_python_client/monitor/_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,37 @@

from .interface import Monitor

WANDB_DEFAULT_PROJECT = "qcognitive-dev"
WANDB_DEFAULT_PROJECT = "qognitive-dev"


class WandbMonitor(Monitor):
"""Wandb Monitor implementation."""

def init(
def init( # noqa: D417 # Complains about parameters description not being present
self,
api_key: str | None = None,
project: str = WANDB_DEFAULT_PROJECT,
parameters: dict | None = None,
labels: list[str] | None = None,
) -> None:
"""Initialize the Wandb Monitor."""
"""Initialize the Wandb Monitor.
Parameters
----------
api_key : str | None
Wandb API key.
project : str | None
Name of the project.
parameters : dict | None
Hyperparameters to be logged.
labels : list | None
Tags to be associated with the project.
Raises
------
ValueError: If the Wandb API key is not provided.
"""
key = api_key or os.getenv("WANDB_API_KEY")

if not key:
Expand All @@ -33,12 +51,19 @@ def init(
wandb.init(
project=project,
config=parameters,
tags=labels,
)

def log(self, data: dict) -> None:
"""Log data to Wandb."""
def log(self, data: dict) -> None: # noqa: D417 # Complains about parameters description not being present
"""Log data to Wandb.
Parameters
----------
data : dict Data to be logged.
"""
wandb.log(data)

def close(self) -> None:
"""Close the Wandb Monitor."""
"""Close the Wandb monitor."""
wandb.finish()
3 changes: 2 additions & 1 deletion qcog_python_client/qcog/_baseclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ async def _progress(self) -> dict:
"""
if self.model.model_name == Model.pytorch.value:
raise ValueError("Progress is not available for PyTorch models.")
logger.warning("Progress is not available for PyTorch models.")
return {}

await self._load_trained_model()
return {
Expand Down
34 changes: 14 additions & 20 deletions qcog_python_client/qcog/pytorch/discover/discoverhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from anyio import open_file

from qcog_python_client.qcog.pytorch import utils
from qcog_python_client.qcog.pytorch.discover.types import MaybeIsRelevantFile
from qcog_python_client.qcog.pytorch.discover.types import IsRelevantFile
from qcog_python_client.qcog.pytorch.discover.utils import pkg_name
from qcog_python_client.qcog.pytorch.handler import Command, Handler
from qcog_python_client.qcog.pytorch.types import (
Expand All @@ -31,21 +31,17 @@
)


async def _maybe_model_module(self: DiscoverHandler, file: QFile) -> QFile | None:
async def _is_model_module(self: DiscoverHandler, file: QFile) -> bool:
"""Check if the file is the model module."""
module_name = os.path.basename(file.path)
if module_name == self.model_module_name:
return file
return None
return module_name == self.model_module_name


async def _maybe_monitor_service_import_module(
self: DiscoverHandler, file: QFile
) -> QFile | None:
async def _is_service_import_module(self: DiscoverHandler, file: QFile) -> bool:
"""Check if the file is importing the monitor service."""
# Make sure the item is not a folder. If so, exit
if os.path.isdir(file.path):
return None
return False

tree = ast.parse(file.content.read())
file.content.seek(0)
Expand All @@ -64,14 +60,14 @@ async def _maybe_monitor_service_import_module(
"You cannot import anything from qcog_python_client other than monitor." # noqa: E501
)

if node.names[0].name == "monitor":
return file
return None
return node.names[0].name == "monitor"

return False

relevant_files_map: dict[RelevantFileId, MaybeIsRelevantFile] = {
"model_module": _maybe_model_module, # type: ignore
"monitor_service_import_module": _maybe_monitor_service_import_module, # type: ignore

relevant_files_map: dict[RelevantFileId, IsRelevantFile] = {
"model_module": _is_model_module, # type: ignore
"monitor_service_import_module": _is_service_import_module, # type: ignore
}


Expand All @@ -83,9 +79,8 @@ async def maybe_relevant_file(
retval: dict[RelevantFileId, QFile] = {}

for relevant_file_id, _maybe_relevant_file_fn in relevant_files_map.items():
relevant_file = await _maybe_relevant_file_fn(self, file)
if relevant_file:
retval.update({relevant_file_id: relevant_file})
if await _maybe_relevant_file_fn(self, file):
retval.update({relevant_file_id: file})

return retval

Expand Down Expand Up @@ -151,7 +146,6 @@ async def handle(self, payload: DiscoverCommand) -> ValidateCommand:

async with await open_file(item_path, "rb") as file:
io_file = io.BytesIO(await file.read())
io_file.seek(0)
self.directory[item_path] = QFile.model_validate(
{
"path": item_path,
Expand All @@ -167,7 +161,7 @@ async def handle(self, payload: DiscoverCommand) -> ValidateCommand:
# And used to index the file. Relevant files
# are key files that are used to run
# the training session and are further
# valifatede in the chain.
# validate in the chain.

self.relevant_files: RelevantFiles = {}

Expand Down
4 changes: 2 additions & 2 deletions qcog_python_client/qcog/pytorch/discover/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
from qcog_python_client.qcog.pytorch.handler import Handler
from qcog_python_client.qcog.pytorch.types import DiscoverCommand, QFile

MaybeIsRelevantFile: TypeAlias = Callable[
[Handler[DiscoverCommand], QFile], Coroutine[Any, Any, QFile | None]
IsRelevantFile: TypeAlias = Callable[
[Handler[DiscoverCommand], QFile], Coroutine[Any, Any, bool]
]
4 changes: 3 additions & 1 deletion qcog_python_client/qcog/pytorch/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from pydantic import BaseModel

from qcog_python_client.log import qcoglogger as logger


class BoundedCommand(BaseModel):
"""Command type."""
Expand Down Expand Up @@ -129,7 +131,7 @@ async def dispatch(self, payload: CommandPayloadType) -> Handler:
# 1 - revert the state of the handler
# 2 - wait for the specified time
# 3 - try again
print(f"Attempt {i}/{self.attempts}, error: {e}")
logger.info(f"Attempt {i}/{self.attempts}, error: {e}")
exception = e
await self.revert()
await asyncio.sleep(self.retry_after)
Expand Down
20 changes: 18 additions & 2 deletions qcog_python_client/qcog/pytorch/upload/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,30 @@


def compress_folder(directory: Directory, folder_path: str) -> io.BytesIO:
"""Compress a folder."""
"""Compress a folder from a in-memory directory.
Parameters
----------
directory : Directory
The directory to compress, where the key is the path of the file
and the value is a QFile object.
folder_path : str
The path of the folder to compress.
Returns
-------
io.BytesIO
The compressed folder as a BytesIO object representing a tarfile.
"""
buffer = io.BytesIO()

with tarfile.open(fileobj=buffer, mode="w:gz") as tar:
for qfile in directory.values():
# Get a relative path from the arcname
rel_path = os.path.relpath(qfile.path, folder_path)
logger.warning(f"Adding {rel_path} to the tarfile")
logger.info(f"Adding {rel_path} to training package...")

# Create TarInfo object and set the size
tarfinfo = tarfile.TarInfo(name=rel_path)
Expand Down
49 changes: 37 additions & 12 deletions qcog_python_client/qcog/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,27 @@

from qcog_python_client.qcog.pytorch.types import FilePath, QFile

default_rules = {
re.compile(r".*__pycache__.*"),
re.compile(r".*\.git.*"),
}

def exclude(file_path: str) -> bool:
"""Check against a list of regexes rules to exclude the file."""
rules = {
"pycache": r".*__pycache__.*",
"git": r".*\.git.*",
"venv_1": r".*\.venv.*",
"venv_2": r".*venv.*",
"venv_3": r".*\.env.*",
}

for pattern in rules.values():

def exclude(file_path: str, rules: set[re.Pattern[str]] | None = None) -> bool:
"""Check against a set of regexes rules.
Parameters
----------
file_path : str
The path to the file.
rules : set[re.Pattern[str]] | None
The set of regex rules to check against. Defaults to `default_rules`.
if None is passed.
"""
rules = rules or default_rules
for pattern in rules:
if re.match(pattern, file_path):
return True

Expand All @@ -28,7 +37,23 @@ def exclude(file_path: str) -> bool:
def get_folder_structure(
file_path: str, *, filter: Callable[[FilePath], bool] | None = None
) -> dict[FilePath, QFile]:
"""Return the folder structure as a dictionary."""
"""Return the folder structure as a dictionary.
Parameters
----------
file_path : str
The path to the folder.
filter : Callable[[FilePath], bool] | None
The filter function to apply to the folder structure.
to exclude some paths
Returns
-------
dict[FilePath, QFile]
The folder structure as a dictionary.
"""
folder_items = os.listdir(file_path)

retval: dict[FilePath, QFile] = {}
Expand Down
Loading

0 comments on commit a478b14

Please sign in to comment.