Skip to content

Commit

Permalink
added to unit tests, updated docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ronikobrosly committed Jun 29, 2020
1 parent 8cb46bd commit 00fc7bf
Show file tree
Hide file tree
Showing 11 changed files with 244 additions and 22 deletions.
8 changes: 4 additions & 4 deletions causal_curve/mediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,13 @@ def _validate_init_params(self):
but found type {type(self.bootstrap_draws)}"
)

if (isinstance(self.bootstrap_draws, float)) and self.bootstrap_draws < 100:
if (isinstance(self.bootstrap_draws, int)) and self.bootstrap_draws < 100:
raise ValueError(
f"bootstrap_draws parameter cannot be < 100, \
but found value {self.bootstrap_draws}"
)

if (isinstance(self.bootstrap_draws, float)) and self.bootstrap_draws > 500000:
if (isinstance(self.bootstrap_draws, int)) and self.bootstrap_draws > 500000:
raise ValueError(
f"bootstrap_draws parameter cannot > 500000, \
but found value {self.bootstrap_draws}"
Expand All @@ -235,15 +235,15 @@ def _validate_init_params(self):
)

if (
isinstance(self.bootstrap_replicates, float)
isinstance(self.bootstrap_replicates, int)
) and self.bootstrap_replicates < 50:
raise ValueError(
f"bootstrap_replicates parameter cannot be < 50, \
but found value {self.bootstrap_replicates}"
)

if (
isinstance(self.bootstrap_replicates, float)
isinstance(self.bootstrap_replicates, int)
) and self.bootstrap_replicates > 100000:
raise ValueError(
f"bootstrap_replicates parameter cannot > 100000, \
Expand Down
5 changes: 5 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
Change Log
==========

Version 0.2.4
-------------
- Strengthened unit tests


Version 0.2.3
-------------
- codecov integration
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = 'Roni Kobrosly'

# The full version, including alpha/beta/rc tags
release = '0.2.2'
release = '0.2.4'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="causal-curve",
version="0.2.3",
version="0.2.4",
author="Roni Kobrosly",
author_email="[email protected]",
description="A python library with tools to perform causal inference using \
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_full_gps_flow(dataset_fixture):
n_splines=10,
max_iter=100,
random_seed=100,
verbose=False,
verbose=True,
)
gps.fit(
T=dataset_fixture["treatment"],
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/test_mediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def test_full_mediation_flow(mediation_fixture):
treatment_grid_num=10,
lower_grid_constraint=0.01,
upper_grid_constraint=0.99,
bootstrap_draws=10,
bootstrap_replicates=10,
bootstrap_draws=100,
bootstrap_replicates=50,
spline_order=3,
n_splines=5,
lambda_=0.5,
max_iter=100,
max_iter=20,
random_seed=None,
verbose=False,
verbose=True,
)
med.fit(
T=mediation_fixture["treatment"],
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_tmle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_full_tmle_flow(dataset_fixture):
tmle = TMLE(
treatment_grid_bins=[22.1, 30, 40, 50, 60, 70, 80.1],
random_seed=100,
verbose=False,
verbose=True,
)
tmle.fit(
T=dataset_fixture["treatment"],
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
""" General unit tests of the causal-curve package """

from causal_curve.core import Core


def test_core():
"""
Tests the `Core` base class
"""

core = Core()
core.a = 5
core.b = 10

observed_results = core.get_params()

assert observed_results == {"a": 5, "b": 10}
84 changes: 79 additions & 5 deletions tests/unit/test_gps.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,104 @@
""" Unit tests of the gps.py module """

from pygam import LinearGAM
import pytest

from causal_curve import GPS
from tests.conftest import full_example_dataset


def test_gps_fit(dataset_fixture):
@pytest.mark.parametrize(
("df_fixture", "family"),
[
(full_example_dataset, "normal"),
(full_example_dataset, "lognormal"),
(full_example_dataset, "gamma"),
(full_example_dataset, None),
],
)
def test_gps_fit(df_fixture, family):
"""
Tests the fit method of the GPS tool
"""

gps = GPS(
gps_family=family,
treatment_grid_num=10,
lower_grid_constraint=0.0,
upper_grid_constraint=1.0,
spline_order=3,
n_splines=10,
max_iter=100,
random_seed=100,
verbose=False,
verbose=True,
)
gps.fit(
T=dataset_fixture["treatment"],
X=dataset_fixture["x1"],
y=dataset_fixture["outcome"],
T=df_fixture()["treatment"], X=df_fixture()["x1"], y=df_fixture()["outcome"],
)

assert isinstance(gps.gam_results, LinearGAM)
assert gps.gps.shape == (500,)


@pytest.mark.parametrize(
(
"gps_family",
"treatment_grid_num",
"lower_grid_constraint",
"upper_grid_constraint",
"spline_order",
"n_splines",
"max_iter",
"random_seed",
"verbose",
),
[
(546, 10, 0, 1.0, 3, 10, 100, 100, True),
("linear", 10, 0, 1.0, 3, 10, 100, 100, True),
(None, "hehe", 0, 1.0, 3, 10, 100, 100, True),
(None, 2, 0, 1.0, 3, 10, 100, 100, True),
(None, 1e6, 0, 1.0, 3, 10, 100, 100, True),
(None, 10, "hehe", 1.0, 3, 10, 100, 100, True),
(None, 10, -1, 1.0, 3, 10, 100, 100, True),
(None, 10, 1.5, 1.0, 3, 10, 100, 100, True),
(None, 10, 0, "hehe", 3, 10, 100, 100, True),
(None, 10, 0, 1.5, 3, 10, 100, 100, True),
(None, 10, 0, -1, 3, 10, 100, 100, True),
(None, 10, 0, 1, 3, 10, 100, 100, True),
(None, 10, 0, 1, "splines", 10, 100, 100, True),
(None, 10, 0, 1, 0, 10, 100, 100, True),
(None, 10, 0, 1, 200, 10, 100, 100, True),
(None, 10, 0, 1, 3, 0, 100, 100, True),
(None, 10, 0, 1, 3, 1e6, 100, 100, True),
(None, 10, 0, 1, 3, 10, 100, 100, True),
(None, 10, 0, 1, 3, 10, "many", 100, True),
(None, 10, 0, 1, 3, 10, 5, 100, True),
(None, 10, 0, 1, 3, 10, 1e7, 100, True),
(None, 10, 0, 1, 3, 10, 100, "random", True),
(None, 10, 0, 1, 3, 10, 100, -1.5, True),
(None, 10, 0, 1, 3, 10, 100, 111, "True"),
],
)
def test_bad_gps_instantiation(
gps_family,
treatment_grid_num,
lower_grid_constraint,
upper_grid_constraint,
spline_order,
n_splines,
max_iter,
random_seed,
verbose,
):
with pytest.raises(Exception) as bad:
GPS(
gps_family=gps_family,
treatment_grid_num=treatment_grid_num,
lower_grid_constraint=lower_grid_constraint,
upper_grid_constraint=upper_grid_constraint,
spline_order=spline_order,
n_splines=n_splines,
max_iter=max_iter,
random_seed=random_seed,
verbose=verbose,
)
84 changes: 80 additions & 4 deletions tests/unit/test_mediation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Unit tests of the Mediation.py module """

import numpy as np
import pytest

from causal_curve import Mediation

Expand All @@ -14,14 +15,14 @@ def test_mediation_fit(mediation_fixture):
treatment_grid_num=10,
lower_grid_constraint=0.01,
upper_grid_constraint=0.99,
bootstrap_draws=10,
bootstrap_replicates=10,
bootstrap_draws=100,
bootstrap_replicates=50,
spline_order=3,
n_splines=5,
lambda_=0.5,
max_iter=100,
max_iter=20,
random_seed=None,
verbose=False,
verbose=True,
)
med.fit(
T=mediation_fixture["treatment"],
Expand All @@ -30,3 +31,78 @@ def test_mediation_fit(mediation_fixture):
)

assert len(med.final_bootstrap_results) == 9


@pytest.mark.parametrize(
(
"treatment_grid_num",
"lower_grid_constraint",
"upper_grid_constraint",
"bootstrap_draws",
"bootstrap_replicates",
"spline_order",
"n_splines",
"lambda_",
"max_iter",
"random_seed",
"verbose",
),
[
(10.5, 0.01, 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
(0, 0.01, 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
(1e6, 0.01, 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
(10, "hehe", 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
(10, -1, 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
(10, 1.5, 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
(10, 0.1, "hehe", 10, 10, 3, 5, 0.5, 100, None, True),
(10, 0.1, -1, 10, 10, 3, 5, 0.5, 100, None, True),
(10, 0.1, 1.5, 10, 10, 3, 5, 0.5, 100, None, True),
(10, 0.1, 0.9, 10.5, 10, 3, 5, 0.5, 100, None, True),
(10, 0.1, 0.9, -2, 10, 3, 5, 0.5, 100, None, True),
(10, 0.1, 0.9, 1e6, 10, 3, 5, 0.5, 100, None, True),
(10, 0.1, 0.9, 100, "10", 3, 5, 0.5, 100, None, True),
(10, 0.1, 0.9, 100, -1, 3, 5, 0.5, 100, None, True),
(10, 0.1, 0.9, 100, 1e6, 3, 5, 0.5, 100, None, True),
(10, 0.1, 0.9, 100, 200, "3", 5, 0.5, 100, None, True),
(10, 0.1, 0.9, 100, 200, 1, 5, 0.5, 100, None, True),
(10, 0.1, 0.9, 100, 200, 1e6, 5, 0.5, 100, None, True),
(10, 0.1, 0.9, 100, 200, 5, "10", 0.5, 100, None, True),
(10, 0.1, 0.9, 100, 200, 5, 1, 0.5, 100, None, True),
(10, 0.1, 0.9, 100, 200, 5, 10, "0.5", 100, None, True),
(10, 0.1, 0.9, 100, 200, 5, 10, -0.5, 100, None, True),
(10, 0.1, 0.9, 100, 200, 5, 10, 1e7, 100, None, True),
(10, 0.1, 0.9, 100, 200, 5, 10, 1, "100", None, True),
(10, 0.1, 0.9, 100, 200, 5, 10, 1, 1, None, True),
(10, 0.1, 0.9, 100, 200, 5, 10, 1, 1e8, None, True),
(10, 0.1, 0.9, 100, 200, 5, 10, 1, 100, "None", True),
(10, 0.1, 0.9, 100, 200, 5, 10, 1, 100, -5, True),
(10, 0.1, 0.9, 100, 200, 5, 10, 1, 100, 123, "True"),
],
)
def test_bad_mediation_instantiation(
treatment_grid_num,
lower_grid_constraint,
upper_grid_constraint,
bootstrap_draws,
bootstrap_replicates,
spline_order,
n_splines,
lambda_,
max_iter,
random_seed,
verbose,
):
with pytest.raises(Exception) as bad:
Mediation(
treatment_grid_num=treatment_grid_num,
lower_grid_constraint=lower_grid_constraint,
upper_grid_constraint=upper_grid_constraint,
bootstrap_draws=bootstrap_draws,
bootstrap_replicates=bootstrap_replicates,
spline_order=spline_order,
n_splines=n_splines,
lambda_=lambda_,
max_iter=max_iter,
random_seed=random_seed,
verbose=verbose,
)
Loading

0 comments on commit 00fc7bf

Please sign in to comment.