Skip to content

Commit

Permalink
String dtype: enable in SQL IO + resolve all xfails (#60255)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored Nov 14, 2024
1 parent 61f800d commit ba4d1cf
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 14 deletions.
2 changes: 2 additions & 0 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,7 @@ def convert_dtypes(

def maybe_infer_to_datetimelike(
value: npt.NDArray[np.object_],
convert_to_nullable_dtype: bool = False,
) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray | IntervalArray:
"""
we might have a array (or single object) that is datetime like,
Expand Down Expand Up @@ -1199,6 +1200,7 @@ def maybe_infer_to_datetimelike(
# numpy would have done it for us.
convert_numeric=False,
convert_non_numeric=True,
convert_to_nullable_dtype=convert_to_nullable_dtype,
dtype_if_all_nat=np.dtype("M8[s]"),
)

Expand Down
5 changes: 3 additions & 2 deletions pandas/core/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,8 +966,9 @@ def convert(arr):
if dtype is None:
if arr.dtype == np.dtype("O"):
# i.e. maybe_convert_objects didn't convert
arr = maybe_infer_to_datetimelike(arr)
if dtype_backend != "numpy" and arr.dtype == np.dtype("O"):
convert_to_nullable_dtype = dtype_backend != "numpy"
arr = maybe_infer_to_datetimelike(arr, convert_to_nullable_dtype)
if convert_to_nullable_dtype and arr.dtype == np.dtype("O"):
new_dtype = StringDtype()
arr_cls = new_dtype.construct_array_type()
arr = arr_cls._from_sequence(arr, dtype=new_dtype)
Expand Down
21 changes: 19 additions & 2 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from pandas.core.dtypes.common import (
is_dict_like,
is_list_like,
is_object_dtype,
is_string_dtype,
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
Expand All @@ -58,6 +60,7 @@
Series,
)
from pandas.core.arrays import ArrowExtensionArray
from pandas.core.arrays.string_ import StringDtype
from pandas.core.base import PandasObject
import pandas.core.common as com
from pandas.core.common import maybe_make_list
Expand Down Expand Up @@ -1316,7 +1319,12 @@ def _harmonize_columns(
elif dtype_backend == "numpy" and col_type is float:
# floats support NA, can always convert!
self.frame[col_name] = df_col.astype(col_type)

elif (
using_string_dtype()
and is_string_dtype(col_type)
and is_object_dtype(self.frame[col_name])
):
self.frame[col_name] = df_col.astype(col_type)
elif dtype_backend == "numpy" and len(df_col) == df_col.count():
# No NA values, can convert ints and bools
if col_type is np.dtype("int64") or col_type is bool:
Expand Down Expand Up @@ -1403,6 +1411,7 @@ def _get_dtype(self, sqltype):
DateTime,
Float,
Integer,
String,
)

if isinstance(sqltype, Float):
Expand All @@ -1422,6 +1431,10 @@ def _get_dtype(self, sqltype):
return date
elif isinstance(sqltype, Boolean):
return bool
elif isinstance(sqltype, String):
if using_string_dtype():
return StringDtype(na_value=np.nan)

return object


Expand Down Expand Up @@ -2205,7 +2218,7 @@ def read_table(
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

arrow_string_types_mapper()
mapping = arrow_string_types_mapper()
else:
mapping = None

Expand Down Expand Up @@ -2286,6 +2299,10 @@ def read_query(
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

mapping = arrow_string_types_mapper()
else:
mapping = None

Expand Down
23 changes: 13 additions & 10 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
pytest.mark.filterwarnings(
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
),
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
pytest.mark.single_cpu,
]


Expand Down Expand Up @@ -685,6 +685,7 @@ def postgresql_psycopg2_conn(postgresql_psycopg2_engine):

@pytest.fixture
def postgresql_adbc_conn():
pytest.importorskip("pyarrow")
pytest.importorskip("adbc_driver_postgresql")
from adbc_driver_postgresql import dbapi

Expand Down Expand Up @@ -817,6 +818,7 @@ def sqlite_conn_types(sqlite_engine_types):

@pytest.fixture
def sqlite_adbc_conn():
pytest.importorskip("pyarrow")
pytest.importorskip("adbc_driver_sqlite")
from adbc_driver_sqlite import dbapi

Expand Down Expand Up @@ -986,13 +988,13 @@ def test_dataframe_to_sql(conn, test_frame1, request):

@pytest.mark.parametrize("conn", all_connectable)
def test_dataframe_to_sql_empty(conn, test_frame1, request):
if conn == "postgresql_adbc_conn":
if conn == "postgresql_adbc_conn" and not using_string_dtype():
request.node.add_marker(
pytest.mark.xfail(
reason="postgres ADBC driver cannot insert index with null type",
strict=True,
reason="postgres ADBC driver < 1.2 cannot insert index with null type",
)
)

# GH 51086 if conn is sqlite_engine
conn = request.getfixturevalue(conn)
empty_df = test_frame1.iloc[:0]
Expand Down Expand Up @@ -3557,7 +3559,8 @@ def test_read_sql_dtype_backend(
result = getattr(pd, func)(
f"Select * from {table}", conn, dtype_backend=dtype_backend
)
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)

tm.assert_frame_equal(result, expected)

if "adbc" in conn_name:
Expand Down Expand Up @@ -3607,7 +3610,7 @@ def test_read_sql_dtype_backend_table(

with pd.option_context("mode.string_storage", string_storage):
result = getattr(pd, func)(table, conn, dtype_backend=dtype_backend)
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
tm.assert_frame_equal(result, expected)

if "adbc" in conn_name:
Expand Down Expand Up @@ -4123,7 +4126,7 @@ def tquery(query, con=None):
def test_xsqlite_basic(sqlite_buildin):
frame = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)
assert sql.to_sql(frame, name="test_table", con=sqlite_buildin, index=False) == 10
Expand All @@ -4150,7 +4153,7 @@ def test_xsqlite_basic(sqlite_buildin):
def test_xsqlite_write_row_by_row(sqlite_buildin):
frame = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)
frame.iloc[0, 0] = np.nan
Expand All @@ -4173,7 +4176,7 @@ def test_xsqlite_write_row_by_row(sqlite_buildin):
def test_xsqlite_execute(sqlite_buildin):
frame = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)
create_sql = sql.get_schema(frame, "test")
Expand All @@ -4194,7 +4197,7 @@ def test_xsqlite_execute(sqlite_buildin):
def test_xsqlite_schema(sqlite_buildin):
frame = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)
create_sql = sql.get_schema(frame, "test")
Expand Down

0 comments on commit ba4d1cf

Please sign in to comment.