diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bd6e0146..dc26e75cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## 0.10.16dev +* [Fix] Updates docs for querying data frames when using DuckDB SQLAlchemy connections +* [Fix] Support for scanning data frames when using native DuckDB connections due to changes in DuckDB's API + ## 0.10.15 (2024-11-05) *Drops compatibility with Python 3.8* diff --git a/doc/integrations/duckdb.md b/doc/integrations/duckdb.md index fe45d0b05..6e09d9c95 100644 --- a/doc/integrations/duckdb.md +++ b/doc/integrations/duckdb.md @@ -265,6 +265,14 @@ df = pd.DataFrame({"x": range(100)}) %sql engine ``` +```{important} +If you're using DuckDB 1.1.0 or higher, you must run this before querying a data frame + +~~~sql +%sql SET python_scan_all_frames=true +~~~ +``` + ```{code-cell} ipython3 %%sql SELECT * diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index c684427a4..6c1c1d525 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -991,9 +991,9 @@ class DBAPIConnection(AbstractConnection): def __init__(self, connection, alias=None, config=None): # detect if the engine is a native duckdb connection - _is_duckdb_native = _check_if_duckdb_dbapi_connection(connection) + self._is_duckdb_native = _check_if_duckdb_dbapi_connection(connection) - self._dialect = "duckdb" if _is_duckdb_native else None + self._dialect = "duckdb" if self._is_duckdb_native else None self._driver = None # TODO: implement the dialect blacklist and add unit tests @@ -1038,6 +1038,15 @@ def raw_execute(self, query, parameters=None, with_=None): query = self._resolve_cte(query, with_) cur = self._connection.cursor() + + # NOTE: this is a workaround for duckdb 1.1.0 and higher so we keep the + # existing behavior of being able to query data frames + if self._is_duckdb_native: + try: + cur.execute("SET python_scan_all_frames=true") + except Exception: + pass + cur.execute(query) if self._requires_manual_commit: diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py index d0c460e25..4a0027212 100644 --- a/src/tests/integration/test_duckDB.py +++ b/src/tests/integration/test_duckDB.py @@ -235,3 +235,18 @@ def test_commits_all_statements(ip, sql, request): out = ip.run_cell(sql) assert out.error_in_exec is None assert out.result.dict() == {"x": (1, 2)} + + +@pytest.mark.parametrize( + "ip", + [ + "ip_with_duckdb_native_empty", + "ip_with_duckdb_sqlalchemy_empty", + ], +) +def test_can_query_existing_df(ip, request): + ip = request.getfixturevalue(ip) + df = pd.DataFrame({"city": ["NYC"]}) # noqa + ip.run_cell("%sql SET python_scan_all_frames=true") + out = ip.run_cell("%sql SELECT * FROM df;") + assert out.result.dict() == {"city": ("NYC",)} diff --git a/src/tests/test_connection.py b/src/tests/test_connection.py index 8369ff690..ae1e0e4a8 100644 --- a/src/tests/test_connection.py +++ b/src/tests/test_connection.py @@ -888,6 +888,10 @@ def mock_dbapi_raw_execute(monkeypatch, conn_dbapi_duckdb): def test_raw_execute_doesnt_transpile_sql_query(fixture_name, request): mock_execute, conn = request.getfixturevalue(fixture_name) + # to prevent the "SET python_scan_all_frames=true" call, since we don't want to + # test that here + conn._is_duckdb_native = False + conn.raw_execute("CREATE TABLE foo (bar INT)") conn.raw_execute("INSERT INTO foo VALUES (42), (43)") conn.raw_execute("SELECT * FROM foo LIMIT 1") @@ -949,6 +953,10 @@ def mock_dbapi_execute(monkeypatch): def test_execute_transpiles_sql_query(fixture_name, request): mock_execute, conn = request.getfixturevalue(fixture_name) + # to prevent the "SET python_scan_all_frames=true" call, since we don't want to + # test that here + conn._is_duckdb_native = False + conn.execute("CREATE TABLE foo (bar INT)") conn.execute("INSERT INTO foo VALUES (42), (43)") conn.execute("SELECT * FROM foo LIMIT 1")