diff --git a/test/test_read_data.py b/test/test_read_data.py index ee446e5..8139c45 100644 --- a/test/test_read_data.py +++ b/test/test_read_data.py @@ -28,6 +28,7 @@ from sagemaker_sklearn_extension.externals.read_data import _get_data from sagemaker_sklearn_extension.externals.read_data import _get_reader from sagemaker_sklearn_extension.externals.read_data import _get_size_total +from sagemaker_sklearn_extension.externals.read_data import _read_to_fit_memory from sagemaker_sklearn_extension.externals.read_data import read_csv_data @@ -268,6 +269,26 @@ def test_read_csv_data_split_limited_object(): assert y.dtype.kind == "O" +@pytest.mark.parametrize("output_dtype", ["O", "U"]) +def test_read_to_fit_memory_dangling_element(tmpdir_factory, output_dtype): + """Test that data is read in correctly when `len(data) = 1 mod batch_size`.""" + data = np.zeros((10, 10)).astype(str) + for i in range(data.shape[0]): + data[i, i] = str(i + 1) + data_dir = tmpdir_factory.mktemp("ten_line_csv") + data_file = data_dir.join("ten_lines.csv") + np.savetxt(data_file.strpath, data, delimiter=",", newline="\n", fmt="%s") + + X_read, y_read = _read_to_fit_memory( + _get_reader(data_dir.strpath, 3), + psutil.virtual_memory().total, + output_dtype=output_dtype, + target_column_index=0, + ) + assert np.array_equal(data[:, 1:], X_read) + assert np.array_equal(data[:, 0], y_read) + + def test_list_alphabetical(): """Test for checking 'list_files' returns alphabetically""" path = "test/data/csv/mock_datasplitter_output"