diff --git a/src/sagemaker_sklearn_extension/preprocessing/encoders.py b/src/sagemaker_sklearn_extension/preprocessing/encoders.py index 306dc5a..46383c1 100644 --- a/src/sagemaker_sklearn_extension/preprocessing/encoders.py +++ b/src/sagemaker_sklearn_extension/preprocessing/encoders.py @@ -153,6 +153,19 @@ def fit(self, X, y=None): return self + def fit_transform(self, X, y=None): + # This method is overloaded here in order to fix a minor bug in OneHotEncoder. See the last line of this method + self._validate_keywords() + + self._handle_deprecations(X) + + if self._legacy_mode: + return super()._transform_selected( + X, self._legacy_fit_transform, self.dtype, self._categorical_features, copy=True + ) + # The y was added to fit. In OneHotEncoder the next line is: return self.fit(X).transform(X) + return self.fit(X, y).transform(X) + def _more_tags(self): return {"X_types": ["categorical"]}