Skip to content

Commit

Permalink
Add SQL runner utility .io.sql.run_sql
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Nov 7, 2023
1 parent da0f0f5 commit 486e720
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 8 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ jobs:
# Do not tear down Testcontainers
TC_KEEPALIVE: true

# https://docs.github.com/en/actions/using-containerized-services/about-service-containers
services:
cratedb:
image: crate/crate:nightly
ports:
- 4200:4200
- 5432:5432

name: Python ${{ matrix.python-version }} on OS ${{ matrix.os }}
steps:

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*.egg-info
.eggs
__pycache__
*.pyc
dist
.coverage
coverage.xml
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

## Unreleased

- Add SQL runner utility primitives to `io.sql` namespace


## 2023/11/06 v0.0.2
Expand Down
Empty file added cratedb_toolkit/io/__init__.py
Empty file.
1 change: 1 addition & 0 deletions cratedb_toolkit/io/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from cratedb_toolkit.util.database import DatabaseAdapter, run_sql # noqa: F401
41 changes: 34 additions & 7 deletions cratedb_toolkit/util/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Copyright (c) 2023, Crate.io Inc.
# Distributed under the terms of the AGPLv3 license, see LICENSE.
import io
import typing as t
from pathlib import Path

import sqlalchemy as sa
import sqlparse
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.sql.elements import AsBoolean

Expand All @@ -22,12 +27,23 @@ def __init__(self, dburi: str):
self.engine = sa.create_engine(self.dburi, echo=False)
self.connection = self.engine.connect()

def run_sql(self, sql: str, records: bool = False, ignore: str = None):
def run_sql(self, sql: t.Union[str, Path, io.IOBase], records: bool = False, ignore: str = None):
"""
Run SQL statement, and return results, optionally ignoring exceptions.
"""

sql_effective: str
if isinstance(sql, str):
sql_effective = sql
elif isinstance(sql, Path):
sql_effective = sql.read_text()
elif isinstance(sql, io.IOBase):
sql_effective = sql.read()
else:
raise TypeError("SQL statement type must be either string, Path, or IO handle")

try:
return self.run_sql_real(sql=sql, records=records)
return self.run_sql_real(sql=sql_effective, records=records)
except Exception as ex:
if not ignore:
raise
Expand All @@ -38,12 +54,23 @@ def run_sql_real(self, sql: str, records: bool = False):
"""
Invoke SQL statement, and return results.
"""
result = self.connection.execute(sa.text(sql))
if records:
rows = result.mappings().fetchall()
return [dict(row.items()) for row in rows]
results = []
with self.engine.connect() as connection:
for statement in sqlparse.split(sql):
result = connection.execute(sa.text(statement))
data: t.Any
if records:
rows = result.mappings().fetchall()
data = [dict(row.items()) for row in rows]
else:
data = result.fetchall()
results.append(data)

# Backward-compatibility.
if len(results) == 1:
return results[0]
else:
return result.fetchall()
return results

def count_records(self, tablename_full: str):
"""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ dependencies = [
"crash",
"crate[sqlalchemy]>=0.34",
"sqlalchemy>=2",
"sqlparse<0.5",
]
[project.optional-dependencies]
develop = [
Expand Down
Empty file added tests/io/__init__.py
Empty file.
66 changes: 66 additions & 0 deletions tests/io/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import io

import pytest
import sqlalchemy as sa

from cratedb_toolkit.io.sql import run_sql


@pytest.fixture
def sqlcmd(cratedb):
def workhorse(sql: str, records: bool = True):
return run_sql(dburi=cratedb.database.dburi, sql=sql, records=records)

return workhorse


def test_run_sql_direct(cratedb):
sql_string = "SELECT 1;"
outcome = run_sql(dburi=cratedb.database.dburi, sql=sql_string, records=True)
assert outcome == [{"1": 1}]


def test_run_sql_from_string(sqlcmd):
sql_string = "SELECT 1;"
outcome = sqlcmd(sql_string)
assert outcome == [{"1": 1}]


def test_run_sql_from_file(sqlcmd, tmp_path):
sql_file = tmp_path / "temp.sql"
sql_file.write_text("SELECT 1;")
outcome = sqlcmd(sql_file)
assert outcome == [{"1": 1}]


def test_run_sql_from_buffer(sqlcmd):
sql_buffer = io.StringIO("SELECT 1;")
outcome = sqlcmd(sql_buffer)
assert outcome == [{"1": 1}]


def test_run_sql_no_records(sqlcmd):
sql_string = "SELECT 1;"
outcome = sqlcmd(sql_string, records=False)
assert outcome == [(1,)]


def test_run_sql_multiple_statements(sqlcmd):
sql_string = "SELECT 1; SELECT 42;"
outcome = sqlcmd(sql_string)
assert outcome == [[{"1": 1}], [{"42": 42}]]


def test_run_sql_invalid_host(capsys):
with pytest.raises(sa.exc.OperationalError) as ex:
run_sql(dburi="crate://localhost:12345", sql="SELECT 1;")
assert ex.match(
".*ConnectionError.*No more Servers available.*HTTPConnectionPool.*"
"Failed to establish a new connection.*Connection refused.*"
)


def test_run_sql_invalid_sql_type(capsys, sqlcmd):
with pytest.raises(TypeError) as ex:
sqlcmd(None)
assert ex.match("SQL statement type must be either string, Path, or IO handle")
2 changes: 1 addition & 1 deletion tests/retention/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_setup_verbose(caplog, cratedb, settings):
assert result.exit_code == 0

assert cratedb.database.table_exists(settings.policy_table.fullname) is True
assert 3 <= len(caplog.records) <= 7
assert 3 <= len(caplog.records) <= 10


def test_setup_dryrun(caplog, cratedb, settings):
Expand Down

0 comments on commit 486e720

Please sign in to comment.