diff --git a/setup.cfg b/setup.cfg index 42c28b5..5c841f7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,7 @@ install_requires = click requests tqdm + typing_extensions zip_safe = false python_requires = >=3.9 diff --git a/src/pystow/__init__.py b/src/pystow/__init__.py index 11de9ea..1e874e4 100644 --- a/src/pystow/__init__.py +++ b/src/pystow/__init__.py @@ -47,7 +47,7 @@ open_gz, ) from .config_api import ConfigError, get_config, write_config -from .impl import Module +from .impl import Module, VersionHint from .utils import ensure_readme __all__ = [ @@ -97,6 +97,7 @@ "get_config", "write_config", "Module", + "VersionHint", ] ensure_readme() diff --git a/src/pystow/api.py b/src/pystow/api.py index 8389d50..0d4c13a 100644 --- a/src/pystow/api.py +++ b/src/pystow/api.py @@ -22,7 +22,7 @@ ) from .constants import JSON, BytesOpener, Provider -from .impl import Module +from .impl import Module, VersionHint if TYPE_CHECKING: import lxml.etree @@ -102,7 +102,13 @@ def module(key: str, *subkeys: str, ensure_exists: bool = True) -> Module: return Module.from_key(key, *subkeys, ensure_exists=ensure_exists) -def join(key: str, *subkeys: str, name: Optional[str] = None, ensure_exists: bool = True) -> Path: +def join( + key: str, + *subkeys: str, + name: Optional[str] = None, + ensure_exists: bool = True, + version: VersionHint = None, +) -> Path: """Return the home data directory for the given module. :param key: @@ -116,11 +122,33 @@ def join(key: str, *subkeys: str, name: Optional[str] = None, ensure_exists: boo :param ensure_exists: Should all directories be created automatically? Defaults to true. + :param version: + The optional version, or no-argument callable that returns + an optional version. This is prepended before the subkeys. + + The following example describes how to store the versioned data + from the Rhea database for biologically relevant chemical reactions. + + .. code-block:: + + import pystow + import requests + + def get_rhea_version() -> str: + res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties") + _, _, version = res.text.splitlines()[0].partition("=") + return version + + # Assume you want to download the data from + # ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz, make a path + # with the same name + path = pystow.join("rhea", name="rhea.rdf.gz", version=get_rhea_version) + :return: The path of the directory or subdirectory for the given module. """ _module = Module.from_key(key, ensure_exists=ensure_exists) - return _module.join(*subkeys, name=name, ensure_exists=ensure_exists) + return _module.join(*subkeys, name=name, ensure_exists=ensure_exists, version=version) # docstr-coverage:excused `overload` @@ -250,6 +278,7 @@ def ensure( *subkeys: str, url: str, name: Optional[str] = None, + version: VersionHint = None, force: bool = False, download_kwargs: Optional[Mapping[str, Any]] = None, ) -> Path: @@ -267,6 +296,29 @@ def ensure( :param name: Overrides the name of the file at the end of the URL, if given. Also useful for URLs that don't have proper filenames with extensions. + :param version: + The optional version, or no-argument callable that returns + an optional version. This is prepended before the subkeys. + + The following example describes how to store the versioned data + from the Rhea database for biologically relevant chemical reactions. + + .. code-block:: + + import pystow + import requests + + def get_rhea_version() -> str: + res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties") + _, _, version = res.text.splitlines()[0].partition("=") + return version + + path = pystow.ensure( + "rhea", + url="ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz", + version=get_rhea_version, + ) + :param force: Should the download be done again, even if the path already exists? Defaults to false. @@ -276,7 +328,7 @@ def ensure( """ _module = Module.from_key(key, ensure_exists=True) return _module.ensure( - *subkeys, url=url, name=name, force=force, download_kwargs=download_kwargs + *subkeys, url=url, name=name, version=version, force=force, download_kwargs=download_kwargs ) diff --git a/src/pystow/impl.py b/src/pystow/impl.py index c9ac9cf..bab9445 100644 --- a/src/pystow/impl.py +++ b/src/pystow/impl.py @@ -8,6 +8,8 @@ import json import logging import lzma +import os +import pickle import sqlite3 import tarfile import zipfile @@ -17,6 +19,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, Generator, Literal, @@ -27,6 +30,8 @@ overload, ) +from typing_extensions import TypeAlias + from . import utils from .constants import JSON, BytesOpener, Provider from .utils import ( @@ -46,11 +51,6 @@ read_zipfile_csv, ) -try: - import pickle5 as pickle -except ImportError: - import pickle - if TYPE_CHECKING: import botocore.client import lxml.etree @@ -58,10 +58,17 @@ import pandas as pd import rdflib -__all__ = ["Module"] +__all__ = [ + "Module", + "VersionHint", +] logger = logging.getLogger(__name__) +#: A type hint for something that can be passed to the +#: `version` argument of Module.join, Module.ensure, etc. +VersionHint: TypeAlias = Union[None, str, Callable[[], Optional[str]]] + class Module: """The class wrapping the directory lookup implementation.""" @@ -121,6 +128,7 @@ def join( *subkeys: str, name: Optional[str] = None, ensure_exists: bool = True, + version: VersionHint = None, ) -> Path: """Get a subdirectory of the current module. @@ -132,10 +140,42 @@ def join( Defaults to true. :param name: The name of the file (optional) inside the folder + :param version: + The optional version, or no-argument callable that returns + an optional version. This is prepended before the subkeys. + + The following example describes how to store the versioned data + from the Rhea database for biologically relevant chemical reactions. + + .. code-block:: + + import pystow + import requests + + def get_rhea_version() -> str: + res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties") + _, _, version = res.text.splitlines()[0].partition("=") + return version + + # Assume you want to download the data from + # ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz, make a path + # with the same name + module = pystow.module("rhea") + path = module.join(name="rhea.rdf.gz", version=get_rhea_version) + :return: The path of the directory or subdirectory for the given module. """ rv = self.base + + # if the version is given as a no-argument callable, + # then it should be called and a version is returned + if callable(version): + version = version() + if version: + self._raise_for_invalid_version(version) + subkeys = (version, *subkeys) + if subkeys: rv = rv.joinpath(*subkeys) mkdir(rv, ensure_exists=ensure_exists) @@ -143,6 +183,14 @@ def join( rv = rv.joinpath(name) return rv + @staticmethod + def _raise_for_invalid_version(version: str) -> None: + if "/" in version or os.sep in version: + raise ValueError( + f"slashes and `{os.sep}` not allowed in versions because of " + f"conflicts with file path construction: {version}" + ) + def joinpath_sqlite(self, *subkeys: str, name: str) -> str: """Get an SQLite database connection string. @@ -160,6 +208,7 @@ def ensure( *subkeys: str, url: str, name: Optional[str] = None, + version: VersionHint = None, force: bool = False, download_kwargs: Optional[Mapping[str, Any]] = None, ) -> Path: @@ -173,6 +222,29 @@ def ensure( :param name: Overrides the name of the file at the end of the URL, if given. Also useful for URLs that don't have proper filenames with extensions. + :param version: + The optional version, or no-argument callable that returns + an optional version. This is prepended before the subkeys. + + The following example describes how to store the versioned data + from the Rhea database for biologically relevant chemical reactions. + + .. code-block:: + + import pystow + import requests + + def get_rhea_version() -> str: + res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties") + _, _, version = res.text.splitlines()[0].partition("=") + return version + + module = pystow.module("rhea") + path = module.ensure( + url="ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz", + version=get_rhea_version, + ) + :param force: Should the download be done again, even if the path already exists? Defaults to false. @@ -182,7 +254,7 @@ def ensure( """ if name is None: name = name_from_url(url) - path = self.join(*subkeys, name=name, ensure_exists=True) + path = self.join(*subkeys, name=name, version=version, ensure_exists=True) utils.download( url=url, path=path, diff --git a/tests/test_module.py b/tests/test_module.py index 902a51f..476bcf0 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -60,7 +60,6 @@ JSON_BZ2_URL = f"{n()}/{JSON_BZ2_NAME}" JSON_BZ2_PATH = RESOURCES / JSON_BZ2_NAME - MOCK_FILES: Mapping[str, Path] = { TSV_URL: RESOURCES / TSV_NAME, JSON_URL: JSON_PATH, @@ -124,7 +123,7 @@ def test_mock_name(self): self.assertFalse(expected_path.exists()) -class TestGet(unittest.TestCase): +class TestJoin(unittest.TestCase): """Tests for :mod:`pystow`.""" def setUp(self) -> None: @@ -175,15 +174,15 @@ def join(self, *parts: str) -> Path: :param parts: The file path parts that are joined with this test case's directory :return: A path to the file """ - return Path(os.path.join(self.directory.name, *parts)) + return Path(self.directory.name).joinpath(*parts) def test_mock(self): """Test that mocking the directory works properly for this test case.""" with self.mock_directory(): self.assertEqual(os.getenv(PYSTOW_HOME_ENVVAR), self.directory.name) - def test_get(self): - """Test the :func:`get` function.""" + def test_join(self): + """Test the :func:`pystow.join` function.""" parts_examples = [ [n()], [n(), n()], @@ -194,6 +193,41 @@ def test_get(self): with self.subTest(parts=parts): self.assertEqual(self.join(*parts), join(*parts)) + def test_join_with_version(self): + """Test the join function when a version is present.""" + with self.mock_directory(): + key = "key" + version = "v1" + self.assertEqual( + self.join(key, version), + pystow.join(key, version=version), + ) + + parts = [n()] + self.assertEqual( + self.join(key, version, *parts), pystow.join(key, *parts, version=version) + ) + + parts = [n()] + name = "yup.tsv" + self.assertEqual( + self.join(key, version, *parts, name), + pystow.join(key, *parts, version=version, name=name), + ) + + def _version_getter() -> str: + return "v2" + + parts = [n()] + name = "yup.tsv" + self.assertEqual( + self.join(key, _version_getter(), *parts, name), + pystow.join(key, *parts, version=_version_getter, name=name), + ) + + with self.assertRaises(ValueError): + pystow.join(key, version="/") + def test_ensure(self): """Test ensuring various files.""" write_pickle_gz(TEST_TSV_ROWS, path=PICKLE_GZ_PATH)