Skip to content

Commit

Permalink
small big fixes for all classes
Browse files Browse the repository at this point in the history
  • Loading branch information
ronikobrosly committed Jul 5, 2020
1 parent 74f776a commit 985bb9e
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 86 deletions.
62 changes: 31 additions & 31 deletions causal_curve/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,29 +189,29 @@ def _validate_init_params(self):
# Checks for gps_family param
if not isinstance(self.gps_family, (str, type(None))):
raise TypeError(
f"gps_family parameter must be a string or None, \
but found type {type(self.gps_family)}"
f"gps_family parameter must be a string or None "
f"but found type {type(self.gps_family)}"
)

if (isinstance(self.gps_family, str)) and (
self.gps_family not in ["normal", "lognormal", "gamma"]
):
raise ValueError(
f"gps_family parameter must take on values of \
'normal', 'lognormal', or 'gamma', but found {self.gps_family}"
f"gps_family parameter must take on values of "
f"'normal', 'lognormal', or 'gamma', but found {self.gps_family}"
)

# Checks for treatment_grid_num
if not isinstance(self.treatment_grid_num, int):
raise TypeError(
f"treatment_grid_num parameter must be an integer, \
but found type {type(self.treatment_grid_num)}"
f"treatment_grid_num parameter must be an integer, "
f"but found type {type(self.treatment_grid_num)}"
)

if (isinstance(self.treatment_grid_num, int)) and self.treatment_grid_num < 10:
raise ValueError(
f"treatment_grid_num parameter should be >= 10 so your final curve \
has enough resolution, but found value {self.treatment_grid_num}"
f"treatment_grid_num parameter should be >= 10 so your final curve "
f"has enough resolution, but found value {self.treatment_grid_num}"
)

if (
Expand All @@ -222,47 +222,47 @@ def _validate_init_params(self):
# Checks for lower_grid_constraint
if not isinstance(self.lower_grid_constraint, float):
raise TypeError(
f"lower_grid_constraint parameter must be a float, \
but found type {type(self.lower_grid_constraint)}"
f"lower_grid_constraint parameter must be a float, "
f"but found type {type(self.lower_grid_constraint)}"
)

if (
isinstance(self.lower_grid_constraint, float)
) and self.lower_grid_constraint < 0:
raise ValueError(
f"lower_grid_constraint parameter cannot be < 0, \
but found value {self.lower_grid_constraint}"
f"lower_grid_constraint parameter cannot be < 0, "
f"but found value {self.lower_grid_constraint}"
)

if (
isinstance(self.lower_grid_constraint, float)
) and self.lower_grid_constraint >= 1.0:
raise ValueError(
f"lower_grid_constraint parameter cannot >= 1.0, \
but found value {self.lower_grid_constraint}"
f"lower_grid_constraint parameter cannot >= 1.0, "
f"but found value {self.lower_grid_constraint}"
)

# Checks for upper_grid_constraint
if not isinstance(self.upper_grid_constraint, float):
raise TypeError(
f"upper_grid_constraint parameter must be a float, \
but found type {type(self.upper_grid_constraint)}"
f"upper_grid_constraint parameter must be a float, "
f"but found type {type(self.upper_grid_constraint)}"
)

if (
isinstance(self.upper_grid_constraint, float)
) and self.upper_grid_constraint <= 0:
raise ValueError(
f"upper_grid_constraint parameter cannot be <= 0, \
but found value {self.upper_grid_constraint}"
f"upper_grid_constraint parameter cannot be <= 0, "
f"but found value {self.upper_grid_constraint}"
)

if (
isinstance(self.upper_grid_constraint, float)
) and self.upper_grid_constraint > 1.0:
raise ValueError(
f"upper_grid_constraint parameter cannot > 1.0, \
but found value {self.upper_grid_constraint}"
f"upper_grid_constraint parameter cannot > 1.0, "
f"but found value {self.upper_grid_constraint}"
)

# Checks for lower_grid_constraint isn't higher than upper_grid_constraint
Expand All @@ -274,8 +274,8 @@ def _validate_init_params(self):
# Checks for spline_order
if not isinstance(self.spline_order, int):
raise TypeError(
f"spline_order parameter must be an integer, \
but found type {type(self.spline_order)}"
f"spline_order parameter must be an integer, "
f"but found type {type(self.spline_order)}"
)

if (isinstance(self.spline_order, int)) and self.spline_order < 1:
Expand Down Expand Up @@ -361,8 +361,8 @@ def _validate_fit_data(self):
for column in self.X:
if not is_numeric_dtype(self.X[column]):
raise TypeError(
f"All covariate (X) columns must be int or float type \
(i.e. must be numeric)"
f"All covariate (X) columns must be int or float type "
f"(i.e. must be numeric)"
)

# Checks for Y column
Expand Down Expand Up @@ -399,9 +399,9 @@ def fit(self, T, X, y):
self : object
"""
self.T = T
self.X = X
self.y = y
self.T = T.reset_index(drop=True, inplace=False)
self.X = X.reset_index(drop=True, inplace=False)
self.y = y.reset_index(drop=True, inplace=False)

# Validate this input data
self._validate_fit_data()
Expand All @@ -425,8 +425,8 @@ def fit(self, T, X, y):

if self.verbose:
print(
f"Best fitting model was {self.best_gps_family}, which \
produced a deviance of {self.gps_deviance}"
f"Best fitting model was {self.best_gps_family}, which "
f"produced a deviance of {self.gps_deviance}"
)

# Otherwise, go with the what the user provided...
Expand Down Expand Up @@ -500,8 +500,8 @@ def calculate_CDRC(self, ci=0.95):

if self.verbose:
print(
f"Generating predictions for each value of treatment grid, \
and averaging to get the CDRC..."
"""Generating predictions for each value of treatment grid,
and averaging to get the CDRC..."""
)

# For each column of _cdrc_preds, calculate the mean and confidence interval bounds
Expand Down
72 changes: 36 additions & 36 deletions causal_curve/mediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ def _validate_init_params(self):
# Checks for treatment_grid_num
if not isinstance(self.treatment_grid_num, int):
raise TypeError(
f"treatment_grid_num parameter must be an integer, \
but found type {type(self.treatment_grid_num)}"
f"treatment_grid_num parameter must be an integer, "
f"but found type {type(self.treatment_grid_num)}"
)

if (isinstance(self.treatment_grid_num, int)) and self.treatment_grid_num < 4:
raise ValueError(
f"treatment_grid_num parameter should be >= 4 so the internal models \
have enough resolution, but found value {self.treatment_grid_num}"
f"treatment_grid_num parameter should be >= 4 so the internal models "
f"have enough resolution, but found value {self.treatment_grid_num}"
)

if (isinstance(self.treatment_grid_num, int)) and self.treatment_grid_num > 100:
Expand All @@ -165,89 +165,89 @@ def _validate_init_params(self):
# Checks for lower_grid_constraint
if not isinstance(self.lower_grid_constraint, float):
raise TypeError(
f"lower_grid_constraint parameter must be a float, \
but found type {type(self.lower_grid_constraint)}"
f"lower_grid_constraint parameter must be a float, "
f"but found type {type(self.lower_grid_constraint)}"
)

if (
isinstance(self.lower_grid_constraint, float)
) and self.lower_grid_constraint < 0:
raise ValueError(
f"lower_grid_constraint parameter cannot be < 0, \
but found value {self.lower_grid_constraint}"
f"lower_grid_constraint parameter cannot be < 0, "
f"but found value {self.lower_grid_constraint}"
)

if (
isinstance(self.lower_grid_constraint, float)
) and self.lower_grid_constraint >= 1.0:
raise ValueError(
f"lower_grid_constraint parameter cannot >= 1.0, \
but found value {self.lower_grid_constraint}"
f"lower_grid_constraint parameter cannot >= 1.0, "
f"but found value {self.lower_grid_constraint}"
)

# Checks for upper_grid_constraint
if not isinstance(self.upper_grid_constraint, float):
raise TypeError(
f"upper_grid_constraint parameter must be a float, \
but found type {type(self.upper_grid_constraint)}"
f"upper_grid_constraint parameter must be a float, "
f"but found type {type(self.upper_grid_constraint)}"
)

if (
isinstance(self.upper_grid_constraint, float)
) and self.upper_grid_constraint <= 0:
raise ValueError(
f"upper_grid_constraint parameter cannot be <= 0, \
but found value {self.upper_grid_constraint}"
f"upper_grid_constraint parameter cannot be <= 0, "
f"but found value {self.upper_grid_constraint}"
)

if (
isinstance(self.upper_grid_constraint, float)
) and self.upper_grid_constraint > 1.0:
raise ValueError(
f"upper_grid_constraint parameter cannot > 1.0, \
but found value {self.upper_grid_constraint}"
f"upper_grid_constraint parameter cannot > 1.0, "
f"but found value {self.upper_grid_constraint}"
)

# Checks for bootstrap_draws
if not isinstance(self.bootstrap_draws, int):
raise TypeError(
f"bootstrap_draws parameter must be a int, \
but found type {type(self.bootstrap_draws)}"
f"bootstrap_draws parameter must be a int, "
f"but found type {type(self.bootstrap_draws)}"
)

if (isinstance(self.bootstrap_draws, int)) and self.bootstrap_draws < 100:
raise ValueError(
f"bootstrap_draws parameter cannot be < 100, \
but found value {self.bootstrap_draws}"
f"bootstrap_draws parameter cannot be < 100, "
f"but found value {self.bootstrap_draws}"
)

if (isinstance(self.bootstrap_draws, int)) and self.bootstrap_draws > 500000:
raise ValueError(
f"bootstrap_draws parameter cannot > 500000, \
but found value {self.bootstrap_draws}"
f"bootstrap_draws parameter cannot > 500000, "
f"but found value {self.bootstrap_draws}"
)

# Checks for bootstrap_replicates
if not isinstance(self.bootstrap_replicates, int):
raise TypeError(
f"bootstrap_replicates parameter must be a int, \
but found type {type(self.bootstrap_replicates)}"
f"bootstrap_replicates parameter must be a int, "
f"but found type {type(self.bootstrap_replicates)}"
)

if (
isinstance(self.bootstrap_replicates, int)
) and self.bootstrap_replicates < 50:
raise ValueError(
f"bootstrap_replicates parameter cannot be < 50, \
but found value {self.bootstrap_replicates}"
f"bootstrap_replicates parameter cannot be < 50, "
f"but found value {self.bootstrap_replicates}"
)

if (
isinstance(self.bootstrap_replicates, int)
) and self.bootstrap_replicates > 100000:
raise ValueError(
f"bootstrap_replicates parameter cannot > 100000, \
but found value {self.bootstrap_replicates}"
f"bootstrap_replicates parameter cannot > 100000, "
f"but found value {self.bootstrap_replicates}"
)

# Checks for lower_grid_constraint isn't higher than upper_grid_constraint
Expand All @@ -259,8 +259,8 @@ def _validate_init_params(self):
# Checks for spline_order
if not isinstance(self.spline_order, int):
raise TypeError(
f"spline_order parameter must be an integer, \
but found type {type(self.spline_order)}"
f"spline_order parameter must be an integer, "
f"but found type {type(self.spline_order)}"
)

if (isinstance(self.spline_order, int)) and self.spline_order < 3:
Expand Down Expand Up @@ -394,9 +394,9 @@ def fit(self, T, M, y):
self : object
"""
self.T = T
self.M = M
self.y = y
self.T = T.reset_index(drop=True, inplace=False)
self.M = M.reset_index(drop=True, inplace=False)
self.y = y.reset_index(drop=True, inplace=False)

# Validate this input data
self._validate_fit_data()
Expand Down Expand Up @@ -504,7 +504,7 @@ def calculate_mediation(self, ci=0.95):
for i in range(0, 1000):
bootstrap_overall_means.append(
general_indirect.sample(
frac=0.25, replace=True, random_state=self.random_seed
frac=0.25, replace=True
).mean()
)

Expand All @@ -519,7 +519,7 @@ def calculate_mediation(self, ci=0.95):
}
)
.round(4)
.clip(lower=0)
.clip(lower=0, upper=1.0)
)

total_prop_mean = round(np.array(self.prop_indirect_list).mean(), 4)
Expand Down Expand Up @@ -586,7 +586,7 @@ def _create_bootstrap_replicate(self):
"""Creates a single bootstrap replicate from the data
"""
temp_t = self.T.sample(
n=self.bootstrap_draws, replace=True, random_state=self.random_seed
n=self.bootstrap_draws, replace=True
)
temp_m = self.M.iloc[temp_t.index]
temp_y = self.y.iloc[temp_t.index]
Expand Down
Loading

0 comments on commit 985bb9e

Please sign in to comment.