Skip to content

Commit

Permalink
Feature/fix sphinx (#28)
Browse files Browse the repository at this point in the history
* fixed issue where sphinx didn't show modules

* updates

* added more test cases

* tests for lambda_ in GPS

* more checks for TMLE tool

* fixed TMLE unit tests

Co-authored-by: rkobrosly <[email protected]>
  • Loading branch information
ronikobrosly and rkobrosly authored Dec 25, 2020
1 parent 954af30 commit 4626e03
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 59 deletions.
2 changes: 1 addition & 1 deletion causal_curve/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def get_params(self):
[(k, v) for k, v in list(attrs.items()) if (k[0] != "_") and (k[-1] != "_")]
)

__version__ = pkg_resources.require("causal-curve")[0].version
__version__ = "0.5.2"
2 changes: 1 addition & 1 deletion causal_curve/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _validate_init_params(self):

if (isinstance(self.lambda_, (int, float))) and self.lambda_ <= 0:
raise ValueError(
f"lambda_ parameter should be >= 2, but found {self.lambda_}"
f"lambda_ parameter should be > 0, but found {self.lambda_}"
)

if (isinstance(self.lambda_, (int, float))) and self.lambda_ >= 1000:
Expand Down
6 changes: 6 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ Change Log
==========


Version 0.5.2
-------------
- Fixed bug that prevented `causal-curve` modules from being shown in Sphinx documentation
- Augmented tests to capture more error states and improve code coverage


Version 0.5.1
-------------
- Removed working test file
Expand Down
29 changes: 15 additions & 14 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,40 @@

import os
import sys
sys.path.insert(0, os.path.abspath('../'))

sys.path.insert(0, os.path.abspath("../"))


# -- Project information -----------------------------------------------------

project = 'causal_curve'
copyright = '2020, Roni Kobrosly'
author = 'Roni Kobrosly'
project = "causal_curve"
copyright = "2020, Roni Kobrosly"
author = "Roni Kobrosly"

# The full version, including alpha/beta/rc tags
release = '0.5.1'
release = "0.5.2"

# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'numpydoc',
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"numpydoc",
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
pygments_style = "sphinx"

# this is needed for some reason...
# see https://github.com/numpy/numpydoc/issues/69
Expand All @@ -53,16 +54,16 @@
# generate autosummary even if no references
autosummary_generate = True

master_doc = 'index'
master_doc = "index"

# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = []
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="causal-curve",
version="0.5.1",
version="0.5.2",
author="Roni Kobrosly",
author_email="[email protected]",
description="A python library with tools to perform causal inference using \
Expand Down
71 changes: 47 additions & 24 deletions tests/unit/test_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,56 @@ def test_gps_fit(df_fixture, family):
"upper_grid_constraint",
"spline_order",
"n_splines",
"lambda_",
"max_iter",
"random_seed",
"verbose",
),
[
(546, 10, 0, 1.0, 3, 10, 100, 100, True),
("linear", 10, 0, 1.0, 3, 10, 100, 100, True),
(None, "hehe", 0, 1.0, 3, 10, 100, 100, True),
(None, 2, 0, 1.0, 3, 10, 100, 100, True),
(None, 1e6, 0, 1.0, 3, 10, 100, 100, True),
(None, 10, "hehe", 1.0, 3, 10, 100, 100, True),
(None, 10, -1, 1.0, 3, 10, 100, 100, True),
(None, 10, 1.5, 1.0, 3, 10, 100, 100, True),
(None, 10, 0, "hehe", 3, 10, 100, 100, True),
(None, 10, 0, 1.5, 3, 10, 100, 100, True),
(None, 10, 0, -1, 3, 10, 100, 100, True),
(None, 10, 0, 1, 3, 10, 100, 100, True),
(None, 10, 0, 1, "splines", 10, 100, 100, True),
(None, 10, 0, 1, 0, 10, 100, 100, True),
(None, 10, 0, 1, 200, 10, 100, 100, True),
(None, 10, 0, 1, 3, 0, 100, 100, True),
(None, 10, 0, 1, 3, 1e6, 100, 100, True),
(None, 10, 0, 1, 3, 10, 100, 100, True),
(None, 10, 0, 1, 3, 10, "many", 100, True),
(None, 10, 0, 1, 3, 10, 5, 100, True),
(None, 10, 0, 1, 3, 10, 1e7, 100, True),
(None, 10, 0, 1, 3, 10, 100, "random", True),
(None, 10, 0, 1, 3, 10, 100, -1.5, True),
(None, 10, 0, 1, 3, 10, 100, 111, "True"),
(546, 10, 0, 1.0, 3, 10, 0.5, 100, 100, True),
("linear", 10, 0, 1.0, 3, 10, 0.5, 100, 100, True),
(None, "hehe", 0, 1.0, 3, 10, 0.5, 100, 100, True),
(None, 2, 0, 1.0, 3, 10, 0.5, 100, 100, True),
(None, 100000, 0, 1.0, 3, 10, 0.5, 100, 100, True),
(None, 10, "hehe", 1.0, 3, 10, 0.5, 100, 100, True),
(None, 10, -1.0, 1.0, 3, 10, 0.5, 100, 100, True),
(None, 10, 1.5, 1.0, 3, 10, 0.5, 100, 100, True),
(None, 10, 0, "hehe", 3, 10, 0.5, 100, 100, True),
(None, 10, 0, 1.5, 3, 10, 0.5, 100, 100, True),
(None, 100, -3.0, 0.99, 3, 30, 0.5, 100, None, True),
(None, 100, 0.01, 1, 3, 30, 0.5, 100, None, True),
(None, 100, 0.01, -4.5, 3, 30, 0.5, 100, None, True),
(None, 100, 0.01, 5.5, 3, 30, 0.5, 100, None, True),
(None, 100, 0.99, 0.01, 3, 30, 0.5, 100, None, True),
(None, 100, 0.01, 0.99, 3.0, 30, 0.5, 100, None, True),
(None, 100, 0.01, 0.99, -2, 30, 0.5, 100, None, True),
(None, 100, 0.01, 0.99, 30, 30, 0.5, 100, None, True),
(None, 100, 0.01, 0.99, 3, 30.0, 0.5, 100, None, True),
(None, 100, 0.01, 0.99, 3, -2, 0.5, 100, None, True),
(None, 100, 0.01, 0.99, 3, 500, 0.5, 100, None, True),
(None, 100, 0.01, 0.99, 3, 30, 0.5, 100.0, None, True),
(None, 100, 0.01, 0.99, 3, 30, 0.5, -100, None, True),
(None, 100, 0.01, 0.99, 3, 30, 0.5, 10000000000, None, True),
(None, 100, 0.01, 0.99, 3, 30, 0.5, 100, 234.5, True),
(None, 100, 0.01, 0.99, 3, 30, 0.5, 100, -5, True),
(None, 100, 0.01, 0.99, 3, 30, 0.5, 100, None, 4.0),
(None, 10, 0, -1, 3, 10, 0.5, 100, 100, True),
(None, 10, 0, 1, 3, 10, 0.5, 100, 100, True),
(None, 10, 0, 1, "splines", 10, 0.5, 100, 100, True),
(None, 10, 0, 1, 0, 10, 0.5, 100, 100, True),
(None, 10, 0, 1, 200, 10, 0.5, 100, 100, True),
(None, 10, 0, 1, 3, 0, 0.5, 100, 100, True),
(None, 10, 0, 1, 3, 1e6, 0.5, 100, 100, True),
(None, 10, 0, 1, 3, 10, 0.5, 100, 100, True),
(None, 10, 0, 1, 3, 10, 0.5, "many", 100, True),
(None, 10, 0, 1, 3, 10, 0.5, 5, 100, True),
(None, 10, 0, 1, 3, 10, 0.5, 1e7, 100, True),
(None, 10, 0, 1, 3, 10, 0.5, 100, "random", True),
(None, 10, 0, 1, 3, 10, 0.5, 100, -1.5, True),
(None, 10, 0, 1, 3, 10, 0.5, 100, 111, "True"),
(None, 100, 0.01, 0.99, 3, 30, "lambda", 100, None, True),
(None, 100, 0.01, 0.99, 3, 30, -1.0, 100, None, True),
(None, 100, 0.01, 0.99, 3, 30, 2000.0, 100, None, True),
],
)
def test_bad_gps_instantiation(
Expand All @@ -89,6 +110,7 @@ def test_bad_gps_instantiation(
upper_grid_constraint,
spline_order,
n_splines,
lambda_,
max_iter,
random_seed,
verbose,
Expand All @@ -104,6 +126,7 @@ def test_bad_gps_instantiation(
upper_grid_constraint=upper_grid_constraint,
spline_order=spline_order,
n_splines=n_splines,
lambda_=lambda_,
max_iter=max_iter,
random_seed=random_seed,
verbose=verbose,
Expand Down
46 changes: 28 additions & 18 deletions tests/unit/test_tmle.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,43 +32,53 @@ def test_tmle_fit(continuous_dataset_fixture):
"n_estimators",
"learning_rate",
"max_depth",
"gamma",
"random_seed",
"verbose",
),
[
([0, 1, 2, 3, 4], 100, 0.1, 5, 1.0, None, False),
(5, 100, 0.1, 5, 1.0, None, False),
("5", 100, 0.1, 5, 1.0, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], "100", 0.1, 5, 1.0, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 1, 0.1, 5, 1.0, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, "0.1", 5, 1.0, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 1e6, 5, 1.0, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, "5", 1.0, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, -5, 1.0, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, "1.0", None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, -1, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, 1.0, "None", False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, 1.0, -10, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, 1.0, None, "False"),
([0, 1, "2", 3, 4], 100, 0.1, 5, None, False),
(5, 100, 0.1, 5, None, False),
("5", 100, 0.1, 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], "100", 0.1, 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 1, 0.1, 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, "0.1", 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 1000000, 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, "5", None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, -5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, "None", False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, -10, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, None, "False"),
({"a": 5, "b": 6}, 100, 0.1, 5, None, False),
(["a", "b", "c"], 100, 0.1, 5, None, False),
([1.0], 100, 0.1, 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100.0, 0.1, 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 1.0, 0.1, 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 10000000, 0.1, 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, "hehe", 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, -0.1, 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 10000000, 5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, "hehe", None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5.0, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, -5, None, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, "hehe", False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, -10, False),
([22.1, 30, 40, 50, 60, 70, 80.1], 100, 0.1, 5, None, "thirty two"),
],
)
def test_bad_tmle_instantiation(
treatment_grid_bins,
n_estimators,
learning_rate,
max_depth,
gamma,
random_seed,
verbose,
):
with pytest.raises(Exception) as bad:
GPS(
TMLE(
treatment_grid_bins=treatment_grid_bins,
n_estimators=n_estimators,
learning_rate=learning_rate,
max_depth=max_depth,
gamma=gamma,
random_seed=random_seed,
verbose=verbose,
)

0 comments on commit 4626e03

Please sign in to comment.