Skip to content

Commit

Permalink
change: minor performance optimizations and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
wiltonwu committed Aug 12, 2020
1 parent 1eb4e4c commit 98cf73b
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 133 deletions.
116 changes: 57 additions & 59 deletions src/sagemaker_sklearn_extension/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,48 @@ def __init__(
self.vocabulary_sizes = vocabulary_sizes
self.ignore_columns_with_zero_vocabulary_size = ignore_columns_with_zero_vocabulary_size

def _fit_vectorizer(self, col_idx, X):
max_features = self.max_features

# Override max_features for the current column in order to enforce the vocabulary size.
if self.max_features and self.vocabulary_sizes:
max_features = min(self.max_features, self.vocabulary_sizes[col_idx])
elif self.vocabulary_sizes:
max_features = self.vocabulary_sizes[col_idx]

try:
vectorizer = TfidfVectorizer(
strip_accents=self.strip_accents,
lowercase=self.lowercase,
preprocessor=self.preprocessor,
tokenizer=self.tokenizer,
stop_words=self.stop_words,
token_pattern=self.token_pattern,
ngram_range=self.ngram_range,
analyzer=self.analyzer,
max_df=self.max_df,
min_df=self.min_df,
max_features=max_features,
vocabulary=self.vocabulary,
dtype=self.dtype,
norm=self.norm,
use_idf=self.use_idf,
smooth_idf=self.smooth_idf,
sublinear_tf=self.sublinear_tf,
)
vectorizer.fit(X[:, col_idx])
except ValueError as err:
zero_vocab_errors = [
"After pruning, no terms remain. Try a lower min_df or a higher max_df.",
"max_df corresponds to < documents than min_df",
"empty vocabulary; perhaps the documents only contain stop words",
]
if str(err) in zero_vocab_errors and self.ignore_columns_with_zero_vocabulary_size:
vectorizer = None
else:
raise
return vectorizer

def fit(self, X, y=None):
"""Build the list of TfidfVectorizers for each column.
Expand All @@ -198,52 +240,23 @@ def fit(self, X, y=None):
if self.vocabulary_sizes and len(self.vocabulary_sizes) != n_columns:
raise ValueError("If specified, vocabulary_sizes has to have exactly one entry per data column.")

self.vectorizers_ = []
for col_idx in range(n_columns):
max_features = self.max_features

# Override max_features for the current column in order to enforce the vocabulary size.
if self.max_features and self.vocabulary_sizes:
max_features = min(self.max_features, self.vocabulary_sizes[col_idx])
elif self.vocabulary_sizes:
max_features = self.vocabulary_sizes[col_idx]

try:
vectorizer = TfidfVectorizer(
strip_accents=self.strip_accents,
lowercase=self.lowercase,
preprocessor=self.preprocessor,
tokenizer=self.tokenizer,
stop_words=self.stop_words,
token_pattern=self.token_pattern,
ngram_range=self.ngram_range,
analyzer=self.analyzer,
max_df=self.max_df,
min_df=self.min_df,
max_features=max_features,
vocabulary=self.vocabulary,
dtype=self.dtype,
norm=self.norm,
use_idf=self.use_idf,
smooth_idf=self.smooth_idf,
sublinear_tf=self.sublinear_tf,
)
vectorizer.fit(X[:, col_idx])
except ValueError as err:
zero_vocab_errors = [
"After pruning, no terms remain. Try a lower min_df or a higher max_df.",
"max_df corresponds to < documents than min_df",
"empty vocabulary; perhaps the documents only contain stop words",
]
if str(err) in zero_vocab_errors and self.ignore_columns_with_zero_vocabulary_size:
vectorizer = None
else:
raise

self.vectorizers_.append(vectorizer)
self.vectorizers_ = [self._fit_vectorizer(i, X) for i in range(n_columns)]

return self

def _transform_vectorizer(self, col_idx, X):
if self.vectorizers_[col_idx]:
tfidf_features = self.vectorizers_[col_idx].transform(X[:, col_idx])
# If the vocabulary size is specified and there are too few features, then pad the output with zeros.
if self.vocabulary_sizes and tfidf_features.shape[1] < self.vocabulary_sizes[col_idx]:
tfidf_features = sp.csr_matrix(
(tfidf_features.data, tfidf_features.indices, tfidf_features.indptr),
shape=(tfidf_features.shape[0], self.vocabulary_sizes[col_idx]),
)
return tfidf_features
# If ``TfidfVectorizer`` threw a value error, add an empty TF-IDF document-term matrix for the column
return sp.csr_matrix((X.shape[0], 0))

def transform(self, X, y=None):
"""Transform documents to document term-matrix.
Expand All @@ -259,22 +272,7 @@ def transform(self, X, y=None):
check_is_fitted(self, "vectorizers_")
X = check_array(X, dtype=None)

ret = []
for col_idx in range(X.shape[1]):
if self.vectorizers_[col_idx]:
tfidf_features = self.vectorizers_[col_idx].transform(X[:, col_idx])
# If the vocabulary size is specified and there are too few features, then pad the output with zeros.
if self.vocabulary_sizes and tfidf_features.shape[1] < self.vocabulary_sizes[col_idx]:
tfidf_features = sp.csr_matrix(
(tfidf_features.data, tfidf_features.indices, tfidf_features.indptr),
shape=(tfidf_features.shape[0], self.vocabulary_sizes[col_idx]),
)
else:
# If ``TfidfVectorizer`` threw a value error, add an empty TF-IDF document-term matrix for the column
tfidf_features = sp.csr_matrix((X.shape[0], 0))
ret.append(tfidf_features)

return sp.hstack(ret)
return sp.hstack([self._transform_vectorizer(i, X) for i in range(X.shape[1])])

def _more_tags(self):
return {"X_types": ["string"]}
1 change: 0 additions & 1 deletion src/sagemaker_sklearn_extension/preprocessing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ def fit(self, X, y=None):
self : QuantileExtremeValueTransformer
"""
super().fit(X)
X = check_array(X)
self.quantile_transformer_ = QuantileTransformer(random_state=0, copy=True)
self.quantile_transformer_.fit(X)
return self
Expand Down
4 changes: 1 addition & 3 deletions src/sagemaker_sklearn_extension/preprocessing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,7 @@ def fit(self, X, y=None):
X, accept_sparse=("csr", "csc"), estimator=self, dtype=FLOAT_DTYPES, force_all_finite="allow-nan"
)

with_mean = True
if issparse(X):
with_mean = False
with_mean = not issparse(X)

self.scaler_ = StandardScaler(with_mean=with_mean, with_std=True, copy=self.copy)
self.scaler_.fit(X)
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker_sklearn_extension/preprocessing/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def fit(self, X, y=None):
super().fit(X, y)
assert self.max_categories >= 1

_, n_samples, n_features = self._check_X(X)
n_samples, n_features = X.shape

if not self.threshold:
threshold = max(10, n_samples / 1000)
Expand Down Expand Up @@ -356,7 +356,7 @@ def inverse_transform(self, y):
labels = np.arange(len(self.classes_))
diff = np.setdiff1d(y, labels)

if diff and not self.fill_unseen_labels:
if diff.size > 0 and not self.fill_unseen_labels:
raise ValueError("y contains previously unseen labels: %s" % str(diff))

y_decoded = [self.classes_[idx] if idx in labels else self.fill_label_value for idx in y]
Expand Down Expand Up @@ -513,7 +513,7 @@ class RobustOrdinalEncoder(OrdinalEncoder):
"""

def __init__(self, categories="auto", dtype=np.float32, unknown_as_nan=False):
super(RobustOrdinalEncoder, self).__init__(categories, dtype)
super(RobustOrdinalEncoder, self).__init__(categories=categories, dtype=dtype)
self.categories = categories
self.dtype = dtype
self.unknown_as_nan = unknown_as_nan
Expand Down
24 changes: 12 additions & 12 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@
@pytest.mark.parametrize(
"Estimator",
[
DateTimeVectorizer,
LogExtremeValuesTransformer,
MultiColumnTfidfVectorizer,
NALabelEncoder,
QuadraticFeatures,
QuantileExtremeValuesTransformer,
RobustImputer,
RemoveConstantColumnsTransformer,
RobustLabelEncoder,
RobustMissingIndicator,
RobustStandardScaler,
ThresholdOneHotEncoder,
DateTimeVectorizer(),
LogExtremeValuesTransformer(),
MultiColumnTfidfVectorizer(),
NALabelEncoder(),
QuadraticFeatures(),
QuantileExtremeValuesTransformer(),
RobustImputer(),
RemoveConstantColumnsTransformer(),
RobustLabelEncoder(),
RobustMissingIndicator(),
RobustStandardScaler(),
ThresholdOneHotEncoder(),
],
)
def test_all_estimators(Estimator):
Expand Down
11 changes: 5 additions & 6 deletions test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from scipy.sparse import csr_matrix, issparse

from sagemaker_sklearn_extension.preprocessing import QuadraticFeatures, RobustStandardScaler
from sklearn.utils.testing import assert_array_almost_equal, assert_array_equal


def _n_choose_2(n):
Expand Down Expand Up @@ -50,7 +49,7 @@ def test_quadratic_features_explicit():
(X_standardized[:, 0] * X_standardized[:, 1]).reshape((-1, 1)),
]
)
assert_array_equal(X_observed, X_expected)
np.testing.assert_array_equal(X_observed, X_expected)


def test_quadratic_features_max_n_features():
Expand Down Expand Up @@ -107,25 +106,25 @@ def test_quadratic_features_single_column_input_explicit():
"""Test that using a single-column matrix as input produces the expected output."""
X_observed = QuadraticFeatures().fit_transform(X_standardized[:, 0].reshape((-1, 1)))
X_expected = np.hstack([X_standardized[:, [0]], (X_standardized[:, 0] * X_standardized[:, 0]).reshape((-1, 1)),])
assert_array_equal(X_observed, X_expected)
np.testing.assert_array_equal(X_observed, X_expected)


def test_robust_standard_scaler_dense():
scaler = RobustStandardScaler()
X_observed = scaler.fit_transform(X)

assert_array_equal(X_observed, X_standardized)
np.testing.assert_array_equal(X_observed, X_standardized)


def test_robust_standard_scaler_sparse():
scaler = RobustStandardScaler()
X_observed = scaler.fit_transform(X_sparse)

assert issparse(X_observed)
assert_array_almost_equal(X_observed.toarray(), X / np.std(X, axis=0))
np.testing.assert_array_almost_equal(X_observed.toarray(), X / np.std(X, axis=0))


def test_robust_standard_dense_with_low_nnz_columns():
scaler = RobustStandardScaler()
X_observed = scaler.fit_transform(X_low_nnz)
assert_array_almost_equal(X_observed, X_low_nnz_standardized)
np.testing.assert_array_almost_equal(X_observed, X_low_nnz_standardized)
9 changes: 4 additions & 5 deletions test/test_date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import numpy as np
import pytest

from sklearn.utils.testing import assert_array_equal
from dateutil import parser

from sagemaker_sklearn_extension.feature_extraction.date_time import DateTimeVectorizer, DateTimeDefinition
Expand Down Expand Up @@ -136,10 +135,10 @@ def test_transform_categorical():
assert np.all(output >= 0)

loc_year = extract_keys.index("YEAR")
assert_array_equal(output[:, loc_year], np.array([2012, 2011, 2012, 2012, 2012, 2018]))
np.testing.assert_array_equal(output[:, loc_year], np.array([2012, 2011, 2012, 2012, 2012, 2018]))

loc_month = extract_keys.index("MONTH")
assert_array_equal(output[:, loc_month], np.array([0, 1, 0, 11, 0, 0]))
np.testing.assert_array_equal(output[:, loc_month], np.array([0, 1, 0, 11, 0, 0]))


def test_transform_cyclic_leaves_year():
Expand All @@ -152,7 +151,7 @@ def test_transform_cyclic_leaves_year():

loc_year = extract_keys.index("YEAR")
loc_year *= 2
assert_array_equal(output[:, loc_year], np.array([2012, 2011, 2012, 2012, 2012, 2018]))
np.testing.assert_array_equal(output[:, loc_year], np.array([2012, 2011, 2012, 2012, 2012, 2018]))

assert output.shape[1] == len(extract) * 2 - 1

Expand All @@ -166,7 +165,7 @@ def test_fit_transform_cyclic_leaves_year():

loc_year = extract_keys.index("YEAR")
loc_year *= 2
assert_array_equal(output[:, loc_year], np.array([2012, 2011, 2012, 2012, 2012, 2018]))
np.testing.assert_array_equal(output[:, loc_year], np.array([2012, 2011, 2012, 2012, 2012, 2018]))

assert output.shape[1] == len(dtv.extract_) * 2 - 1

Expand Down
12 changes: 5 additions & 7 deletions test/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import numpy as np
import pytest

from sklearn.utils.testing import assert_array_equal, assert_array_almost_equal

from sagemaker_sklearn_extension.preprocessing import (
LogExtremeValuesTransformer,
QuantileExtremeValuesTransformer,
Expand Down Expand Up @@ -75,7 +73,7 @@ def test_remove_constant_columns_transformer(X, X_expected):
transformer = RemoveConstantColumnsTransformer()
X_observed = transformer.fit_transform(X)

assert_array_equal(X_observed, X_expected)
np.testing.assert_array_equal(X_observed, X_expected)


@pytest.mark.parametrize(
Expand All @@ -91,15 +89,15 @@ def test_log_extreme_value_transformer(X, X_expected):
transformer = LogExtremeValuesTransformer(threshold_std=2.0)
X_observed = transformer.fit_transform(X)

assert_array_almost_equal(X_observed, X_expected)
np.testing.assert_array_almost_equal(X_observed, X_expected)


def test_log_extreme_value_transformer_state():
t = LogExtremeValuesTransformer(threshold_std=2.0)
X_observed = t.fit_transform(X_extreme_vals)

assert_array_almost_equal(t.nonnegative_cols_, [1, 2])
assert_array_almost_equal(X_observed, X_log_extreme_vals)
np.testing.assert_array_almost_equal(t.nonnegative_cols_, [1, 2])
np.testing.assert_array_almost_equal(X_observed, X_log_extreme_vals)


@pytest.mark.parametrize(
Expand All @@ -110,4 +108,4 @@ def test_extreme_value_transformer(X, X_expected):
transformer = QuantileExtremeValuesTransformer(threshold_std=2.0)
X_observed = transformer.fit_transform(X)

assert_array_almost_equal(X_observed, X_expected)
np.testing.assert_array_almost_equal(X_observed, X_expected)
Loading

0 comments on commit 98cf73b

Please sign in to comment.