diff --git a/causal_curve/gps_regressor.py b/causal_curve/gps_regressor.py index 90a1e8a..97d4c5f 100644 --- a/causal_curve/gps_regressor.py +++ b/causal_curve/gps_regressor.py @@ -117,7 +117,7 @@ def _create_point_estimate(self, T): in the case of a continuous outcome. """ return self.gam_results.predict( - np.array([T, self.gps_function(T).mean()]).reshape(1, -1) + np.array([T[0], self.gps_function(T).mean()]).reshape(1, -1) ) def point_estimate_interval(self, T, ci=0.95): @@ -154,5 +154,5 @@ def _create_point_estimate_interval(self, T, width): associated with a point estimate in the case of a continuous outcome. """ return self.gam_results.prediction_intervals( - np.array([T, self.gps_function(T).mean()]).reshape(1, -1), width=width + np.array([T[0], self.gps_function(T).mean()]).reshape(1, -1), width=width )