diff --git a/README.md b/README.md index 832e7e7..83c8614 100644 --- a/README.md +++ b/README.md @@ -40,11 +40,22 @@ For example, when you would like to: * Estimate how changing neighborhood income inequality (Gini index) could be causally related to neighborhood crime rate. This library attempts to address this gap, providing tools to estimate causal curves (AKA causal dose-response curves). +Both continuous and binary outcomes can be modeled against a continuous treatment. ## Installation +Available via PyPI: + `pip install causal-curve` +You can also get the latest version of causal-curve by cloning the repository:: + +``` +git clone -b master https://github.com/ronikobrosly/causal-curve.git +cd causal-curve +pip install . +``` + ## Documentation [Documentation is available at readthedocs.org](https://causal-curve.readthedocs.io/en/latest/) diff --git a/causal_curve/core.py b/causal_curve/core.py index 77f481e..06772a2 100644 --- a/causal_curve/core.py +++ b/causal_curve/core.py @@ -4,8 +4,7 @@ class Core: - """ Base class for causal_curve module - """ + """Base class for causal_curve module""" def __init__(self): pass diff --git a/causal_curve/gps.py b/causal_curve/gps.py index c01d325..f801efc 100644 --- a/causal_curve/gps.py +++ b/causal_curve/gps.py @@ -9,8 +9,9 @@ import numpy as np import pandas as pd -from pandas.api.types import is_float_dtype, is_numeric_dtype -from pygam import LinearGAM, s +from pandas.api.types import is_float_dtype, is_integer_dtype, is_numeric_dtype +from pygam import LinearGAM, LogisticGAM, s +from scipy.special import logit from scipy.stats import gamma, norm import statsmodels.api as sm from statsmodels.genmod.families.links import inverse_power as Inverse_Power @@ -24,7 +25,8 @@ class GPS(Core): """ In a multi-stage approach, this computes the generalized propensity score (GPS) function, and uses this in a generalized additive model (GAM) to correct treatment prediction of - the outcome variable. Assumes continuous treatment and outcome variable. + the outcome variable. Assumes continuous treatment, but the outcome variable may be + continuous or binary. WARNING: @@ -146,7 +148,6 @@ class GPS(Core): Hirano K and Imbens GW. The propensity score with continuous treatments. In: Gelman A and Meng XL (eds) Applied bayesian modeling and causal inference from incomplete-data perspectives. Oxford, UK: Wiley, 2004, pp.73–84. - """ def __init__( @@ -344,8 +345,7 @@ def _validate_init_params(self): ) def _validate_fit_data(self): - """Verifies that T, X, and y are formatted the right way - """ + """Verifies that T, X, and y are formatted the right way""" # Checks for T column if not is_float_dtype(self.T): raise TypeError(f"Treatment data must be of type float") @@ -366,12 +366,19 @@ def _validate_fit_data(self): ) # Checks for Y column - if not is_float_dtype(self.y): - raise TypeError(f"Outcome data must be of type float") + if not (is_float_dtype(self.y) or is_integer_dtype(self.y)): + raise TypeError(f"Outcome data must be of type float or integer") + + if is_integer_dtype(self.y) and ( + not np.array_equal(np.sort(self.y.unique()), np.array([0, 1])) + ): + raise TypeError( + f"If your outcome data is of type integer (binary outcome)," + f"it should only contain 1's and 0's." + ) def _grid_values(self): - """Produces initial grid values for the treatment variable - """ + """Produces initial grid values for the treatment variable""" return np.quantile( self.T, q=np.linspace( @@ -383,16 +390,20 @@ def _grid_values(self): def fit(self, T, X, y): """Fits the GPS causal dose-response model. For now, this only accepts pandas columns. + While the treatment variable must be continuous (or ordinal with many levels), the + outcome variable may be continuous or binary. Parameters ---------- T: array-like, shape (n_samples,) - A continuous treatment variable + A continuous treatment variable. X: array-like, shape (n_samples, m_features) Covariates, where n_samples is the number of samples - and m_features is the number of features + and m_features is the number of features. Features can be a mix of continuous + and nominal/categorical variables. y: array-like, shape (n_samples,) - Outcome variable + Outcome variable. May be continuous or binary. If continuous, this must + be a series of type `float`, if binary must be a series of type `integer`. Returns ---------- @@ -403,6 +414,15 @@ def fit(self, T, X, y): self.X = X.reset_index(drop=True, inplace=False) self.y = y.reset_index(drop=True, inplace=False) + # Determine what type of outcome variable we're working with + if is_float_dtype(self.y): + self.outcome_type = "continuous" + elif is_integer_dtype(self.y): + self.outcome_type = "binary" + + if self.verbose: + print(f"Determined the outcome variable is of type {self.outcome_type}...") + # Validate this input data self._validate_fit_data() @@ -495,42 +515,81 @@ def calculate_CDRC(self, ci=0.95): """ self._validate_calculate_CDRC_params(ci) - # Create CDRC predictions from trained GAM - self._cdrc_preds = self._cdrc_predictions(ci) - if self.verbose: print( """Generating predictions for each value of treatment grid, and averaging to get the CDRC...""" ) - # For each column of _cdrc_preds, calculate the mean and confidence interval bounds - results = [] + # Create CDRC predictions from trained GAM + # If working with a continuous outcome variable, use this path + if self.outcome_type == "continuous": + self._cdrc_preds = self._cdrc_predictions_continuous(ci) + + results = [] + + for i in range(0, self.treatment_grid_num): + temp_grid_value = self.grid_values[i] + temp_point_estimate = self._cdrc_preds[:, i, 0].mean() + mean_ci_width = ( + self._cdrc_preds[:, i, 2].mean() - self._cdrc_preds[:, i, 1].mean() + ) / 2 + temp_lower_bound = temp_point_estimate - mean_ci_width + temp_upper_bound = temp_point_estimate + mean_ci_width + results.append( + [ + temp_grid_value, + temp_point_estimate, + temp_lower_bound, + temp_upper_bound, + ] + ) + + outcome_name = "Causal_Dose_Response" - for i in range(0, self.treatment_grid_num): - temp_grid_value = self.grid_values[i] - temp_point_estimate = self._cdrc_preds[:, i, 0].mean() - mean_ci_width = ( - self._cdrc_preds[:, i, 2].mean() - self._cdrc_preds[:, i, 1].mean() - ) / 2 - temp_lower_bound = temp_point_estimate - mean_ci_width - temp_upper_bound = temp_point_estimate + mean_ci_width - results.append( - [ - temp_grid_value, - temp_point_estimate, - temp_lower_bound, - temp_upper_bound, - ] - ) + # If working with a binary outcome variable, use this path + else: + self._cdrc_preds = self._cdrc_predictions_binary(ci) + + # Capture the first prediction's mean log odds. + # This will serve as a reference for calculating the odds ratios + log_odds_reference = self._cdrc_preds[:, 0, 0].mean() + + results = [] + + for i in range(0, self.treatment_grid_num): + temp_grid_value = self.grid_values[i] + + temp_log_odds_estimate = ( + self._cdrc_preds[:, i, 0].mean() - log_odds_reference + ) + temp_OR_estimate = np.exp(temp_log_odds_estimate) + + temp_lower_bound = np.exp( + temp_log_odds_estimate + - (self._calculate_z_score(ci) * self._cdrc_preds[:, i, 1].mean()) + ) + temp_upper_bound = np.exp( + temp_log_odds_estimate + + (self._calculate_z_score(ci) * self._cdrc_preds[:, i, 1].mean()) + ) + results.append( + [ + temp_grid_value, + temp_OR_estimate, + temp_lower_bound, + temp_upper_bound, + ] + ) + + outcome_name = "Causal_Odds_Ratio" return pd.DataFrame( - results, columns=["Treatment", "CDRC", "Lower_CI", "Upper_CI"] + results, columns=["Treatment", outcome_name, "Lower_CI", "Upper_CI"] ).round(3) def _validate_calculate_CDRC_params(self, ci): - """Validates the parameters given to `calculate_CDRC` - """ + """Validates the parameters given to `calculate_CDRC`""" if not isinstance(ci, float): raise TypeError( @@ -540,9 +599,14 @@ def _validate_calculate_CDRC_params(self, ci): if isinstance(ci, float) and ((ci <= 0) or (ci >= 1.0)): raise ValueError("`ci` parameter should be between (0, 1)") - def _cdrc_predictions(self, ci): + def _calculate_z_score(self, ci): + """Calculates the critical z-score for a desired two-sided, confidence interval width.""" + return norm.ppf((1 + ci) / 2) + + def _cdrc_predictions_continuous(self, ci): """Returns the predictions of CDRC for each value of the treatment grid. Essentially, - we're making predictions using the original treatment and gps_at_grid + we're making predictions using the original treatment and gps_at_grid. + To be used when the outcome of interest is continuous. """ # To keep track of cdrc predictions, we create an empty 3d array of shape @@ -569,6 +633,42 @@ def _cdrc_predictions(self, ci): return np.round(cdrc_preds, 3) + def _cdrc_predictions_binary(self, ci): + """Returns the predictions of CDRC for each value of the treatment grid. Essentially, + we're making predictions using the original treatment and gps_at_grid. + To be used when the outcome of interest is binary. + """ + # To keep track of cdrc predictions, we create an empty 2d array of shape + # (n_samples, treatment_grid_num, 2). The last dimension is of length 2 because + # we are going to keep track of the point estimate (log-odds) of the prediction, as well as + # the standard error of the prediction interval (again, this is for the log odds) + cdrc_preds = np.zeros((len(self.T), self.treatment_grid_num, 2), dtype=float) + + # Loop through each of the grid values, predict point estimate and get prediction interval + for i in range(0, self.treatment_grid_num): + + temp_T = np.repeat(self.grid_values[i], repeats=len(self.T)) + temp_gps = self.gps_at_grid[:, i] + + temp_cdrc_preds = logit( + self.gam_results.predict_proba(np.column_stack((temp_T, temp_gps))) + ) + + temp_cdrc_interval = logit( + self.gam_results.confidence_intervals( + np.column_stack((temp_T, temp_gps)), width=ci + ) + ) + + standard_error = ( + temp_cdrc_interval[:, 1] - temp_cdrc_preds + ) / self._calculate_z_score(ci) + + cdrc_preds[:, i, 0] = temp_cdrc_preds + cdrc_preds[:, i, 1] = standard_error + + return np.round(cdrc_preds, 3) + def _gps_values_at_grid(self): """Returns an array where we get the GPS-derived values for each element of the treatment grid. Resulting array will be of shape (n_samples, treatment_grid_num) @@ -592,19 +692,18 @@ def print_gam_summary(self): Returns ---------- self: object - - """ print(self._gam_summary_str) def _fit_gam(self): - """Fits a GAM that predicts the outcome from the treatment and GPS - """ + """Fits a GAM that predicts the outcome (continuous or binary) from the treatment and GPS""" X = np.column_stack((self.T.values, self.gps)) y = np.asarray(self.y) - return LinearGAM( + model_type_dict = {"continuous": LinearGAM, "binary": LogisticGAM} + + return model_type_dict[self.outcome_type]( s(0, n_splines=self.n_splines, spline_order=self.spline_order) + s(1, n_splines=self.n_splines, spline_order=self.spline_order), max_iter=self.max_iter, @@ -612,8 +711,7 @@ def _fit_gam(self): ).fit(X, y) def _create_normal_gps_function(self): - """Models the GPS using a GLM of the Gaussian family - """ + """Models the GPS using a GLM of the Gaussian family""" normal_gps_model = sm.GLM( self.T, add_constant(self.X), family=sm.families.Gaussian() ).fit() @@ -627,8 +725,7 @@ def gps_function(treatment_val, pred_treat=pred_treat, sigma=sigma): return gps_function, normal_gps_model.deviance def _create_lognormal_gps_function(self): - """Models the GPS using a GLM of the Gaussian family (assumes treatment is lognormal) - """ + """Models the GPS using a GLM of the Gaussian family (assumes treatment is lognormal)""" lognormal_gps_model = sm.GLM( np.log(self.T), add_constant(self.X), family=sm.families.Gaussian() ).fit() @@ -642,8 +739,7 @@ def gps_function(treatment_val, pred_log_treat=pred_log_treat, sigma=sigma): return gps_function, lognormal_gps_model.deviance def _create_gamma_gps_function(self): - """Models the GPS using a GLM of the Gamma family - """ + """Models the GPS using a GLM of the Gamma family""" gamma_gps_model = sm.GLM( self.T, add_constant(self.X), family=sm.families.Gamma(Inverse_Power()) ).fit() diff --git a/causal_curve/mediation.py b/causal_curve/mediation.py index e6dd66c..d571aa5 100644 --- a/causal_curve/mediation.py +++ b/causal_curve/mediation.py @@ -329,8 +329,7 @@ def _validate_init_params(self): ) def _validate_fit_data(self): - """Verifies that T, M, and y are formatted the right way - """ + """Verifies that T, M, and y are formatted the right way""" # Checks for T column if not is_float_dtype(self.T): raise TypeError(f"Treatment data must be of type float") @@ -344,8 +343,7 @@ def _validate_fit_data(self): raise TypeError(f"Outcome data must be of type float") def _grid_values(self): - """Produces initial grid values for the treatment variable - """ + """Produces initial grid values for the treatment variable""" return np.quantile( self.T, q=np.linspace( @@ -356,8 +354,7 @@ def _grid_values(self): ) def _collect_mean_t_levels(self): - """Collects the mean treatment value within each treatment bucket in the grid_values - """ + """Collects the mean treatment value within each treatment bucket in the grid_values""" t_bin_means = [] @@ -539,15 +536,13 @@ def calculate_mediation(self, ci=0.95): return final_results def _clip_negatives(self, number): - """Helper function to clip negative numbers to zero - """ + """Helper function to clip negative numbers to zero""" if number < 0: return 0 return number def _bootstrap_analysis(self, temp_low_treatment, temp_high_treatment): - """The top-level function used in the fitting method - """ + """The top-level function used in the fitting method""" bootstrap_collection = [] @@ -584,8 +579,7 @@ def _bootstrap_analysis(self, temp_low_treatment, temp_high_treatment): return bootstrap_results def _create_bootstrap_replicate(self): - """Creates a single bootstrap replicate from the data - """ + """Creates a single bootstrap replicate from the data""" temp_t = self.T.sample(n=self.bootstrap_draws, replace=True) temp_m = self.M.iloc[temp_t.index] temp_y = self.y.iloc[temp_t.index] @@ -593,8 +587,7 @@ def _create_bootstrap_replicate(self): return temp_t, temp_m, temp_y def _fit_gams(self, temp_t, temp_m, temp_y): - """Fits the mediator and outcome GAMs - """ + """Fits the mediator and outcome GAMs""" temp_mediator_model = LinearGAM( s(0, n_splines=self.n_splines, spline_order=self.spline_order), fit_intercept=True, @@ -623,8 +616,7 @@ def _mediator_prediction( temp_low_treatment, temp_high_treatment, ): - """Makes predictions based on the mediator models - """ + """Makes predictions based on the mediator models""" m1_mean = temp_mediator_model.predict(temp_high_treatment)[0] m0_mean = temp_mediator_model.predict(temp_low_treatment)[0] @@ -648,8 +640,7 @@ def _outcome_prediction( predict_m0, temp_outcome_model, ): - """Makes predictions based on the outcome models - """ + """Makes predictions based on the outcome models""" outcome_preds = {} diff --git a/causal_curve/tmle.py b/causal_curve/tmle.py index b15b640..7dd1a72 100644 --- a/causal_curve/tmle.py +++ b/causal_curve/tmle.py @@ -200,8 +200,7 @@ def _validate_init_params(self): ) def _validate_fit_data(self): - """Verifies that T, X, and y are formatted the right way - """ + """Verifies that T, X, and y are formatted the right way""" # Checks for T column if not is_float_dtype(self.t_data): raise TypeError(f"Treatment data must be of type float") @@ -226,8 +225,7 @@ def _validate_fit_data(self): raise TypeError(f"Outcome data must be of type float") def _validate_calculate_CDRC_params(self, ci): - """Validates the parameters given to `calculate_CDRC` - """ + """Validates the parameters given to `calculate_CDRC`""" if not isinstance(ci, float): raise TypeError( @@ -276,8 +274,7 @@ def _create_treatment_comparison_df( return temp_y, temp_x, temp_t def _collect_mean_t_levels(self): - """Collects the mean treatment value within each treatment bucket in treatment_grid_bins - """ + """Collects the mean treatment value within each treatment bucket in treatment_grid_bins""" t_bin_means = [] @@ -455,19 +452,22 @@ def calculate_CDRC(self, ci=0.95, CDRC_grid_num=100): return pd.DataFrame( { "Treatment": Treatment, - "CDRC": CDRC, + "Causal_Dose_Response": CDRC, "Lower_CI": Lower_CI, "Upper_CI": Upper_CI, } ).round(3) def _grid_values(self, CDRC_grid_num, t_values): - """Produces grid values for use in estimating the final CDRC and confidence intervals. - """ + """Produces grid values for use in estimating the final CDRC and confidence intervals.""" return np.quantile( self.t_data[((self.t_data > t_values[0]) & (self.t_data < t_values[-1]))], - q=np.linspace(start=0, stop=1, num=CDRC_grid_num,), + q=np.linspace( + start=0, + stop=1, + num=CDRC_grid_num, + ), ) def _q_model(self, temp_y, temp_x, temp_t): @@ -501,8 +501,7 @@ def _q_model(self, temp_y, temp_x, temp_t): return y_hat_a, y_hat_1, y_hat_0 def _g_model(self, temp_x, temp_t): - """Produces the G-model and gets treatment assignment predictions - """ + """Produces the G-model and gets treatment assignment predictions""" X = temp_x.to_numpy() t = temp_t.to_numpy() @@ -523,8 +522,7 @@ def _g_model(self, temp_x, temp_t): return pi_hat1, pi_hat0 def _delta_hat_estimation(self, temp_y, temp_x, temp_t): - """Estimates delta to correct treatment estimation - """ + """Estimates delta to correct treatment estimation""" H_a = [] for idx, treatment in enumerate(np.asarray(temp_t)): diff --git a/docs/GPS_example.rst b/docs/GPS_example.rst index 6cdf957..580b57c 100644 --- a/docs/GPS_example.rst +++ b/docs/GPS_example.rst @@ -84,8 +84,8 @@ half the error of a simple LOESS estimate using only the treatment and the outco .. image:: ../imgs/cdrc/CDRC.png - - +A binary outcome can also be handled with the GPS tool. As long as the outcome series contains +binary integer values (e.g. 0's and 1's) the GPS `fit` method will work as it's supposed to. References diff --git a/docs/changelog.rst b/docs/changelog.rst index e3b3c87..7908bcd 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,12 @@ Change Log ========== +Version 0.4.0 +------------- +- Added support for binary outcomes in GPS tool +- Small changes to repo README + + Version 0.3.8 ------------- - Added citation (yay!) diff --git a/docs/conf.py b/docs/conf.py index 6c41670..4181576 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,7 +22,7 @@ author = 'Roni Kobrosly' # The full version, including alpha/beta/rc tags -release = '0.3.8' +release = '0.4.0' # -- General configuration --------------------------------------------------- diff --git a/docs/index.rst b/docs/index.rst index 8f41843..3644667 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -64,6 +64,7 @@ There are many available methods to perform causal inference when your intervent but few methods exist to handle continuous treatments. This is unfortunate because there are many scenarios (in industry and research) where these methods would be useful. This library attempts to address this gap, providing tools to estimate causal curves (AKA causal dose-response curves). +Both continuous and binary outcomes can be modeled with this package. Quick example (of the ``GPS`` tool) diff --git a/docs/intro.rst b/docs/intro.rst index d28b6a2..3d8b8af 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -14,7 +14,6 @@ However, for ethical or financial reasons, experiments may not always be feasibl * It's not ethical to randomly assign some people to receive a possible carcinogen in pill form while others receive a sugar pill, and then see which group is more likely to develop cancer. * It's not feasible to increase the household incomes of some New York neighborhoods, while leaving others unchanged to see if changing a neighborhood's income inequality would improve the local crime rate. - "Causal inference" methods are a set of approaches that attempt to estimate causal effects from observational rather than experimental data, correcting for the biases that are inherent to analyzing observational data (e.g. confounding and selection bias) [@Hernán:2020]. @@ -24,11 +23,14 @@ and potentially confounding variables across your units of analysis (in addition then you can essentially simulate a proper experiment and make causal claims. -Interpretting the causal curve +Interpreting the causal curve ------------------------------ Two of the methods contained within this package produce causal curves for continuous treatments -(see the GPS and TMLE methods). +(see the GPS and TMLE methods). Both continuous and binary treatments can be modeled +(only the GPS tool can handle binary outcomes). + +**Continuous outcome:** .. image:: ../imgs/welcome_plot.png @@ -45,6 +47,18 @@ generated through standard multivariable regression modeling in a few important * This curve represents a population-level effect, and should not be used to infer effects at the individual-level (or whatever the unit of analysis is). * To generate a similar-looking plot using multivariable regression, you would have to hold covariates constant, and any treatment effect that is inferred occurs within the levels of the covariates specified in the model. The causal curve averages out across all of these strata and gives us the population marginal effect. +**Binary outcome:** + +.. image:: ../imgs/binary_OR_fig.png + +In the case of binary outcome, the GPS tool can be used to estimate a curve of odds ratio. Every +point on the curve is relative to the lowest treatment value. The highest effect (relative to the lowest treatment value) +is around a treatment value of -1.2. At this point in the treatment, the odds of a positive class +occurring is 5.6 times higher compared with the lowest treatment value. This curve is always on +the relative scale. This is why the odds ratio for the lowest point is always 1.0, because it is +relative to itself. Odds ratios are bounded [0, inf] and cannot take on a negative value. Note that +the confidence intervals at any given point in the curve isn't symmetric. + A caution about causal inference assumptions -------------------------------------------- diff --git a/docs/requirements.txt b/docs/requirements.txt index 6c001a9..aae46cd 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -15,4 +15,5 @@ pytz scikit-learn scipy six +sphinx_rtd_theme statsmodels diff --git a/imgs/binary_OR_fig.png b/imgs/binary_OR_fig.png new file mode 100644 index 0000000..f4e3b7f Binary files /dev/null and b/imgs/binary_OR_fig.png differ diff --git a/requirements.txt b/requirements.txt index 6c001a9..aae46cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,5 @@ pytz scikit-learn scipy six +sphinx_rtd_theme statsmodels diff --git a/setup.py b/setup.py index 3a04540..8d44cf9 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="causal-curve", - version="0.3.8", + version="0.4.0", author="Roni Kobrosly", author_email="roni.kobrosly@gmail.com", description="A python library with tools to perform causal inference using \ @@ -38,6 +38,7 @@ 'scikit-learn', 'scipy', 'six', + 'sphinx_rtd_theme', 'statsmodels' ] ) diff --git a/tests/conftest.py b/tests/conftest.py index 2726b79..2adede4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,16 +3,17 @@ import numpy as np import pandas as pd import pytest +from scipy.stats import norm @pytest.fixture(scope="module") -def dataset_fixture(): - """Returns full_example_dataset""" - return full_example_dataset() +def continuous_dataset_fixture(): + """Returns full_continuous_example_dataset (with a continuous outcome)""" + return full_continuous_example_dataset() -def full_example_dataset(): - """Example dataset with a treatment, two covariates, and outcome variable""" +def full_continuous_example_dataset(): + """Example dataset with a treatment, two covariates, and continuous outcome variable""" np.random.seed(500) @@ -26,7 +27,41 @@ def full_example_dataset(): fixture = pd.DataFrame( {"treatment": treatment, "x1": x_1, "x2": x_2, "outcome": outcome} ) + fixture.reset_index(drop=True, inplace=True) + + return fixture + + +@pytest.fixture(scope="module") +def binary_dataset_fixture(): + """Returns full_binary_example_dataset (with a binary outcome)""" + return full_binary_example_dataset() + + +def full_binary_example_dataset(): + """Example dataset with a treatment, two covariates, and binary outcome variable""" + np.random.seed(500) + treatment = np.linspace( + start=0, + stop=100, + num=100, + ) + x_1 = norm.rvs(size=100, loc = 50, scale = 5) + outcome = [ + 0,0,0,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,1,0,0, + 0,0,0,0,0,0,0,0,1,1, + 0,0,0,0,0,0,0,0,0,1, + 1,1,1,1,1,0,1,1,1,1, + 1,0,1,1,1,1,1,0,1,1, + 1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1, + ] + + fixture = pd.DataFrame({'treatment':treatment, 'x1': x_1, 'outcome': outcome}) fixture.reset_index(drop=True, inplace=True) return fixture diff --git a/tests/integration/test_gps.py b/tests/integration/test_gps.py index 9205868..564c8d9 100644 --- a/tests/integration/test_gps.py +++ b/tests/integration/test_gps.py @@ -5,9 +5,9 @@ from causal_curve import GPS -def test_full_gps_flow(dataset_fixture): +def test_full_continuous_gps_flow(continuous_dataset_fixture): """ - Tests the full flow of the GPS tool + Tests the full flow of the GPS tool when used with a continuous outcome """ gps = GPS( @@ -21,12 +21,40 @@ def test_full_gps_flow(dataset_fixture): verbose=True, ) gps.fit( - T=dataset_fixture["treatment"], - X=dataset_fixture[["x1", "x2"]], - y=dataset_fixture["outcome"], + T=continuous_dataset_fixture["treatment"], + X=continuous_dataset_fixture[["x1", "x2"]], + y=continuous_dataset_fixture["outcome"], ) gps_results = gps.calculate_CDRC(0.95) assert isinstance(gps_results, pd.DataFrame) - check = gps_results.columns == ["Treatment", "CDRC", "Lower_CI", "Upper_CI"] + check = gps_results.columns == ["Treatment", "Causal_Dose_Response", "Lower_CI", "Upper_CI"] + assert check.all() + + +def test_binary_continuous_gps_flow(binary_dataset_fixture): + """ + Tests the full flow of the GPS tool when used with a binary outcome + """ + + gps = GPS( + gps_family="normal", + treatment_grid_num=10, + lower_grid_constraint=0.0, + upper_grid_constraint=1.0, + spline_order=3, + n_splines=10, + max_iter=100, + random_seed=100, + verbose=True, + ) + gps.fit( + T=binary_dataset_fixture["treatment"], + X=binary_dataset_fixture["x1"], + y=binary_dataset_fixture["outcome"], + ) + gps_results = gps.calculate_CDRC(0.95) + + assert isinstance(gps_results, pd.DataFrame) + check = gps_results.columns == ["Treatment", "Causal_Odds_Ratio", "Lower_CI", "Upper_CI"] assert check.all() diff --git a/tests/integration/test_tmle.py b/tests/integration/test_tmle.py index 89f1af7..bc11b89 100644 --- a/tests/integration/test_tmle.py +++ b/tests/integration/test_tmle.py @@ -5,7 +5,7 @@ from causal_curve import TMLE -def test_full_tmle_flow(dataset_fixture): +def test_full_tmle_flow(continuous_dataset_fixture): """ Tests the full flow of the TMLE tool """ @@ -16,12 +16,12 @@ def test_full_tmle_flow(dataset_fixture): verbose=True, ) tmle.fit( - T=dataset_fixture["treatment"], - X=dataset_fixture[["x1", "x2"]], - y=dataset_fixture["outcome"], + T=continuous_dataset_fixture["treatment"], + X=continuous_dataset_fixture[["x1", "x2"]], + y=continuous_dataset_fixture["outcome"], ) tmle_results = tmle.calculate_CDRC(0.95) assert isinstance(tmle_results, pd.DataFrame) - check = tmle_results.columns == ["Treatment", "CDRC", "Lower_CI", "Upper_CI"] + check = tmle_results.columns == ["Treatment", "Causal_Dose_Response", "Lower_CI", "Upper_CI"] assert check.all() diff --git a/tests/unit/test_gps.py b/tests/unit/test_gps.py index b6a00a3..57beb1c 100644 --- a/tests/unit/test_gps.py +++ b/tests/unit/test_gps.py @@ -4,16 +4,16 @@ import pytest from causal_curve import GPS -from tests.conftest import full_example_dataset +from tests.conftest import full_continuous_example_dataset @pytest.mark.parametrize( ("df_fixture", "family"), [ - (full_example_dataset, "normal"), - (full_example_dataset, "lognormal"), - (full_example_dataset, "gamma"), - (full_example_dataset, None), + (full_continuous_example_dataset, "normal"), + (full_continuous_example_dataset, "lognormal"), + (full_continuous_example_dataset, "gamma"), + (full_continuous_example_dataset, None), ], ) def test_gps_fit(df_fixture, family): @@ -90,6 +90,9 @@ def test_bad_gps_instantiation( random_seed, verbose, ): + """ + Tests for exceptions when the GPS class if call with bad inputs. + """ with pytest.raises(Exception) as bad: GPS( gps_family=gps_family, @@ -102,3 +105,13 @@ def test_bad_gps_instantiation( random_seed=random_seed, verbose=verbose, ) + +def test_calculate_z_score(): + """ + Tests that that `_calculate_z_score` methods returns expected z-scores + """ + gps = GPS() + assert round(gps._calculate_z_score(0.99), 2) == 2.58 + assert round(gps._calculate_z_score(0.95), 2) == 1.96 + assert round(gps._calculate_z_score(0.90), 2) == 1.64 + assert round(gps._calculate_z_score(0.80), 2) == 1.28 diff --git a/tests/unit/test_tmle.py b/tests/unit/test_tmle.py index 452aa8b..f691155 100644 --- a/tests/unit/test_tmle.py +++ b/tests/unit/test_tmle.py @@ -5,7 +5,7 @@ from causal_curve import TMLE -def test_tmle_fit(dataset_fixture): +def test_tmle_fit(continuous_dataset_fixture): """ Tests the fit method GPS tool """ @@ -16,9 +16,9 @@ def test_tmle_fit(dataset_fixture): verbose=True, ) tmle.fit( - T=dataset_fixture["treatment"], - X=dataset_fixture[["x1", "x2"]], - y=dataset_fixture["outcome"], + T=continuous_dataset_fixture["treatment"], + X=continuous_dataset_fixture[["x1", "x2"]], + y=continuous_dataset_fixture["outcome"], ) assert tmle.n_obs == 72