From 9008b90c6bc6ae01c2377fcaec27a025a3c45146 Mon Sep 17 00:00:00 2001 From: Yotam Elor Date: Wed, 25 Mar 2020 22:29:04 -0400 Subject: [PATCH] fix: fix a minor bug in OneHotEncoder by by overloading the buggy method in ThresholdOneHotEncoder and fixing it --- .../preprocessing/encoders.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/sagemaker_sklearn_extension/preprocessing/encoders.py b/src/sagemaker_sklearn_extension/preprocessing/encoders.py index 306dc5a..46383c1 100644 --- a/src/sagemaker_sklearn_extension/preprocessing/encoders.py +++ b/src/sagemaker_sklearn_extension/preprocessing/encoders.py @@ -153,6 +153,19 @@ def fit(self, X, y=None): return self + def fit_transform(self, X, y=None): + # This method is overloaded here in order to fix a minor bug in OneHotEncoder. See the last line of this method + self._validate_keywords() + + self._handle_deprecations(X) + + if self._legacy_mode: + return super()._transform_selected( + X, self._legacy_fit_transform, self.dtype, self._categorical_features, copy=True + ) + # The y was added to fit. In OneHotEncoder the next line is: return self.fit(X).transform(X) + return self.fit(X, y).transform(X) + def _more_tags(self): return {"X_types": ["categorical"]}