From 4626e0303ba32d9e1548cee4eb1f76355ef873c7 Mon Sep 17 00:00:00 2001 From: Roni Kobrosly Date: Thu, 24 Dec 2020 22:55:04 -0500 Subject: [PATCH] Feature/fix sphinx (#28) * fixed issue where sphinx didn't show modules * updates * added more test cases * tests for lambda_ in GPS * more checks for TMLE tool * fixed TMLE unit tests Co-authored-by: rkobrosly --- causal_curve/core.py | 2 +- causal_curve/gps.py | 2 +- docs/changelog.rst | 6 ++++ docs/conf.py | 29 +++++++++-------- setup.py | 2 +- tests/unit/test_gps.py | 71 +++++++++++++++++++++++++++-------------- tests/unit/test_tmle.py | 46 +++++++++++++++----------- 7 files changed, 99 insertions(+), 59 deletions(-) diff --git a/causal_curve/core.py b/causal_curve/core.py index 3e0b083..80ba09a 100644 --- a/causal_curve/core.py +++ b/causal_curve/core.py @@ -26,4 +26,4 @@ def get_params(self): [(k, v) for k, v in list(attrs.items()) if (k[0] != "_") and (k[-1] != "_")] ) - __version__ = pkg_resources.require("causal-curve")[0].version + __version__ = "0.5.2" diff --git a/causal_curve/gps.py b/causal_curve/gps.py index 8e6b841..41767f2 100644 --- a/causal_curve/gps.py +++ b/causal_curve/gps.py @@ -324,7 +324,7 @@ def _validate_init_params(self): if (isinstance(self.lambda_, (int, float))) and self.lambda_ <= 0: raise ValueError( - f"lambda_ parameter should be >= 2, but found {self.lambda_}" + f"lambda_ parameter should be > 0, but found {self.lambda_}" ) if (isinstance(self.lambda_, (int, float))) and self.lambda_ >= 1000: diff --git a/docs/changelog.rst b/docs/changelog.rst index 6374987..631c7d5 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,12 @@ Change Log ========== +Version 0.5.2 +------------- +- Fixed bug that prevented `causal-curve` modules from being shown in Sphinx documentation +- Augmented tests to capture more error states and improve code coverage + + Version 0.5.1 ------------- - Removed working test file diff --git a/docs/conf.py b/docs/conf.py index 67166b3..d83b66f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,17 +12,18 @@ import os import sys -sys.path.insert(0, os.path.abspath('../')) + +sys.path.insert(0, os.path.abspath("../")) # -- Project information ----------------------------------------------------- -project = 'causal_curve' -copyright = '2020, Roni Kobrosly' -author = 'Roni Kobrosly' +project = "causal_curve" +copyright = "2020, Roni Kobrosly" +author = "Roni Kobrosly" # The full version, including alpha/beta/rc tags -release = '0.5.1' +release = "0.5.2" # -- General configuration --------------------------------------------------- @@ -30,21 +31,21 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'numpydoc', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "numpydoc", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # this is needed for some reason... # see https://github.com/numpy/numpydoc/issues/69 @@ -53,16 +54,16 @@ # generate autosummary even if no references autosummary_generate = True -master_doc = 'index' +master_doc = "index" # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = [] diff --git a/setup.py b/setup.py index 9dc9837..b956828 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="causal-curve", - version="0.5.1", + version="0.5.2", author="Roni Kobrosly", author_email="roni.kobrosly@gmail.com", description="A python library with tools to perform causal inference using \ diff --git a/tests/unit/test_gps.py b/tests/unit/test_gps.py index d38bb98..e550dd4 100644 --- a/tests/unit/test_gps.py +++ b/tests/unit/test_gps.py @@ -51,35 +51,56 @@ def test_gps_fit(df_fixture, family): "upper_grid_constraint", "spline_order", "n_splines", + "lambda_", "max_iter", "random_seed", "verbose", ), [ - (546, 10, 0, 1.0, 3, 10, 100, 100, True), - ("linear", 10, 0, 1.0, 3, 10, 100, 100, True), - (None, "hehe", 0, 1.0, 3, 10, 100, 100, True), - (None, 2, 0, 1.0, 3, 10, 100, 100, True), - (None, 1e6, 0, 1.0, 3, 10, 100, 100, True), - (None, 10, "hehe", 1.0, 3, 10, 100, 100, True), - (None, 10, -1, 1.0, 3, 10, 100, 100, True), - (None, 10, 1.5, 1.0, 3, 10, 100, 100, True), - (None, 10, 0, "hehe", 3, 10, 100, 100, True), - (None, 10, 0, 1.5, 3, 10, 100, 100, True), - (None, 10, 0, -1, 3, 10, 100, 100, True), - (None, 10, 0, 1, 3, 10, 100, 100, True), - (None, 10, 0, 1, "splines", 10, 100, 100, True), - (None, 10, 0, 1, 0, 10, 100, 100, True), - (None, 10, 0, 1, 200, 10, 100, 100, True), - (None, 10, 0, 1, 3, 0, 100, 100, True), - (None, 10, 0, 1, 3, 1e6, 100, 100, True), - (None, 10, 0, 1, 3, 10, 100, 100, True), - (None, 10, 0, 1, 3, 10, "many", 100, True), - (None, 10, 0, 1, 3, 10, 5, 100, True), - (None, 10, 0, 1, 3, 10, 1e7, 100, True), - (None, 10, 0, 1, 3, 10, 100, "random", True), - (None, 10, 0, 1, 3, 10, 100, -1.5, True), - (None, 10, 0, 1, 3, 10, 100, 111, "True"), + (546, 10, 0, 1.0, 3, 10, 0.5, 100, 100, True), + ("linear", 10, 0, 1.0, 3, 10, 0.5, 100, 100, True), + (None, "hehe", 0, 1.0, 3, 10, 0.5, 100, 100, True), + (None, 2, 0, 1.0, 3, 10, 0.5, 100, 100, True), + (None, 100000, 0, 1.0, 3, 10, 0.5, 100, 100, True), + (None, 10, "hehe", 1.0, 3, 10, 0.5, 100, 100, True), + (None, 10, -1.0, 1.0, 3, 10, 0.5, 100, 100, True), + (None, 10, 1.5, 1.0, 3, 10, 0.5, 100, 100, True), + (None, 10, 0, "hehe", 3, 10, 0.5, 100, 100, True), + (None, 10, 0, 1.5, 3, 10, 0.5, 100, 100, True), + (None, 100, -3.0, 0.99, 3, 30, 0.5, 100, None, True), + (None, 100, 0.01, 1, 3, 30, 0.5, 100, None, True), + (None, 100, 0.01, -4.5, 3, 30, 0.5, 100, None, True), + (None, 100, 0.01, 5.5, 3, 30, 0.5, 100, None, True), + (None, 100, 0.99, 0.01, 3, 30, 0.5, 100, None, True), + (None, 100, 0.01, 0.99, 3.0, 30, 0.5, 100, None, True), + (None, 100, 0.01, 0.99, -2, 30, 0.5, 100, None, True), + (None, 100, 0.01, 0.99, 30, 30, 0.5, 100, None, True), + (None, 100, 0.01, 0.99, 3, 30.0, 0.5, 100, None, True), + (None, 100, 0.01, 0.99, 3, -2, 0.5, 100, None, True), + (None, 100, 0.01, 0.99, 3, 500, 0.5, 100, None, True), + (None, 100, 0.01, 0.99, 3, 30, 0.5, 100.0, None, True), + (None, 100, 0.01, 0.99, 3, 30, 0.5, -100, None, True), + (None, 100, 0.01, 0.99, 3, 30, 0.5, 10000000000, None, True), + (None, 100, 0.01, 0.99, 3, 30, 0.5, 100, 234.5, True), + (None, 100, 0.01, 0.99, 3, 30, 0.5, 100, -5, True), + (None, 100, 0.01, 0.99, 3, 30, 0.5, 100, None, 4.0), + (None, 10, 0, -1, 3, 10, 0.5, 100, 100, True), + (None, 10, 0, 1, 3, 10, 0.5, 100, 100, True), + (None, 10, 0, 1, "splines", 10, 0.5, 100, 100, True), + (None, 10, 0, 1, 0, 10, 0.5, 100, 100, True), + (None, 10, 0, 1, 200, 10, 0.5, 100, 100, True), + (None, 10, 0, 1, 3, 0, 0.5, 100, 100, True), + (None, 10, 0, 1, 3, 1e6, 0.5, 100, 100, True), + (None, 10, 0, 1, 3, 10, 0.5, 100, 100, True), + (None, 10, 0, 1, 3, 10, 0.5, "many", 100, True), + (None, 10, 0, 1, 3, 10, 0.5, 5, 100, True), + (None, 10, 0, 1, 3, 10, 0.5, 1e7, 100, True), + (None, 10, 0, 1, 3, 10, 0.5, 100, "random", True), + (None, 10, 0, 1, 3, 10, 0.5, 100, -1.5, True), + (None, 10, 0, 1, 3, 10, 0.5, 100, 111, "True"), + (None, 100, 0.01, 0.99, 3, 30, "lambda", 100, None, True), + (None, 100, 0.01, 0.99, 3, 30, -1.0, 100, None, True), + (None, 100, 0.01, 0.99, 3, 30, 2000.0, 100, None, True), ], ) def test_bad_gps_instantiation( @@ -89,6 +110,7 @@ def test_bad_gps_instantiation( upper_grid_constraint, spline_order, n_splines, + lambda_, max_iter, random_seed, verbose, @@ -104,6 +126,7 @@ def test_bad_gps_instantiation( upper_grid_constraint=upper_grid_constraint, spline_order=spline_order, n_splines=n_splines, + lambda_=lambda_, max_iter=max_iter, random_seed=random_seed, verbose=verbose, diff --git a/tests/unit/test_tmle.py b/tests/unit/test_tmle.py index f691155..b9671c1 100644 --- a/tests/unit/test_tmle.py +++ b/tests/unit/test_tmle.py @@ -32,25 +32,37 @@ def test_tmle_fit(continuous_dataset_fixture): "n_estimators", "learning_rate", "max_depth", - "gamma", "random_seed", "verbose", ), [ - ([0, 1, 2, 3, 4], 100, 0.1, 5, 1.0, None, False), - (5, 100, 0.1, 5, 1.0, None, False), - ("5", 100, 0.1, 5, 1.0, None, False), - ([22.1, 30, 40, 50, 60, 70, 80.1], "100", 0.1, 5, 1.0, None, False), - ([22.1, 30, 40, 50, 60, 70, 80.1], 1, 0.1, 5, 1.0, None, False), - ([22.1, 30, 40, 50, 60, 70, 80.1], 100, "0.1", 5, 1.0, None, False), - ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 1e6, 5, 1.0, None, False), - ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, "5", 1.0, None, False), - ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, -5, 1.0, None, False), - ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, "1.0", None, False), - ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, -1, None, False), - ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, 1.0, "None", False), - ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, 1.0, -10, False), - ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, 1.0, None, "False"), + ([0, 1, "2", 3, 4], 100, 0.1, 5, None, False), + (5, 100, 0.1, 5, None, False), + ("5", 100, 0.1, 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], "100", 0.1, 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 1, 0.1, 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, "0.1", 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 1000000, 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, "5", None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, -5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, "None", False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, -10, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, None, "False"), + ({"a": 5, "b": 6}, 100, 0.1, 5, None, False), + (["a", "b", "c"], 100, 0.1, 5, None, False), + ([1.0], 100, 0.1, 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100.0, 0.1, 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 1.0, 0.1, 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 10000000, 0.1, 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, "hehe", 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, -0.1, 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 10000000, 5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, "hehe", None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5.0, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, -5, None, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, "hehe", False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, -10, False), + ([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, None, "thirty two"), ], ) def test_bad_tmle_instantiation( @@ -58,17 +70,15 @@ def test_bad_tmle_instantiation( n_estimators, learning_rate, max_depth, - gamma, random_seed, verbose, ): with pytest.raises(Exception) as bad: - GPS( + TMLE( treatment_grid_bins=treatment_grid_bins, n_estimators=n_estimators, learning_rate=learning_rate, max_depth=max_depth, - gamma=gamma, random_seed=random_seed, verbose=verbose, )