From e620e9dce4a40b46e768cca74220735852516223 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 31 Oct 2024 12:16:13 +0100 Subject: [PATCH] [backport 2.3.x] BUG/TST (string dtype): fix and update tests for Stata IO (#60130) (#60155) BUG/TST (string dtype): fix and update tests for Stata IO (#60130) (cherry picked from commit e7d54a54da8a179fbde5878dfb4e6440d0cfbac8) --- pandas/io/stata.py | 5 +++ pandas/tests/io/test_stata.py | 82 ++++++++++++++++++----------------- 2 files changed, 48 insertions(+), 39 deletions(-) diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 4abf9af185a01..b5057a6681638 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -605,7 +605,11 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame: if getattr(data[col].dtype, "numpy_dtype", None) is not None: data[col] = data[col].astype(data[col].dtype.numpy_dtype) elif is_string_dtype(data[col].dtype): + # TODO could avoid converting string dtype to object here, + # but handle string dtype in _encode_strings data[col] = data[col].astype("object") + # generate_table checks for None values + data.loc[data[col].isna(), col] = None dtype = data[col].dtype empty_df = data.shape[0] == 0 @@ -2671,6 +2675,7 @@ def _encode_strings(self) -> None: continue column = self.data[col] dtype = column.dtype + # TODO could also handle string dtype here specifically if dtype.type is np.object_: inferred_dtype = infer_dtype(column, skipna=True) if not ((inferred_dtype == "string") or len(column) == 0): diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py index 09509fb495034..32f1c8d65271b 100644 --- a/pandas/tests/io/test_stata.py +++ b/pandas/tests/io/test_stata.py @@ -11,8 +11,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - import pandas.util._test_decorators as td import pandas as pd @@ -347,9 +345,8 @@ def test_write_dta6(self, datapath): check_index_type=False, ) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) - def test_read_write_dta10(self, version): + def test_read_write_dta10(self, version, using_infer_string): original = DataFrame( data=[["string", "object", 1, 1.1, np.datetime64("2003-12-25")]], columns=["string", "object", "integer", "floating", "datetime"], @@ -362,12 +359,17 @@ def test_read_write_dta10(self, version): with tm.ensure_clean() as path: original.to_stata(path, convert_dates={"datetime": "tc"}, version=version) written_and_read_again = self.read_dta(path) - # original.index is np.int32, read index is np.int64 - tm.assert_frame_equal( - written_and_read_again.set_index("index"), - original, - check_index_type=False, - ) + + expected = original.copy() + if using_infer_string: + expected["object"] = expected["object"].astype("str") + + # original.index is np.int32, read index is np.int64 + tm.assert_frame_equal( + written_and_read_again.set_index("index"), + expected, + check_index_type=False, + ) def test_stata_doc_examples(self): with tm.ensure_clean() as path: @@ -1153,7 +1155,6 @@ def test_categorical_ordering(self, file, datapath): assert parsed[col].cat.ordered assert not parsed_unordered[col].cat.ordered - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False) @pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.parametrize( "file", @@ -1215,6 +1216,10 @@ def _convert_categorical(from_frame: DataFrame) -> DataFrame: if cat.categories.dtype == object: categories = pd.Index._with_infer(cat.categories._values) cat = cat.set_categories(categories) + elif cat.categories.dtype == "string" and len(cat.categories) == 0: + # if the read categories are empty, it comes back as object dtype + categories = cat.categories.astype(object) + cat = cat.set_categories(categories) from_frame[col] = cat return from_frame @@ -1244,7 +1249,6 @@ def test_iterator(self, datapath): from_chunks = pd.concat(itr) tm.assert_frame_equal(parsed, from_chunks) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False) @pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.parametrize( "file", @@ -1548,12 +1552,11 @@ def test_inf(self, infval): with tm.ensure_clean() as path: df.to_stata(path) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_path_pathlib(self): df = DataFrame( 1.1 * np.arange(120).reshape((30, 4)), - columns=pd.Index(list("ABCD"), dtype=object), - index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), ) df.index.name = "index" reader = lambda x: read_stata(x).set_index("index") @@ -1584,13 +1587,12 @@ def test_value_labels_iterator(self, write_index): value_labels = dta_iter.value_labels() assert value_labels == {"A": {0: "A", 1: "B", 2: "C", 3: "E"}} - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_set_index(self): # GH 17328 df = DataFrame( 1.1 * np.arange(120).reshape((30, 4)), - columns=pd.Index(list("ABCD"), dtype=object), - index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), ) df.index.name = "index" with tm.ensure_clean() as path: @@ -1618,8 +1620,7 @@ def test_date_parsing_ignores_format_details(self, column, datapath): formatted = df.loc[0, column + "_fmt"] assert unformatted == formatted - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") - def test_writer_117(self): + def test_writer_117(self, using_infer_string): original = DataFrame( data=[ [ @@ -1682,13 +1683,17 @@ def test_writer_117(self): version=117, ) written_and_read_again = self.read_dta(path) - # original.index is np.int32, read index is np.int64 - tm.assert_frame_equal( - written_and_read_again.set_index("index"), - original, - check_index_type=False, - ) - tm.assert_frame_equal(original, copy) + + expected = original[:] + if using_infer_string: + # object dtype (with only strings/None) comes back as string dtype + expected["object"] = expected["object"].astype("str") + + tm.assert_frame_equal( + written_and_read_again.set_index("index"), + expected, + ) + tm.assert_frame_equal(original, copy) def test_convert_strl_name_swap(self): original = DataFrame( @@ -1725,15 +1730,14 @@ def test_invalid_date_conversion(self): with pytest.raises(ValueError, match=msg): original.to_stata(path, convert_dates={"wrong_name": "tc"}) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) def test_nonfile_writing(self, version): # GH 21041 bio = io.BytesIO() df = DataFrame( 1.1 * np.arange(120).reshape((30, 4)), - columns=pd.Index(list("ABCD"), dtype=object), - index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), ) df.index.name = "index" with tm.ensure_clean() as path: @@ -1744,13 +1748,12 @@ def test_nonfile_writing(self, version): reread = read_stata(path, index_col="index") tm.assert_frame_equal(df, reread) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_gzip_writing(self): # writing version 117 requires seek and cannot be used with gzip df = DataFrame( 1.1 * np.arange(120).reshape((30, 4)), - columns=pd.Index(list("ABCD"), dtype=object), - index=pd.Index([f"i-{i}" for i in range(30)], dtype=object), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), ) df.index.name = "index" with tm.ensure_clean() as path: @@ -1777,8 +1780,7 @@ def test_unicode_dta_118(self, datapath): tm.assert_frame_equal(unicode_df, expected) - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") - def test_mixed_string_strl(self): + def test_mixed_string_strl(self, using_infer_string): # GH 23633 output = [{"mixed": "string" * 500, "number": 0}, {"mixed": None, "number": 1}] output = DataFrame(output) @@ -1796,7 +1798,10 @@ def test_mixed_string_strl(self): path, write_index=False, convert_strl=["mixed"], version=117 ) reread = read_stata(path) - expected = output.fillna("") + expected = output.copy() + if using_infer_string: + expected["mixed"] = expected["mixed"].astype("str") + expected = expected.fillna("") tm.assert_frame_equal(reread, expected) @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) @@ -1875,7 +1880,7 @@ def test_stata_119(self, datapath): reader._ensure_open() assert reader._nvar == 32999 - @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") + @pytest.mark.filterwarnings("ignore:Downcasting behavior:FutureWarning") @pytest.mark.parametrize("version", [118, 119, None]) def test_utf8_writer(self, version): cat = pd.Categorical(["a", "β", "ĉ"], ordered=True) @@ -2143,14 +2148,13 @@ def test_iterator_errors(datapath, chunksize): pass -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_iterator_value_labels(): # GH 31544 values = ["c_label", "b_label"] + ["a_label"] * 500 df = DataFrame({f"col{k}": pd.Categorical(values, ordered=True) for k in range(2)}) with tm.ensure_clean() as path: df.to_stata(path, write_index=False) - expected = pd.Index(["a_label", "b_label", "c_label"], dtype="object") + expected = pd.Index(["a_label", "b_label", "c_label"]) with read_stata(path, chunksize=100) as reader: for j, chunk in enumerate(reader): for i in range(2):