From 5be63576339ad581f7484215e0f08f90d7d4c191 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Sat, 16 Nov 2024 10:36:41 +0100 Subject: [PATCH] Add initial support for version handling (#76) This makes it easier to work with data where either you already know the version associated with it, or if you have a function that can dynamically get the version. The first time I wrote code like this was in PyOBO, but I've also written similar code many times in different places in combination with pystow, so it makes sense to enable it upstream here. This PR just adds support for the join and ensure functions, but in a second step can be added to all functions that transitively call them. --- setup.cfg | 1 + src/pystow/__init__.py | 3 +- src/pystow/api.py | 60 +++++++++++++++++++++++++++-- src/pystow/impl.py | 86 ++++++++++++++++++++++++++++++++++++++---- tests/test_module.py | 44 ++++++++++++++++++--- 5 files changed, 177 insertions(+), 17 deletions(-) diff --git a/setup.cfg b/setup.cfg index 42c28b5..5c841f7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,7 @@ install_requires = click requests tqdm + typing_extensions zip_safe = false python_requires = >=3.9 diff --git a/src/pystow/__init__.py b/src/pystow/__init__.py index 11de9ea..1e874e4 100644 --- a/src/pystow/__init__.py +++ b/src/pystow/__init__.py @@ -47,7 +47,7 @@ open_gz, ) from .config_api import ConfigError, get_config, write_config -from .impl import Module +from .impl import Module, VersionHint from .utils import ensure_readme __all__ = [ @@ -97,6 +97,7 @@ "get_config", "write_config", "Module", + "VersionHint", ] ensure_readme() diff --git a/src/pystow/api.py b/src/pystow/api.py index 8389d50..0d4c13a 100644 --- a/src/pystow/api.py +++ b/src/pystow/api.py @@ -22,7 +22,7 @@ ) from .constants import JSON, BytesOpener, Provider -from .impl import Module +from .impl import Module, VersionHint if TYPE_CHECKING: import lxml.etree @@ -102,7 +102,13 @@ def module(key: str, *subkeys: str, ensure_exists: bool = True) -> Module: return Module.from_key(key, *subkeys, ensure_exists=ensure_exists) -def join(key: str, *subkeys: str, name: Optional[str] = None, ensure_exists: bool = True) -> Path: +def join( + key: str, + *subkeys: str, + name: Optional[str] = None, + ensure_exists: bool = True, + version: VersionHint = None, +) -> Path: """Return the home data directory for the given module. :param key: @@ -116,11 +122,33 @@ def join(key: str, *subkeys: str, name: Optional[str] = None, ensure_exists: boo :param ensure_exists: Should all directories be created automatically? Defaults to true. + :param version: + The optional version, or no-argument callable that returns + an optional version. This is prepended before the subkeys. + + The following example describes how to store the versioned data + from the Rhea database for biologically relevant chemical reactions. + + .. code-block:: + + import pystow + import requests + + def get_rhea_version() -> str: + res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties") + _, _, version = res.text.splitlines()[0].partition("=") + return version + + # Assume you want to download the data from + # ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz, make a path + # with the same name + path = pystow.join("rhea", name="rhea.rdf.gz", version=get_rhea_version) + :return: The path of the directory or subdirectory for the given module. """ _module = Module.from_key(key, ensure_exists=ensure_exists) - return _module.join(*subkeys, name=name, ensure_exists=ensure_exists) + return _module.join(*subkeys, name=name, ensure_exists=ensure_exists, version=version) # docstr-coverage:excused `overload` @@ -250,6 +278,7 @@ def ensure( *subkeys: str, url: str, name: Optional[str] = None, + version: VersionHint = None, force: bool = False, download_kwargs: Optional[Mapping[str, Any]] = None, ) -> Path: @@ -267,6 +296,29 @@ def ensure( :param name: Overrides the name of the file at the end of the URL, if given. Also useful for URLs that don't have proper filenames with extensions. + :param version: + The optional version, or no-argument callable that returns + an optional version. This is prepended before the subkeys. + + The following example describes how to store the versioned data + from the Rhea database for biologically relevant chemical reactions. + + .. code-block:: + + import pystow + import requests + + def get_rhea_version() -> str: + res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties") + _, _, version = res.text.splitlines()[0].partition("=") + return version + + path = pystow.ensure( + "rhea", + url="ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz", + version=get_rhea_version, + ) + :param force: Should the download be done again, even if the path already exists? Defaults to false. @@ -276,7 +328,7 @@ def ensure( """ _module = Module.from_key(key, ensure_exists=True) return _module.ensure( - *subkeys, url=url, name=name, force=force, download_kwargs=download_kwargs + *subkeys, url=url, name=name, version=version, force=force, download_kwargs=download_kwargs ) diff --git a/src/pystow/impl.py b/src/pystow/impl.py index c9ac9cf..bab9445 100644 --- a/src/pystow/impl.py +++ b/src/pystow/impl.py @@ -8,6 +8,8 @@ import json import logging import lzma +import os +import pickle import sqlite3 import tarfile import zipfile @@ -17,6 +19,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, Generator, Literal, @@ -27,6 +30,8 @@ overload, ) +from typing_extensions import TypeAlias + from . import utils from .constants import JSON, BytesOpener, Provider from .utils import ( @@ -46,11 +51,6 @@ read_zipfile_csv, ) -try: - import pickle5 as pickle -except ImportError: - import pickle - if TYPE_CHECKING: import botocore.client import lxml.etree @@ -58,10 +58,17 @@ import pandas as pd import rdflib -__all__ = ["Module"] +__all__ = [ + "Module", + "VersionHint", +] logger = logging.getLogger(__name__) +#: A type hint for something that can be passed to the +#: `version` argument of Module.join, Module.ensure, etc. +VersionHint: TypeAlias = Union[None, str, Callable[[], Optional[str]]] + class Module: """The class wrapping the directory lookup implementation.""" @@ -121,6 +128,7 @@ def join( *subkeys: str, name: Optional[str] = None, ensure_exists: bool = True, + version: VersionHint = None, ) -> Path: """Get a subdirectory of the current module. @@ -132,10 +140,42 @@ def join( Defaults to true. :param name: The name of the file (optional) inside the folder + :param version: + The optional version, or no-argument callable that returns + an optional version. This is prepended before the subkeys. + + The following example describes how to store the versioned data + from the Rhea database for biologically relevant chemical reactions. + + .. code-block:: + + import pystow + import requests + + def get_rhea_version() -> str: + res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties") + _, _, version = res.text.splitlines()[0].partition("=") + return version + + # Assume you want to download the data from + # ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz, make a path + # with the same name + module = pystow.module("rhea") + path = module.join(name="rhea.rdf.gz", version=get_rhea_version) + :return: The path of the directory or subdirectory for the given module. """ rv = self.base + + # if the version is given as a no-argument callable, + # then it should be called and a version is returned + if callable(version): + version = version() + if version: + self._raise_for_invalid_version(version) + subkeys = (version, *subkeys) + if subkeys: rv = rv.joinpath(*subkeys) mkdir(rv, ensure_exists=ensure_exists) @@ -143,6 +183,14 @@ def join( rv = rv.joinpath(name) return rv + @staticmethod + def _raise_for_invalid_version(version: str) -> None: + if "/" in version or os.sep in version: + raise ValueError( + f"slashes and `{os.sep}` not allowed in versions because of " + f"conflicts with file path construction: {version}" + ) + def joinpath_sqlite(self, *subkeys: str, name: str) -> str: """Get an SQLite database connection string. @@ -160,6 +208,7 @@ def ensure( *subkeys: str, url: str, name: Optional[str] = None, + version: VersionHint = None, force: bool = False, download_kwargs: Optional[Mapping[str, Any]] = None, ) -> Path: @@ -173,6 +222,29 @@ def ensure( :param name: Overrides the name of the file at the end of the URL, if given. Also useful for URLs that don't have proper filenames with extensions. + :param version: + The optional version, or no-argument callable that returns + an optional version. This is prepended before the subkeys. + + The following example describes how to store the versioned data + from the Rhea database for biologically relevant chemical reactions. + + .. code-block:: + + import pystow + import requests + + def get_rhea_version() -> str: + res = requests.get("https://ftp.expasy.org/databases/rhea/rhea-release.properties") + _, _, version = res.text.splitlines()[0].partition("=") + return version + + module = pystow.module("rhea") + path = module.ensure( + url="ftp://ftp.expasy.org/databases/rhea/rdf/rhea.rdf.gz", + version=get_rhea_version, + ) + :param force: Should the download be done again, even if the path already exists? Defaults to false. @@ -182,7 +254,7 @@ def ensure( """ if name is None: name = name_from_url(url) - path = self.join(*subkeys, name=name, ensure_exists=True) + path = self.join(*subkeys, name=name, version=version, ensure_exists=True) utils.download( url=url, path=path, diff --git a/tests/test_module.py b/tests/test_module.py index 902a51f..476bcf0 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -60,7 +60,6 @@ JSON_BZ2_URL = f"{n()}/{JSON_BZ2_NAME}" JSON_BZ2_PATH = RESOURCES / JSON_BZ2_NAME - MOCK_FILES: Mapping[str, Path] = { TSV_URL: RESOURCES / TSV_NAME, JSON_URL: JSON_PATH, @@ -124,7 +123,7 @@ def test_mock_name(self): self.assertFalse(expected_path.exists()) -class TestGet(unittest.TestCase): +class TestJoin(unittest.TestCase): """Tests for :mod:`pystow`.""" def setUp(self) -> None: @@ -175,15 +174,15 @@ def join(self, *parts: str) -> Path: :param parts: The file path parts that are joined with this test case's directory :return: A path to the file """ - return Path(os.path.join(self.directory.name, *parts)) + return Path(self.directory.name).joinpath(*parts) def test_mock(self): """Test that mocking the directory works properly for this test case.""" with self.mock_directory(): self.assertEqual(os.getenv(PYSTOW_HOME_ENVVAR), self.directory.name) - def test_get(self): - """Test the :func:`get` function.""" + def test_join(self): + """Test the :func:`pystow.join` function.""" parts_examples = [ [n()], [n(), n()], @@ -194,6 +193,41 @@ def test_get(self): with self.subTest(parts=parts): self.assertEqual(self.join(*parts), join(*parts)) + def test_join_with_version(self): + """Test the join function when a version is present.""" + with self.mock_directory(): + key = "key" + version = "v1" + self.assertEqual( + self.join(key, version), + pystow.join(key, version=version), + ) + + parts = [n()] + self.assertEqual( + self.join(key, version, *parts), pystow.join(key, *parts, version=version) + ) + + parts = [n()] + name = "yup.tsv" + self.assertEqual( + self.join(key, version, *parts, name), + pystow.join(key, *parts, version=version, name=name), + ) + + def _version_getter() -> str: + return "v2" + + parts = [n()] + name = "yup.tsv" + self.assertEqual( + self.join(key, _version_getter(), *parts, name), + pystow.join(key, *parts, version=_version_getter, name=name), + ) + + with self.assertRaises(ValueError): + pystow.join(key, version="/") + def test_ensure(self): """Test ensuring various files.""" write_pickle_gz(TEST_TSV_ROWS, path=PICKLE_GZ_PATH)