diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7472e7c9..69ffc80b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,6 +23,14 @@ jobs: # Do not tear down Testcontainers TC_KEEPALIVE: true + # https://docs.github.com/en/actions/using-containerized-services/about-service-containers + services: + cratedb: + image: crate/crate:nightly + ports: + - 4200:4200 + - 5432:5432 + name: Python ${{ matrix.python-version }} on OS ${{ matrix.os }} steps: diff --git a/.gitignore b/.gitignore index d4bde8fd..0f24e1a4 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.egg-info .eggs __pycache__ +*.pyc dist .coverage coverage.xml diff --git a/CHANGES.md b/CHANGES.md index e032df26..853f9c95 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,7 @@ ## Unreleased +- Add SQL runner utility primitives to `io.sql` namespace ## 2023/11/06 v0.0.2 diff --git a/cratedb_toolkit/io/__init__.py b/cratedb_toolkit/io/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cratedb_toolkit/io/sql.py b/cratedb_toolkit/io/sql.py new file mode 100644 index 00000000..01441afe --- /dev/null +++ b/cratedb_toolkit/io/sql.py @@ -0,0 +1 @@ +from cratedb_toolkit.util.database import DatabaseAdapter, run_sql # noqa: F401 diff --git a/cratedb_toolkit/util/database.py b/cratedb_toolkit/util/database.py index 999e8ba5..7ef709ac 100644 --- a/cratedb_toolkit/util/database.py +++ b/cratedb_toolkit/util/database.py @@ -1,6 +1,11 @@ # Copyright (c) 2023, Crate.io Inc. # Distributed under the terms of the AGPLv3 license, see LICENSE. +import io +import typing as t +from pathlib import Path + import sqlalchemy as sa +import sqlparse from sqlalchemy.exc import ProgrammingError from sqlalchemy.sql.elements import AsBoolean @@ -22,12 +27,23 @@ def __init__(self, dburi: str): self.engine = sa.create_engine(self.dburi, echo=False) self.connection = self.engine.connect() - def run_sql(self, sql: str, records: bool = False, ignore: str = None): + def run_sql(self, sql: t.Union[str, Path, io.IOBase], records: bool = False, ignore: str = None): """ Run SQL statement, and return results, optionally ignoring exceptions. """ + + sql_effective: str + if isinstance(sql, str): + sql_effective = sql + elif isinstance(sql, Path): + sql_effective = sql.read_text() + elif isinstance(sql, io.IOBase): + sql_effective = sql.read() + else: + raise TypeError("SQL statement type must be either string, Path, or IO handle") + try: - return self.run_sql_real(sql=sql, records=records) + return self.run_sql_real(sql=sql_effective, records=records) except Exception as ex: if not ignore: raise @@ -38,12 +54,23 @@ def run_sql_real(self, sql: str, records: bool = False): """ Invoke SQL statement, and return results. """ - result = self.connection.execute(sa.text(sql)) - if records: - rows = result.mappings().fetchall() - return [dict(row.items()) for row in rows] + results = [] + with self.engine.connect() as connection: + for statement in sqlparse.split(sql): + result = connection.execute(sa.text(statement)) + data: t.Any + if records: + rows = result.mappings().fetchall() + data = [dict(row.items()) for row in rows] + else: + data = result.fetchall() + results.append(data) + + # Backward-compatibility. + if len(results) == 1: + return results[0] else: - return result.fetchall() + return results def count_records(self, tablename_full: str): """ diff --git a/pyproject.toml b/pyproject.toml index 4a8603d9..49d257a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ dependencies = [ "crash", "crate[sqlalchemy]>=0.34", "sqlalchemy>=2", + "sqlparse<0.5", ] [project.optional-dependencies] develop = [ diff --git a/tests/io/__init__.py b/tests/io/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/io/test_sql.py b/tests/io/test_sql.py new file mode 100644 index 00000000..8659389b --- /dev/null +++ b/tests/io/test_sql.py @@ -0,0 +1,66 @@ +import io + +import pytest +import sqlalchemy as sa + +from cratedb_toolkit.io.sql import run_sql + + +@pytest.fixture +def sqlcmd(cratedb): + def workhorse(sql: str, records: bool = True): + return run_sql(dburi=cratedb.database.dburi, sql=sql, records=records) + + return workhorse + + +def test_run_sql_direct(cratedb): + sql_string = "SELECT 1;" + outcome = run_sql(dburi=cratedb.database.dburi, sql=sql_string, records=True) + assert outcome == [{"1": 1}] + + +def test_run_sql_from_string(sqlcmd): + sql_string = "SELECT 1;" + outcome = sqlcmd(sql_string) + assert outcome == [{"1": 1}] + + +def test_run_sql_from_file(sqlcmd, tmp_path): + sql_file = tmp_path / "temp.sql" + sql_file.write_text("SELECT 1;") + outcome = sqlcmd(sql_file) + assert outcome == [{"1": 1}] + + +def test_run_sql_from_buffer(sqlcmd): + sql_buffer = io.StringIO("SELECT 1;") + outcome = sqlcmd(sql_buffer) + assert outcome == [{"1": 1}] + + +def test_run_sql_no_records(sqlcmd): + sql_string = "SELECT 1;" + outcome = sqlcmd(sql_string, records=False) + assert outcome == [(1,)] + + +def test_run_sql_multiple_statements(sqlcmd): + sql_string = "SELECT 1; SELECT 42;" + outcome = sqlcmd(sql_string) + assert outcome == [[{"1": 1}], [{"42": 42}]] + + +def test_run_sql_invalid_host(capsys): + with pytest.raises(sa.exc.OperationalError) as ex: + run_sql(dburi="crate://localhost:12345", sql="SELECT 1;") + assert ex.match( + ".*ConnectionError.*No more Servers available.*HTTPConnectionPool.*" + "Failed to establish a new connection.*Connection refused.*" + ) + + +def test_run_sql_invalid_sql_type(capsys, sqlcmd): + with pytest.raises(TypeError) as ex: + sqlcmd(None) + assert ex.match("SQL statement type must be either string, Path, or IO handle") diff --git a/tests/retention/test_cli.py b/tests/retention/test_cli.py index 01e1b5e4..b100e50a 100644 --- a/tests/retention/test_cli.py +++ b/tests/retention/test_cli.py @@ -61,7 +61,7 @@ def test_setup_verbose(caplog, cratedb, settings): assert result.exit_code == 0 assert cratedb.database.table_exists(settings.policy_table.fullname) is True - assert 3 <= len(caplog.records) <= 7 + assert 3 <= len(caplog.records) <= 10 def test_setup_dryrun(caplog, cratedb, settings):