From 478b75c52999ec029a3e1585b9279323b274b02d Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Thu, 7 Nov 2024 09:31:05 -0600 Subject: [PATCH] fix --- src/sql/connection/connection.py | 8 ++++++-- src/tests/integration/test_duckDB.py | 14 +++++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index c684427a4..0c3ff604f 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -991,9 +991,9 @@ class DBAPIConnection(AbstractConnection): def __init__(self, connection, alias=None, config=None): # detect if the engine is a native duckdb connection - _is_duckdb_native = _check_if_duckdb_dbapi_connection(connection) + self._is_duckdb_native = _check_if_duckdb_dbapi_connection(connection) - self._dialect = "duckdb" if _is_duckdb_native else None + self._dialect = "duckdb" if self._is_duckdb_native else None self._driver = None # TODO: implement the dialect blacklist and add unit tests @@ -1038,6 +1038,10 @@ def raw_execute(self, query, parameters=None, with_=None): query = self._resolve_cte(query, with_) cur = self._connection.cursor() + + if self._is_duckdb_native: + cur.execute("SET python_scan_all_frames=true") + cur.execute(query) if self._requires_manual_commit: diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py index c17683aa6..4a0027212 100644 --- a/src/tests/integration/test_duckDB.py +++ b/src/tests/integration/test_duckDB.py @@ -237,8 +237,16 @@ def test_commits_all_statements(ip, sql, request): assert out.result.dict() == {"x": (1, 2)} -def test_can_query_existing_df(ip_with_duckdb_sqlalchemy_empty): +@pytest.mark.parametrize( + "ip", + [ + "ip_with_duckdb_native_empty", + "ip_with_duckdb_sqlalchemy_empty", + ], +) +def test_can_query_existing_df(ip, request): + ip = request.getfixturevalue(ip) df = pd.DataFrame({"city": ["NYC"]}) # noqa - ip_with_duckdb_sqlalchemy_empty.run_cell("%sql SET python_scan_all_frames=true") - out = ip_with_duckdb_sqlalchemy_empty.run_cell("%sql SELECT * FROM df;") + ip.run_cell("%sql SET python_scan_all_frames=true") + out = ip.run_cell("%sql SELECT * FROM df;") assert out.result.dict() == {"city": ("NYC",)}