Skip to content

Commit

Permalink
Merge pull request #18 from ipanepen/rle-bug
Browse files Browse the repository at this point in the history
fix: makes fit_transform behavior consistent with fit and transform
  • Loading branch information
ipanepen authored Jun 24, 2020
2 parents 8db46be + 9d683b2 commit dfbfa58
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
9 changes: 1 addition & 8 deletions src/sagemaker_sklearn_extension/preprocessing/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,6 @@ def _is_sorted(self, iterable):
def fit_transform(self, y):
"""Fit label encoder and return encoded labels.
``fill_unseen_labels=True`` does nothing in ``fit_transform`` because there will be no unseen labels.
Parameters
----------
y : array-like of shape [n_samples]
Expand All @@ -290,12 +288,7 @@ def fit_transform(self, y):
y_encoded : array-like of shape [n_samples]
Encoded label values.
"""
y = column_or_1d(y, warn=True)
sorted_labels = self._check_labels_and_sort()
self.classes_, y_encoded = (
_encode(y, uniques=sorted_labels, encode=True) if sorted_labels else _encode(y, encode=True)
)
return y_encoded
return self.fit(y).transform(y)

def transform(self, y):
"""Transform labels to normalized encoding.
Expand Down
30 changes: 30 additions & 0 deletions test/test_preprocessing_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ def test_robust_label_encoder_sorted_labels(labels):
assert_array_equal(list(enc.classes_), labels)
assert_array_equal(enc.transform([labels[2], labels[1], "173"]), [2, 1, 3])

# Test that fit_transform has the same behavior
enc = RobustLabelEncoder(labels=labels)
y_transformed = enc.fit_transform([labels[2], labels[1], "173"])

assert_array_equal(list(enc.classes_), labels)
assert_array_equal(y_transformed, [2, 1, 3])


@pytest.mark.parametrize("labels", (["-12", "9", "3"], ["-12.", "9.", "3."]))
def test_robust_label_encoder_unsorted_labels_warning(labels):
Expand All @@ -158,6 +165,29 @@ def test_robust_label_encoder_unsorted_labels_warning(labels):
assert_array_equal(list(enc.classes_), sorted(labels))
assert_array_equal(enc.transform([labels[1], labels[2], "173"]), [2, 1, 3])

# Test that fit_transform has the same behavior
enc = RobustLabelEncoder(labels=labels)
with pytest.warns(UserWarning):
y_transformed = enc.fit_transform([labels[1], labels[2], "173"])

assert_array_equal(list(enc.classes_), sorted(labels))
assert_array_equal(y_transformed, [2, 1, 3])


def test_robust_label_encoder_fill_label_value():
y = np.array([1, 1, 0, 1, 1])
enc = RobustLabelEncoder(labels=[1], fill_label_value=0)
enc.fit(y)
y_transform = enc.transform(y)
assert_array_equal(y_transform, [0, 0, 1, 0, 0])
assert_array_equal(enc.inverse_transform(y_transform), y)

# Test that fit_transform has the same behavior
enc = RobustLabelEncoder(labels=[1], fill_label_value=0)
y_transform = enc.fit_transform(y)
assert_array_equal(y_transform, [0, 0, 1, 0, 0])
assert_array_equal(enc.inverse_transform(y_transform), y)


@pytest.mark.parametrize(
"y, y_expected",
Expand Down

0 comments on commit dfbfa58

Please sign in to comment.