Skip to content

Commit

Permalink
Add initial support for version handling (#76)
Browse files Browse the repository at this point in the history
This makes it easier to work with data where either you already know the
version associated with it, or if you have a function that can
dynamically get the version. The first time I wrote code like this was
in PyOBO, but I've also written similar code many times in different
places in combination with pystow, so it makes sense to enable it
upstream here.

This PR just adds support for the join and ensure functions, but in a
second step can be added to all functions that transitively call them.
  • Loading branch information
cthoyt authored Nov 16, 2024
1 parent fd7de90 commit 5be6357
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 17 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ install_requires =
click
requests
tqdm
typing_extensions

zip_safe = false
python_requires = >=3.9
Expand Down
3 changes: 2 additions & 1 deletion src/pystow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
open_gz,
)
from .config_api import ConfigError, get_config, write_config
from .impl import Module
from .impl import Module, VersionHint
from .utils import ensure_readme

__all__ = [
Expand Down Expand Up @@ -97,6 +97,7 @@
"get_config",
"write_config",
"Module",
"VersionHint",
]

ensure_readme()
Expand Down
60 changes: 56 additions & 4 deletions src/pystow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)

from .constants import JSON, BytesOpener, Provider
from .impl import Module
from .impl import Module, VersionHint

if TYPE_CHECKING:
import lxml.etree
Expand Down Expand Up @@ -102,7 +102,13 @@ def module(key: str, *subkeys: str, ensure_exists: bool = True) -> Module:
return Module.from_key(key, *subkeys, ensure_exists=ensure_exists)


def join(key: str, *subkeys: str, name: Optional[str] = None, ensure_exists: bool = True) -> Path:
def join(
key: str,
*subkeys: str,
name: Optional[str] = None,
ensure_exists: bool = True,
version: VersionHint = None,
) -> Path:
"""Return the home data directory for the given module.
:param key:
Expand All @@ -116,11 +122,33 @@ def join(key: str, *subkeys: str, name: Optional[str] = None, ensure_exists: boo
:param ensure_exists:
Should all directories be created automatically?
Defaults to true.
:param version:
The optional version, or no-argument callable that returns
an optional version. This is prepended before the subkeys.
The following example describes how to store the versioned data
from the Rhea database for biologically relevant chemical reactions.
.. code-block::
import pystow
import requests
def get_rhea_version() -> str:
res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties")
_, _, version = res.text.splitlines()[0].partition("=")
return version
# Assume you want to download the data from
# ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz, make a path
# with the same name
path = pystow.join("rhea", name="rhea.rdf.gz", version=get_rhea_version)
:return:
The path of the directory or subdirectory for the given module.
"""
_module = Module.from_key(key, ensure_exists=ensure_exists)
return _module.join(*subkeys, name=name, ensure_exists=ensure_exists)
return _module.join(*subkeys, name=name, ensure_exists=ensure_exists, version=version)


# docstr-coverage:excused `overload`
Expand Down Expand Up @@ -250,6 +278,7 @@ def ensure(
*subkeys: str,
url: str,
name: Optional[str] = None,
version: VersionHint = None,
force: bool = False,
download_kwargs: Optional[Mapping[str, Any]] = None,
) -> Path:
Expand All @@ -267,6 +296,29 @@ def ensure(
:param name:
Overrides the name of the file at the end of the URL, if given. Also
useful for URLs that don't have proper filenames with extensions.
:param version:
The optional version, or no-argument callable that returns
an optional version. This is prepended before the subkeys.
The following example describes how to store the versioned data
from the Rhea database for biologically relevant chemical reactions.
.. code-block::
import pystow
import requests
def get_rhea_version() -> str:
res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties")
_, _, version = res.text.splitlines()[0].partition("=")
return version
path = pystow.ensure(
"rhea",
url="ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz",
version=get_rhea_version,
)
:param force:
Should the download be done again, even if the path already exists?
Defaults to false.
Expand All @@ -276,7 +328,7 @@ def ensure(
"""
_module = Module.from_key(key, ensure_exists=True)
return _module.ensure(
*subkeys, url=url, name=name, force=force, download_kwargs=download_kwargs
*subkeys, url=url, name=name, version=version, force=force, download_kwargs=download_kwargs
)


Expand Down
86 changes: 79 additions & 7 deletions src/pystow/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import json
import logging
import lzma
import os
import pickle
import sqlite3
import tarfile
import zipfile
Expand All @@ -17,6 +19,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Literal,
Expand All @@ -27,6 +30,8 @@
overload,
)

from typing_extensions import TypeAlias

from . import utils
from .constants import JSON, BytesOpener, Provider
from .utils import (
Expand All @@ -46,22 +51,24 @@
read_zipfile_csv,
)

try:
import pickle5 as pickle
except ImportError:
import pickle

if TYPE_CHECKING:
import botocore.client
import lxml.etree
import numpy
import pandas as pd
import rdflib

__all__ = ["Module"]
__all__ = [
"Module",
"VersionHint",
]

logger = logging.getLogger(__name__)

#: A type hint for something that can be passed to the
#: `version` argument of Module.join, Module.ensure, etc.
VersionHint: TypeAlias = Union[None, str, Callable[[], Optional[str]]]


class Module:
"""The class wrapping the directory lookup implementation."""
Expand Down Expand Up @@ -121,6 +128,7 @@ def join(
*subkeys: str,
name: Optional[str] = None,
ensure_exists: bool = True,
version: VersionHint = None,
) -> Path:
"""Get a subdirectory of the current module.
Expand All @@ -132,17 +140,57 @@ def join(
Defaults to true.
:param name:
The name of the file (optional) inside the folder
:param version:
The optional version, or no-argument callable that returns
an optional version. This is prepended before the subkeys.
The following example describes how to store the versioned data
from the Rhea database for biologically relevant chemical reactions.
.. code-block::
import pystow
import requests
def get_rhea_version() -> str:
res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties")
_, _, version = res.text.splitlines()[0].partition("=")
return version
# Assume you want to download the data from
# ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz, make a path
# with the same name
module = pystow.module("rhea")
path = module.join(name="rhea.rdf.gz", version=get_rhea_version)
:return:
The path of the directory or subdirectory for the given module.
"""
rv = self.base

# if the version is given as a no-argument callable,
# then it should be called and a version is returned
if callable(version):
version = version()
if version:
self._raise_for_invalid_version(version)
subkeys = (version, *subkeys)

if subkeys:
rv = rv.joinpath(*subkeys)
mkdir(rv, ensure_exists=ensure_exists)
if name:
rv = rv.joinpath(name)
return rv

@staticmethod
def _raise_for_invalid_version(version: str) -> None:
if "/" in version or os.sep in version:
raise ValueError(
f"slashes and `{os.sep}` not allowed in versions because of "
f"conflicts with file path construction: {version}"
)

def joinpath_sqlite(self, *subkeys: str, name: str) -> str:
"""Get an SQLite database connection string.
Expand All @@ -160,6 +208,7 @@ def ensure(
*subkeys: str,
url: str,
name: Optional[str] = None,
version: VersionHint = None,
force: bool = False,
download_kwargs: Optional[Mapping[str, Any]] = None,
) -> Path:
Expand All @@ -173,6 +222,29 @@ def ensure(
:param name:
Overrides the name of the file at the end of the URL, if given. Also
useful for URLs that don't have proper filenames with extensions.
:param version:
The optional version, or no-argument callable that returns
an optional version. This is prepended before the subkeys.
The following example describes how to store the versioned data
from the Rhea database for biologically relevant chemical reactions.
.. code-block::
import pystow
import requests
def get_rhea_version() -> str:
res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties")
_, _, version = res.text.splitlines()[0].partition("=")
return version
module = pystow.module("rhea")
path = module.ensure(
url="ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz",
version=get_rhea_version,
)
:param force:
Should the download be done again, even if the path already exists?
Defaults to false.
Expand All @@ -182,7 +254,7 @@ def ensure(
"""
if name is None:
name = name_from_url(url)
path = self.join(*subkeys, name=name, ensure_exists=True)
path = self.join(*subkeys, name=name, version=version, ensure_exists=True)
utils.download(
url=url,
path=path,
Expand Down
44 changes: 39 additions & 5 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
JSON_BZ2_URL = f"{n()}/{JSON_BZ2_NAME}"
JSON_BZ2_PATH = RESOURCES / JSON_BZ2_NAME


MOCK_FILES: Mapping[str, Path] = {
TSV_URL: RESOURCES / TSV_NAME,
JSON_URL: JSON_PATH,
Expand Down Expand Up @@ -124,7 +123,7 @@ def test_mock_name(self):
self.assertFalse(expected_path.exists())


class TestGet(unittest.TestCase):
class TestJoin(unittest.TestCase):
"""Tests for :mod:`pystow`."""

def setUp(self) -> None:
Expand Down Expand Up @@ -175,15 +174,15 @@ def join(self, *parts: str) -> Path:
:param parts: The file path parts that are joined with this test case's directory
:return: A path to the file
"""
return Path(os.path.join(self.directory.name, *parts))
return Path(self.directory.name).joinpath(*parts)

def test_mock(self):
"""Test that mocking the directory works properly for this test case."""
with self.mock_directory():
self.assertEqual(os.getenv(PYSTOW_HOME_ENVVAR), self.directory.name)

def test_get(self):
"""Test the :func:`get` function."""
def test_join(self):
"""Test the :func:`pystow.join` function."""
parts_examples = [
[n()],
[n(), n()],
Expand All @@ -194,6 +193,41 @@ def test_get(self):
with self.subTest(parts=parts):
self.assertEqual(self.join(*parts), join(*parts))

def test_join_with_version(self):
"""Test the join function when a version is present."""
with self.mock_directory():
key = "key"
version = "v1"
self.assertEqual(
self.join(key, version),
pystow.join(key, version=version),
)

parts = [n()]
self.assertEqual(
self.join(key, version, *parts), pystow.join(key, *parts, version=version)
)

parts = [n()]
name = "yup.tsv"
self.assertEqual(
self.join(key, version, *parts, name),
pystow.join(key, *parts, version=version, name=name),
)

def _version_getter() -> str:
return "v2"

parts = [n()]
name = "yup.tsv"
self.assertEqual(
self.join(key, _version_getter(), *parts, name),
pystow.join(key, *parts, version=_version_getter, name=name),
)

with self.assertRaises(ValueError):
pystow.join(key, version="/")

def test_ensure(self):
"""Test ensuring various files."""
write_pickle_gz(TEST_TSV_ROWS, path=PICKLE_GZ_PATH)
Expand Down

0 comments on commit 5be6357

Please sign in to comment.