From df94f0d15868a60e48e3d9372cf3e48e303caadd Mon Sep 17 00:00:00 2001 From: Roni Kobrosly Date: Sun, 5 Jul 2020 08:55:08 -0500 Subject: [PATCH] ran black formatter --- causal_curve/gps.py | 2 +- causal_curve/mediation.py | 8 ++------ causal_curve/tmle.py | 4 ++-- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/causal_curve/gps.py b/causal_curve/gps.py index 5946f4f..c01d325 100644 --- a/causal_curve/gps.py +++ b/causal_curve/gps.py @@ -190,7 +190,7 @@ def _validate_init_params(self): if not isinstance(self.gps_family, (str, type(None))): raise TypeError( f"gps_family parameter must be a string or None " - f"but found type {type(self.gps_family)}" + f"but found type {type(self.gps_family)}" ) if (isinstance(self.gps_family, str)) and ( diff --git a/causal_curve/mediation.py b/causal_curve/mediation.py index 8d85eb8..1b71fb7 100644 --- a/causal_curve/mediation.py +++ b/causal_curve/mediation.py @@ -503,9 +503,7 @@ def calculate_mediation(self, ci=0.95): bootstrap_overall_means = [] for i in range(0, 1000): bootstrap_overall_means.append( - general_indirect.sample( - frac=0.25, replace=True - ).mean() + general_indirect.sample(frac=0.25, replace=True).mean() ) bootstrap_overall_means = np.array(bootstrap_overall_means) @@ -585,9 +583,7 @@ def _bootstrap_analysis(self, temp_low_treatment, temp_high_treatment): def _create_bootstrap_replicate(self): """Creates a single bootstrap replicate from the data """ - temp_t = self.T.sample( - n=self.bootstrap_draws, replace=True - ) + temp_t = self.T.sample(n=self.bootstrap_draws, replace=True) temp_m = self.M.iloc[temp_t.index] temp_y = self.y.iloc[temp_t.index] diff --git a/causal_curve/tmle.py b/causal_curve/tmle.py index 636a5da..1974446 100644 --- a/causal_curve/tmle.py +++ b/causal_curve/tmle.py @@ -145,14 +145,14 @@ def _validate_init_params(self): if not isinstance(self.treatment_grid_bins, list): raise TypeError( f"treatment_grid_bins parameter must be a list, " - f"but found type {type(self.treatment_grid_bins)}" + f"but found type {type(self.treatment_grid_bins)}" ) for element in self.treatment_grid_bins: if not isinstance(element, (int, float)): raise TypeError( f"'{element}' in `treatment_grid_bins` list is not of type float or int, " - f"it is {type(element)}" + f"it is {type(element)}" ) if len(self.treatment_grid_bins) < 2: