Skip to content

Commit

Permalink
cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Oct 2, 2023
1 parent 2755100 commit 3577a59
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 73 deletions.
4 changes: 2 additions & 2 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2299,8 +2299,8 @@ def to_sql(
else:
table_name = name

# TODO: pandas if_exists="append" will still create the
# table if it does not exist; ADBC has append/create
# pandas if_exists="append" will still create the
# table if it does not exist; ADBC is more explicit with append/create
# as applicable modes, so the semantics get blurred across
# the libraries
mode = "create"
Expand Down
99 changes: 28 additions & 71 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,50 +141,29 @@ def create_and_load_iris_sqlite3(conn, iris_file: Path):
"Name" TEXT
)"""

if isinstance(conn, sqlite3.Connection):
cur = conn.cursor()
cur.execute(stmt)
with iris_file.open(newline=None, encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
next(reader)
stmt = "INSERT INTO iris VALUES($1, $2, $3, $4, $5)"
# ADBC requires explicit types - no implicit str -> float conversion
records = []
records = [
(
float(row[0]),
float(row[1]),
float(row[2]),
float(row[3]),
row[4],
)
for row in reader
]
cur = conn.cursor()
cur.execute(stmt)
with iris_file.open(newline=None, encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
next(reader)
stmt = "INSERT INTO iris VALUES($1, $2, $3, $4, $5)"
# ADBC requires explicit types - no implicit str -> float conversion
records = []
records = [
(
float(row[0]),
float(row[1]),
float(row[2]),
float(row[3]),
row[4],
)
for row in reader
]

cur.executemany(stmt, records)
else:
with conn.cursor() as cur:
cur.execute(stmt)
with iris_file.open(newline=None, encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
next(reader)
stmt = "INSERT INTO iris VALUES($1, $2, $3, $4, $5)"
# ADBC requires explicit types - no implicit str -> float conversion
records = []
records = [
(
float(row[0]),
float(row[1]),
float(row[2]),
float(row[3]),
row[4],
)
for row in reader
]

cur.executemany(stmt, records)
cur.executemany(stmt, records)
cur.close()

conn.commit()
conn.commit()


def create_and_load_iris_postgresql(conn, iris_file: Path):
Expand Down Expand Up @@ -320,8 +299,6 @@ def create_and_load_types_sqlite3(conn, types_data: list[dict]):
def create_and_load_types_postgresql(conn, types_data: list[dict]):
# Boolean support not added until 0.8.0
adbc = import_optional_dependency("adbc_driver_manager")
from pandas.util.version import Version

if Version(adbc.__version__) < Version("0.8.0"):
bool_type = "INTEGER"
else:
Expand Down Expand Up @@ -684,8 +661,6 @@ def postgresql_adbc_conn(iris_path, types_data):
conn.rollback()
# Boolean support not added until 0.8.0
adbc = import_optional_dependency("adbc_driver_manager")
from pandas.util.version import Version

if Version(adbc.__version__) < Version("0.8.0"):
new_data = []
for entry in types_data:
Expand Down Expand Up @@ -1589,7 +1564,6 @@ def test_api_to_sql_append(conn, request, test_frame1):

@pytest.mark.parametrize("conn", all_connectable)
def test_api_to_sql_type_mapping(conn, request, test_frame3):
conn_name = conn
conn = request.getfixturevalue(conn)
if sql.has_table("test_frame5", conn):
with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL:
Expand All @@ -1598,9 +1572,6 @@ def test_api_to_sql_type_mapping(conn, request, test_frame3):
sql.to_sql(test_frame3, "test_frame5", conn, index=False)
result = sql.read_sql("SELECT * FROM test_frame5", conn)

if conn_name == "postgresql_adbc_conn":
# postgresql driver does not maintain capitalization
result.columns = ["index", "A", "B"]
tm.assert_frame_equal(test_frame3, result)


Expand Down Expand Up @@ -2223,13 +2194,9 @@ def test_api_escaped_table_name(conn, request):
@pytest.mark.parametrize("conn", all_connectable)
def test_api_read_sql_duplicate_columns(conn, request):
# GH#53117
if conn == "postgresql_adbc_conn":
request.node.add_marker(
pytest.mark.xfail(reason="fails with syntax error", strict=True)
)
elif conn == "sqlite_adbc_conn":
if "adbc" in conn:
request.node.add_marker(
pytest.mark.xfail(reason="fails with ValueError", strict=True)
pytest.mark.xfail(reason="pyarrow->pandas throws ValueError", strict=True)
)
conn = request.getfixturevalue(conn)
if sql.has_table("test_table", conn):
Expand All @@ -2239,7 +2206,7 @@ def test_api_read_sql_duplicate_columns(conn, request):
df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3], "c": 1})
df.to_sql(name="test_table", con=conn, index=False)

result = pd.read_sql("SELECT a, b, a +1 as a, c FROM test_table;", conn)
result = pd.read_sql("SELECT a, b, a +1 as a, c FROM test_table", conn)
expected = DataFrame(
[[1, 0.1, 2, 1], [2, 0.2, 3, 1], [3, 0.3, 4, 1]],
columns=["a", "b", "a", "c"],
Expand Down Expand Up @@ -2629,8 +2596,6 @@ def test_roundtrip(conn, request, test_frame1):

if "adbc" in conn_name:
result = result.rename(columns={"__index_level_0__": "level_0"})
if conn_name == "postgresql_adbc_conn":
result = result.rename(columns={"a": "A", "b": "B", "c": "C", "d": "D"})
result.set_index("level_0", inplace=True)
# result.index.astype(int)

Expand Down Expand Up @@ -3050,7 +3015,9 @@ def test_nan_string(conn, request):
def test_to_sql_save_index(conn, request):
if "adbc" in conn:
request.node.add_marker(
pytest.mark.xfail(reason="not working with ADBC drivers", strict=True)
pytest.mark.xfail(
reason="ADBC implementation does not create index", strict=True
)
)
conn_name = conn
conn = request.getfixturevalue(conn)
Expand All @@ -3063,7 +3030,7 @@ def test_to_sql_save_index(conn, request):
with pandasSQL.run_transaction():
assert pandasSQL.to_sql(df, tbl_name) == 2

if conn_name in {"sqlite_buildin", "sqlite_str"} or "adbc" in conn_name:
if conn_name in {"sqlite_buildin", "sqlite_str"}:
ixs = sql.read_sql_query(
"SELECT * FROM sqlite_master WHERE type = 'index' "
f"AND tbl_name = '{tbl_name}'",
Expand Down Expand Up @@ -3516,8 +3483,6 @@ def test_read_sql_dtype_backend(
if "adbc" in conn_name:
# Boolean support not added until 0.8.0
adbc = import_optional_dependency("adbc_driver_manager")
from pandas.util.version import Version

if Version(adbc.__version__) < Version("0.8.0"):
df = df.drop(columns=["e", "f"])
df.to_sql(name=table, con=conn, index=False, if_exists="replace")
Expand All @@ -3529,8 +3494,6 @@ def test_read_sql_dtype_backend(
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
if "adbc" in conn_name:
adbc = import_optional_dependency("adbc_driver_manager")
from pandas.util.version import Version

if Version(adbc.__version__) < Version("0.8.0"):
expected = expected.drop(columns=["e", "f"])
tm.assert_frame_equal(result, expected)
Expand Down Expand Up @@ -3578,8 +3541,6 @@ def test_read_sql_dtype_backend_table(
df = dtype_backend_data
if "adbc" in conn_name:
adbc = import_optional_dependency("adbc_driver_manager")
from pandas.util.version import Version

if Version(adbc.__version__) < Version("0.8.0"):
df = df.drop(columns=["e", "f"])
df.to_sql(name=table, con=conn, index=False, if_exists="replace")
Expand All @@ -3589,8 +3550,6 @@ def test_read_sql_dtype_backend_table(
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
if "adbc" in conn_name:
adbc = import_optional_dependency("adbc_driver_manager")
from pandas.util.version import Version

if Version(adbc.__version__) < Version("0.8.0"):
expected = expected.drop(columns=["e", "f"])
tm.assert_frame_equal(result, expected)
Expand Down Expand Up @@ -3620,8 +3579,6 @@ def test_read_sql_invalid_dtype_backend_table(conn, request, func, dtype_backend
df = dtype_backend_data
if "adbc" in conn_name:
adbc = import_optional_dependency("adbc_driver_manager")
from pandas.util.version import Version

if Version(adbc.__version__) < Version("0.8.0"):
df = df.drop(columns=["e", "f"])
df.to_sql(name=table, con=conn, index=False, if_exists="replace")
Expand Down

0 comments on commit 3577a59

Please sign in to comment.