Skip to content

Commit

Permalink
[backport 2.3.x] BUG/TST (string dtype): fix and update tests for Sta…
Browse files Browse the repository at this point in the history
…ta IO (pandas-dev#60130) (pandas-dev#60155)

BUG/TST (string dtype): fix and update tests for Stata IO (pandas-dev#60130)

(cherry picked from commit e7d54a5)
  • Loading branch information
jorisvandenbossche authored Oct 31, 2024
1 parent fa7c87b commit e620e9d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 39 deletions.
5 changes: 5 additions & 0 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,11 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
if getattr(data[col].dtype, "numpy_dtype", None) is not None:
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
elif is_string_dtype(data[col].dtype):
# TODO could avoid converting string dtype to object here,
# but handle string dtype in _encode_strings
data[col] = data[col].astype("object")
# generate_table checks for None values
data.loc[data[col].isna(), col] = None

dtype = data[col].dtype
empty_df = data.shape[0] == 0
Expand Down Expand Up @@ -2671,6 +2675,7 @@ def _encode_strings(self) -> None:
continue
column = self.data[col]
dtype = column.dtype
# TODO could also handle string dtype here specifically
if dtype.type is np.object_:
inferred_dtype = infer_dtype(column, skipna=True)
if not ((inferred_dtype == "string") or len(column) == 0):
Expand Down
82 changes: 43 additions & 39 deletions pandas/tests/io/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

import pandas.util._test_decorators as td

import pandas as pd
Expand Down Expand Up @@ -347,9 +345,8 @@ def test_write_dta6(self, datapath):
check_index_type=False,
)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
def test_read_write_dta10(self, version):
def test_read_write_dta10(self, version, using_infer_string):
original = DataFrame(
data=[["string", "object", 1, 1.1, np.datetime64("2003-12-25")]],
columns=["string", "object", "integer", "floating", "datetime"],
Expand All @@ -362,12 +359,17 @@ def test_read_write_dta10(self, version):
with tm.ensure_clean() as path:
original.to_stata(path, convert_dates={"datetime": "tc"}, version=version)
written_and_read_again = self.read_dta(path)
# original.index is np.int32, read index is np.int64
tm.assert_frame_equal(
written_and_read_again.set_index("index"),
original,
check_index_type=False,
)

expected = original.copy()
if using_infer_string:
expected["object"] = expected["object"].astype("str")

# original.index is np.int32, read index is np.int64
tm.assert_frame_equal(
written_and_read_again.set_index("index"),
expected,
check_index_type=False,
)

def test_stata_doc_examples(self):
with tm.ensure_clean() as path:
Expand Down Expand Up @@ -1153,7 +1155,6 @@ def test_categorical_ordering(self, file, datapath):
assert parsed[col].cat.ordered
assert not parsed_unordered[col].cat.ordered

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize(
"file",
Expand Down Expand Up @@ -1215,6 +1216,10 @@ def _convert_categorical(from_frame: DataFrame) -> DataFrame:
if cat.categories.dtype == object:
categories = pd.Index._with_infer(cat.categories._values)
cat = cat.set_categories(categories)
elif cat.categories.dtype == "string" and len(cat.categories) == 0:
# if the read categories are empty, it comes back as object dtype
categories = cat.categories.astype(object)
cat = cat.set_categories(categories)
from_frame[col] = cat
return from_frame

Expand Down Expand Up @@ -1244,7 +1249,6 @@ def test_iterator(self, datapath):
from_chunks = pd.concat(itr)
tm.assert_frame_equal(parsed, from_chunks)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize(
"file",
Expand Down Expand Up @@ -1548,12 +1552,11 @@ def test_inf(self, infval):
with tm.ensure_clean() as path:
df.to_stata(path)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_path_pathlib(self):
df = DataFrame(
1.1 * np.arange(120).reshape((30, 4)),
columns=pd.Index(list("ABCD"), dtype=object),
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
columns=pd.Index(list("ABCD")),
index=pd.Index([f"i-{i}" for i in range(30)]),
)
df.index.name = "index"
reader = lambda x: read_stata(x).set_index("index")
Expand Down Expand Up @@ -1584,13 +1587,12 @@ def test_value_labels_iterator(self, write_index):
value_labels = dta_iter.value_labels()
assert value_labels == {"A": {0: "A", 1: "B", 2: "C", 3: "E"}}

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_set_index(self):
# GH 17328
df = DataFrame(
1.1 * np.arange(120).reshape((30, 4)),
columns=pd.Index(list("ABCD"), dtype=object),
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
columns=pd.Index(list("ABCD")),
index=pd.Index([f"i-{i}" for i in range(30)]),
)
df.index.name = "index"
with tm.ensure_clean() as path:
Expand Down Expand Up @@ -1618,8 +1620,7 @@ def test_date_parsing_ignores_format_details(self, column, datapath):
formatted = df.loc[0, column + "_fmt"]
assert unformatted == formatted

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_writer_117(self):
def test_writer_117(self, using_infer_string):
original = DataFrame(
data=[
[
Expand Down Expand Up @@ -1682,13 +1683,17 @@ def test_writer_117(self):
version=117,
)
written_and_read_again = self.read_dta(path)
# original.index is np.int32, read index is np.int64
tm.assert_frame_equal(
written_and_read_again.set_index("index"),
original,
check_index_type=False,
)
tm.assert_frame_equal(original, copy)

expected = original[:]
if using_infer_string:
# object dtype (with only strings/None) comes back as string dtype
expected["object"] = expected["object"].astype("str")

tm.assert_frame_equal(
written_and_read_again.set_index("index"),
expected,
)
tm.assert_frame_equal(original, copy)

def test_convert_strl_name_swap(self):
original = DataFrame(
Expand Down Expand Up @@ -1725,15 +1730,14 @@ def test_invalid_date_conversion(self):
with pytest.raises(ValueError, match=msg):
original.to_stata(path, convert_dates={"wrong_name": "tc"})

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
def test_nonfile_writing(self, version):
# GH 21041
bio = io.BytesIO()
df = DataFrame(
1.1 * np.arange(120).reshape((30, 4)),
columns=pd.Index(list("ABCD"), dtype=object),
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
columns=pd.Index(list("ABCD")),
index=pd.Index([f"i-{i}" for i in range(30)]),
)
df.index.name = "index"
with tm.ensure_clean() as path:
Expand All @@ -1744,13 +1748,12 @@ def test_nonfile_writing(self, version):
reread = read_stata(path, index_col="index")
tm.assert_frame_equal(df, reread)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_gzip_writing(self):
# writing version 117 requires seek and cannot be used with gzip
df = DataFrame(
1.1 * np.arange(120).reshape((30, 4)),
columns=pd.Index(list("ABCD"), dtype=object),
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
columns=pd.Index(list("ABCD")),
index=pd.Index([f"i-{i}" for i in range(30)]),
)
df.index.name = "index"
with tm.ensure_clean() as path:
Expand All @@ -1777,8 +1780,7 @@ def test_unicode_dta_118(self, datapath):

tm.assert_frame_equal(unicode_df, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_mixed_string_strl(self):
def test_mixed_string_strl(self, using_infer_string):
# GH 23633
output = [{"mixed": "string" * 500, "number": 0}, {"mixed": None, "number": 1}]
output = DataFrame(output)
Expand All @@ -1796,7 +1798,10 @@ def test_mixed_string_strl(self):
path, write_index=False, convert_strl=["mixed"], version=117
)
reread = read_stata(path)
expected = output.fillna("")
expected = output.copy()
if using_infer_string:
expected["mixed"] = expected["mixed"].astype("str")
expected = expected.fillna("")
tm.assert_frame_equal(reread, expected)

@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
Expand Down Expand Up @@ -1875,7 +1880,7 @@ def test_stata_119(self, datapath):
reader._ensure_open()
assert reader._nvar == 32999

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.filterwarnings("ignore:Downcasting behavior:FutureWarning")
@pytest.mark.parametrize("version", [118, 119, None])
def test_utf8_writer(self, version):
cat = pd.Categorical(["a", "β", "ĉ"], ordered=True)
Expand Down Expand Up @@ -2143,14 +2148,13 @@ def test_iterator_errors(datapath, chunksize):
pass


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_iterator_value_labels():
# GH 31544
values = ["c_label", "b_label"] + ["a_label"] * 500
df = DataFrame({f"col{k}": pd.Categorical(values, ordered=True) for k in range(2)})
with tm.ensure_clean() as path:
df.to_stata(path, write_index=False)
expected = pd.Index(["a_label", "b_label", "c_label"], dtype="object")
expected = pd.Index(["a_label", "b_label", "c_label"])
with read_stata(path, chunksize=100) as reader:
for j, chunk in enumerate(reader):
for i in range(2):
Expand Down

0 comments on commit e620e9d

Please sign in to comment.