Skip to content

Commit

Permalink
fixed negative treatment bug (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronikobrosly authored Oct 12, 2020
1 parent 8bff2e9 commit 5d35451
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ env:

before_install:
# Here we download miniconda and install the dependencies
- pip install black coverage future joblib numpy numpydoc pandas patsy progressbar2 pygam pytest python-dateutil python-utils pytz scikit-learn scipy six statsmodels
- pip install black coverage future joblib numpy numpydoc pandas patsy progressbar2 pygam pytest python-dateutil python-utils pytz scikit-learn scipy six sphinx_rtd_theme statsmodels

install:
- python setup.py install
Expand Down
3 changes: 3 additions & 0 deletions causal_curve/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Core classes (with basic methods) that will be invoked when other, model classes are defined
"""
import pkg_resources


class Core:
Expand All @@ -24,3 +25,5 @@ def get_params(self):
return dict(
[(k, v) for k, v in list(attrs.items()) if (k[0] != "_") and (k[-1] != "_")]
)

__version__ = pkg_resources.require("causal-curve")[0].version
91 changes: 54 additions & 37 deletions causal_curve/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,43 +429,8 @@ def fit(self, T, X, y):
# Create grid_values
self.grid_values = self._grid_values()

# Estimating the GPS
self.best_gps_family = self.gps_family

# If no family specified, pick the best family
if self.gps_family == None:
if self.verbose:
print(f"Fitting several GPS models and picking the best fitting one...")

(
self.best_gps_family,
self.gps_function,
self.gps_deviance,
) = self._find_best_gps_model()

if self.verbose:
print(
f"Best fitting model was {self.best_gps_family}, which "
f"produced a deviance of {self.gps_deviance}"
)

# Otherwise, go with the what the user provided...
else:
if self.verbose:
print(f"Fitting GPS model of family '{self.best_gps_family}'...")

if self.best_gps_family == "normal":
(
self.gps_function,
self.gps_deviance,
) = self._create_normal_gps_function()
elif self.best_gps_family == "lognormal":
(
self.gps_function,
self.gps_deviance,
) = self._create_lognormal_gps_function()
elif self.best_gps_family == "gamma":
self.gps_function, self.gps_deviance = self._create_gamma_gps_function()
# Determine which GPS family to use
self._determine_gps_function()

# Estimate the GPS
if self.verbose:
Expand Down Expand Up @@ -710,6 +675,58 @@ def _fit_gam(self):
lam=self.lambda_,
).fit(X, y)

def _determine_gps_function(self):
"""Based on the user input, distribution of treatment values, and/or model deviances,
this function determines which GPS function family should be used.
"""

# If any negative values in treatment, you must use the normal GLM family.
if any(self.T <= 0):
self.best_gps_family = "normal"
self.gps_function, self.gps_deviance = self._create_normal_gps_function()
if self.verbose:
print(
f"Must fit `normal` GLM family to model treatment since treatment takes on zero or negative values..."
)

# If treatment has no negative values and user provides in put, use that.
elif (all(self.T > 0)) & (not isinstance(self.gps_family, type(None))):
if self.verbose:
print(f"Fitting GPS model of family '{self.gps_family}'...")

if self.gps_family == "normal":
self.best_gps_family = "normal"
(
self.gps_function,
self.gps_deviance,
) = self._create_normal_gps_function()
elif self.gps_family == "lognormal":
self.best_gps_family = "lognormal"
(
self.gps_function,
self.gps_deviance,
) = self._create_lognormal_gps_function()
elif self.gps_family == "gamma":
self.best_gps_family = "gamma"
self.gps_function, self.gps_deviance = self._create_gamma_gps_function()

# If no zero or negative treatment values and user didn't provide input, figure out best-fitting family
elif (all(self.T > 0)) & (isinstance(self.gps_family, type(None))):
if self.verbose:
print(f"Fitting several GPS models and picking the best fitting one...")

(
self.best_gps_family,
self.gps_function,
self.gps_deviance,
) = self._find_best_gps_model()

if self.verbose:
print(
f"Best fitting model was {self.best_gps_family}, which "
f"produced a deviance of {self.gps_deviance}"
)

def _create_normal_gps_function(self):
"""Models the GPS using a GLM of the Gaussian family"""
normal_gps_model = sm.GLM(
Expand Down
8 changes: 8 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
Change Log
==========


Version 0.4.1
-------------
- When using GPS tool with a treatment with negative values, only the normal GLM family can be picked
- Added 'sphinx_rtd_theme' to dependency list in `.travis.yml` and `install.rst`
- core.py base class now has __version__ attribute


Version 0.4.0
-------------
- Added support for binary outcomes in GPS tool
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = 'Roni Kobrosly'

# The full version, including alpha/beta/rc tags
release = '0.4.0'
release = '0.4.1'

# -- General configuration ---------------------------------------------------

Expand Down
1 change: 1 addition & 0 deletions docs/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ causal-curve requires:
- scikit-learn
- scipy
- six
- sphinx_rtd_theme
- statsmodels


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="causal-curve",
version="0.4.0",
version="0.4.1",
author="Roni Kobrosly",
author_email="[email protected]",
description="A python library with tools to perform causal inference using \
Expand Down

0 comments on commit 5d35451

Please sign in to comment.