From 2a6cc413f2fe2c6817c72502afc695cc8fc4c108 Mon Sep 17 00:00:00 2001 From: justold <1188067+pwwang@users.noreply.github.com> Date: Fri, 11 Aug 2023 23:59:21 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Allow=20`datar.all.filter=20rega?= =?UTF-8?q?rdless`=20of=20`allow=5Fconflict=5Fnames`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datar/all.py | 18 +++++- datar/base.py | 18 +++++- datar/dplyr.py | 18 +++++- poetry.lock | 24 ++++---- tests/conflict_names.py | 112 +++++++++++++++++++++++++++++++++++ tests/test_conflict_names.py | 85 ++++++++++++++++++++++++++ tox.ini | 1 + 7 files changed, 261 insertions(+), 15 deletions(-) create mode 100644 tests/conflict_names.py create mode 100644 tests/test_conflict_names.py diff --git a/datar/all.py b/datar/all.py index 299b4772c..9ea434481 100644 --- a/datar/all.py +++ b/datar/all.py @@ -15,7 +15,23 @@ __all__ = [key for key in locals() if not key.startswith("_")] -if get_option("allow_conflict_names"): # noqa: F405 pragma: no cover +if get_option("allow_conflict_names"): # noqa: F405 __all__.extend(_base_conflict_names | _dplyr_conflict_names) for name in _base_conflict_names | _dplyr_conflict_names: locals()[name] = locals()[name + "_"] + + +def __getattr__(name): + """Even when allow_conflict_names is False, datar.base.sum should be fine + """ + if name in _base_conflict_names | _dplyr_conflict_names: + import sys + import ast + from executing import Source + node = Source.executing(sys._getframe(1)).node + if isinstance(node, (ast.Call, ast.Attribute)): + # import datar.all as d + # d.sum(...) or getattr(d, "sum")(...) + return globals()[name + "_"] + + raise AttributeError diff --git a/datar/base.py b/datar/base.py index c36a2f299..6c156291a 100644 --- a/datar/base.py +++ b/datar/base.py @@ -6,7 +6,23 @@ __all__ = [key for key in locals() if not key.startswith("_")] _conflict_names = {"min", "max", "sum", "abs", "round", "all", "any", "re"} -if get_option("allow_conflict_names"): # noqa: F405 pragma: no cover +if get_option("allow_conflict_names"): # noqa: F405 __all__.extend(_conflict_names) for name in _conflict_names: locals()[name] = locals()[name + "_"] + + +def __getattr__(name): + """Even when allow_conflict_names is False, datar.base.sum should be fine + """ + if name in _conflict_names: + import sys + import ast + from executing import Source + node = Source.executing(sys._getframe(1)).node + if isinstance(node, (ast.Call, ast.Attribute)): + # import datar.base as d + # d.sum(...) + return globals()[name + "_"] + + raise AttributeError diff --git a/datar/dplyr.py b/datar/dplyr.py index 3c609e512..6f71cb61a 100644 --- a/datar/dplyr.py +++ b/datar/dplyr.py @@ -7,7 +7,23 @@ __all__ = [key for key in locals() if not key.startswith("_")] _conflict_names = {"filter", "slice"} -if _get_option("allow_conflict_names"): # pragma: no cover +if _get_option("allow_conflict_names"): __all__.extend(_conflict_names) for name in _conflict_names: locals()[name] = locals()[name + "_"] + + +def __getattr__(name): + """Even when allow_conflict_names is False, datar.base.sum should be fine + """ + if name in _conflict_names: + import sys + import ast + from executing import Source + node = Source.executing(sys._getframe(1)).node + if isinstance(node, (ast.Call, ast.Attribute)): + # import datar.dplyr as d + # d.sum(...) + return globals()[name + "_"] + + raise AttributeError diff --git a/poetry.lock b/poetry.lock index 9eb869dbc..6e17e83a1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -88,47 +88,47 @@ toml = ["tomli"] [[package]] name = "datar-arrow" -version = "0.0.0" +version = "0.0.1" description = "The pyarrow backend for datar" optional = true python-versions = ">=3.8,<4.0" files = [ - {file = "datar_arrow-0.0.0-py3-none-any.whl", hash = "sha256:41ddc5d3171ac76433c20380178a12e9f08a6761ea930c5626a843a2e5dba21b"}, - {file = "datar_arrow-0.0.0.tar.gz", hash = "sha256:1c713201830ed9232a4892ce73aa6a834d5cbe3c885ae5629b666749b1be4a86"}, + {file = "datar_arrow-0.0.1-py3-none-any.whl", hash = "sha256:20a0ead1d3bf74b800d1e0255647ed7bc287f546aca0d88b9f8489286ea5d553"}, + {file = "datar_arrow-0.0.1.tar.gz", hash = "sha256:0b4f8f4eb5446e0b64306def93ae63741f7c6099a24a59b13763b2037202842f"}, ] [package.dependencies] -datar = ">=0.12,<0.13" +datar = ">=0.13,<0.14" pyarrow = ">=11,<12" [[package]] name = "datar-numpy" -version = "0.2.0" +version = "0.2.1" description = "The numpy backend for datar" optional = true python-versions = ">=3.8,<4.0" files = [ - {file = "datar_numpy-0.2.0-py3-none-any.whl", hash = "sha256:2369f3d8419e7ffbbe950ad6b459a3f75ae6571155fab09f63874d49e96bab6b"}, - {file = "datar_numpy-0.2.0.tar.gz", hash = "sha256:a0e119f3ec208ba45100e31b22a8dc1d0a926a42f144bb5b2b859e2897d2e320"}, + {file = "datar_numpy-0.2.1-py3-none-any.whl", hash = "sha256:f9a85e514c612d6c400514f75242648d17c340aff188d8b8cf94d0a0de157bab"}, + {file = "datar_numpy-0.2.1.tar.gz", hash = "sha256:49d835f6fdbd856cd7ea00d2829da664b2138be40892d72be464d0df79b5252a"}, ] [package.dependencies] -datar = ">=0.12,<0.13" +datar = ">=0.13,<0.14" numpy = ">=1.20,<2.0" [[package]] name = "datar-pandas" -version = "0.3.0" +version = "0.3.1" description = "The pandas backend for datar" optional = true python-versions = ">=3.8,<4.0" files = [ - {file = "datar_pandas-0.3.0-py3-none-any.whl", hash = "sha256:738190a54c816dd3fa1529d0f48a601b134edc52a3e1d11ce6d3d0bcec91c2a7"}, - {file = "datar_pandas-0.3.0.tar.gz", hash = "sha256:97803e124f0e3ac2f6b65ddd087859fabb47e5d510fc1f4d3772a882853dbea9"}, + {file = "datar_pandas-0.3.1-py3-none-any.whl", hash = "sha256:ccd5db148b2d5a5d76aa408ea8b2b33b6bb4b4cc0a47d84062de980ef42a91f2"}, + {file = "datar_pandas-0.3.1.tar.gz", hash = "sha256:2d683cca9eb1991fe19e0669b41242dfed66e0d65d183e3c76a0014376e1a051"}, ] [package.dependencies] -datar = ">=0.12,<0.13" +datar = ">=0.13,<0.14" datar-numpy = ">=0.2,<0.3" pdtypes = ">=0.0.4,<0.0.5" diff --git a/tests/conflict_names.py b/tests/conflict_names.py new file mode 100644 index 000000000..f243e6ac1 --- /dev/null +++ b/tests/conflict_names.py @@ -0,0 +1,112 @@ +import argparse + + +def test_getattr(module, allow_conflict_names, fun, error): + from datar import options + options(allow_conflict_names=allow_conflict_names) + + if module == "all": + import datar.all as d + elif module == "base": + import datar.base as d + elif module == "dplyr": + import datar.dplyr as d + + if not error: + return getattr(d, fun) + + try: + getattr(d, fun) + except Exception as e: + raised = type(e).__name__ + assert raised == error, f"Raised {raised}, expected {error}" + else: + raise AssertionError(f"{error} should have raised") + + +def _import(module, fun): + if module == "all" and fun == "sum": + from datar.all import sum # noqa: F401 + elif module == "all" and fun == "slice": + from datar.all import slice # noqa: F401 + elif module == "base" and fun == "sum": + from datar.base import sum # noqa: F401 + elif module == "dplyr" and fun == "slice": + from datar.dplyr import slice # noqa: F401 + + +def test_import(module, allow_conflict_names, fun, error): + from datar import options + options(allow_conflict_names=allow_conflict_names) + + if not error: + return _import(module, fun) + + try: + _import(module, fun) + except Exception as e: + raised = type(e).__name__ + assert raised == error, f"Raised {raised}, expected {error}" + else: + raise AssertionError(f"{error} should have raised") + + +def make_test(module, allow_conflict_names, getattr, fun, error): + if fun == "_": + fun = "sum" if module in ["all", "base"] else "slice" + + if getattr: + return test_getattr(module, allow_conflict_names, fun, error) + + return test_import(module, allow_conflict_names, fun, error) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--module", + choices=["all", "base", "dplyr"], + required=True, + help="The module to test" + ) + parser.add_argument( + "--allow-conflict-names", + action="store_true", + help="Whether to allow conflict names", + default=False, + ) + parser.add_argument( + "--getattr", + action="store_true", + help=( + "Whether to test datar.all.sum, " + "otherwise test from datar.all import sum." + ), + default=False, + ) + parser.add_argument( + "--fun", + help=( + "The function to test. " + "If _ then sum for all/base, slice for dplyr" + ), + choices=["sum", "filter", "_"], + default="_", + ) + parser.add_argument( + "--error", + help="The error to expect", + ) + args = parser.parse_args() + + make_test( + args.module, + args.allow_conflict_names, + args.getattr, + args.fun, + args.error, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_conflict_names.py b/tests/test_conflict_names.py new file mode 100644 index 000000000..46ac91094 --- /dev/null +++ b/tests/test_conflict_names.py @@ -0,0 +1,85 @@ +import sys +import subprocess +from pathlib import Path + +import pytest + + +def _run_conflict_names(module, allow_conflict_names, getat, error): + here = Path(__file__).parent + conflict_names = here / "conflict_names.py" + cmd = [ + sys.executable, + str(conflict_names), + "--module", + module, + ] + if error: + cmd += ["--error", error] + if allow_conflict_names: + cmd.append("--allow-conflict-names") + if getat: + cmd.append("--getattr") + + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return p.wait(), " ".join(cmd) + + +def test_from_all_import_allow_conflict_names_true(): + r, cmd = _run_conflict_names("all", True, False, None) + assert r == 0, cmd + + +def test_from_all_import_allow_conflict_names_false(): + r, cmd = _run_conflict_names("all", False, False, "ImportError") + assert r == 0, cmd + + +def test_all_getattr_allow_conflict_names_true(): + r, cmd = _run_conflict_names("all", True, True, None) + assert r == 0, cmd + + +def test_all_getattr_allow_conflict_names_false(): + r, cmd = _run_conflict_names("all", False, True, None) + assert r == 0, cmd + + +def test_from_base_import_allow_conflict_names_true(): + r, cmd = _run_conflict_names("base", True, False, None) + assert r == 0, cmd + + +def test_from_base_import_allow_conflict_names_false(): + r, cmd = _run_conflict_names("base", False, False, "ImportError") + assert r == 0, cmd + + +def test_base_getattr_allow_conflict_names_true(): + r, cmd = _run_conflict_names("base", True, True, None) + assert r == 0, cmd + + +def test_base_getattr_allow_conflict_names_false(): + r, cmd = _run_conflict_names("base", False, True, None) + assert r == 0, cmd + + +def test_from_dplyr_import_allow_conflict_names_true(): + r, cmd = _run_conflict_names("dplyr", True, False, None) + assert r == 0, cmd + + +def test_from_dplyr_import_allow_conflict_names_false(): + r, cmd = _run_conflict_names("dplyr", False, False, "ImportError") + assert r == 0, cmd + + +def test_dplyr_getattr_allow_conflict_names_true(): + r, cmd = _run_conflict_names("dplyr", True, True, None) + assert r == 0, cmd + + +def test_dplyr_getattr_allow_conflict_names_false(): + r, cmd = _run_conflict_names("dplyr", False, True, None) + assert r == 0, cmd diff --git a/tox.ini b/tox.ini index 75b7e9b70..9974eb949 100644 --- a/tox.ini +++ b/tox.ini @@ -14,4 +14,5 @@ per-file-ignores = datar/base.py: F401, F402, F403, E402 datar/dplyr.py: F401, F402, F403, E402 datar/data/metadata.py: E501 + tests/test_conflict_names.py: F401 max-line-length = 81