Skip to content

Commit

Permalink
change: test data reading when n_rows = 1 mod batch_size
Browse files Browse the repository at this point in the history
This test verifies that data is read correctly when the number of rows
is one row larger than an integer multiple of the batch size. This is
useful because squeezing single-element arrays can make their new shape
be empty and for that reason incompatible with `np.concatenate`.
  • Loading branch information
Andre Perunicic authored and wiltonwu committed Jun 24, 2020
1 parent 9008b90 commit 8db46be
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions test/test_read_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sagemaker_sklearn_extension.externals.read_data import _get_data
from sagemaker_sklearn_extension.externals.read_data import _get_reader
from sagemaker_sklearn_extension.externals.read_data import _get_size_total
from sagemaker_sklearn_extension.externals.read_data import _read_to_fit_memory
from sagemaker_sklearn_extension.externals.read_data import read_csv_data


Expand Down Expand Up @@ -268,6 +269,26 @@ def test_read_csv_data_split_limited_object():
assert y.dtype.kind == "O"


@pytest.mark.parametrize("output_dtype", ["O", "U"])
def test_read_to_fit_memory_dangling_element(tmpdir_factory, output_dtype):
"""Test that data is read in correctly when `len(data) = 1 mod batch_size`."""
data = np.zeros((10, 10)).astype(str)
for i in range(data.shape[0]):
data[i, i] = str(i + 1)
data_dir = tmpdir_factory.mktemp("ten_line_csv")
data_file = data_dir.join("ten_lines.csv")
np.savetxt(data_file.strpath, data, delimiter=",", newline="\n", fmt="%s")

X_read, y_read = _read_to_fit_memory(
_get_reader(data_dir.strpath, 3),
psutil.virtual_memory().total,
output_dtype=output_dtype,
target_column_index=0,
)
assert np.array_equal(data[:, 1:], X_read)
assert np.array_equal(data[:, 0], y_read)


def test_list_alphabetical():
"""Test for checking 'list_files' returns alphabetically"""
path = "test/data/csv/mock_datasplitter_output"
Expand Down

0 comments on commit 8db46be

Please sign in to comment.