diff --git a/src/sagemaker_sklearn_extension/preprocessing/encoders.py b/src/sagemaker_sklearn_extension/preprocessing/encoders.py index 46383c1..edfb5dc 100644 --- a/src/sagemaker_sklearn_extension/preprocessing/encoders.py +++ b/src/sagemaker_sklearn_extension/preprocessing/encoders.py @@ -278,8 +278,6 @@ def _is_sorted(self, iterable): def fit_transform(self, y): """Fit label encoder and return encoded labels. - ``fill_unseen_labels=True`` does nothing in ``fit_transform`` because there will be no unseen labels. - Parameters ---------- y : array-like of shape [n_samples] @@ -290,12 +288,7 @@ def fit_transform(self, y): y_encoded : array-like of shape [n_samples] Encoded label values. """ - y = column_or_1d(y, warn=True) - sorted_labels = self._check_labels_and_sort() - self.classes_, y_encoded = ( - _encode(y, uniques=sorted_labels, encode=True) if sorted_labels else _encode(y, encode=True) - ) - return y_encoded + return self.fit(y).transform(y) def transform(self, y): """Transform labels to normalized encoding. diff --git a/test/test_preprocessing_encoders.py b/test/test_preprocessing_encoders.py index 0a027e8..73cd89a 100644 --- a/test/test_preprocessing_encoders.py +++ b/test/test_preprocessing_encoders.py @@ -148,6 +148,13 @@ def test_robust_label_encoder_sorted_labels(labels): assert_array_equal(list(enc.classes_), labels) assert_array_equal(enc.transform([labels[2], labels[1], "173"]), [2, 1, 3]) + # Test that fit_transform has the same behavior + enc = RobustLabelEncoder(labels=labels) + y_transformed = enc.fit_transform([labels[2], labels[1], "173"]) + + assert_array_equal(list(enc.classes_), labels) + assert_array_equal(y_transformed, [2, 1, 3]) + @pytest.mark.parametrize("labels", (["-12", "9", "3"], ["-12.", "9.", "3."])) def test_robust_label_encoder_unsorted_labels_warning(labels): @@ -158,6 +165,29 @@ def test_robust_label_encoder_unsorted_labels_warning(labels): assert_array_equal(list(enc.classes_), sorted(labels)) assert_array_equal(enc.transform([labels[1], labels[2], "173"]), [2, 1, 3]) + # Test that fit_transform has the same behavior + enc = RobustLabelEncoder(labels=labels) + with pytest.warns(UserWarning): + y_transformed = enc.fit_transform([labels[1], labels[2], "173"]) + + assert_array_equal(list(enc.classes_), sorted(labels)) + assert_array_equal(y_transformed, [2, 1, 3]) + + +def test_robust_label_encoder_fill_label_value(): + y = np.array([1, 1, 0, 1, 1]) + enc = RobustLabelEncoder(labels=[1], fill_label_value=0) + enc.fit(y) + y_transform = enc.transform(y) + assert_array_equal(y_transform, [0, 0, 1, 0, 0]) + assert_array_equal(enc.inverse_transform(y_transform), y) + + # Test that fit_transform has the same behavior + enc = RobustLabelEncoder(labels=[1], fill_label_value=0) + y_transform = enc.fit_transform(y) + assert_array_equal(y_transform, [0, 0, 1, 0, 0]) + assert_array_equal(enc.inverse_transform(y_transform), y) + @pytest.mark.parametrize( "y, y_expected",