Skip to content

Commit

Permalink
Support parameterized statements in Presto Python Client
Browse files Browse the repository at this point in the history
Cherry-pick of trinodb/trino-python-client@a743855

Co-authored-by: Harrington Joseph <[email protected]>
  • Loading branch information
mlyublena and harph committed Aug 22, 2023
1 parent 08c2cca commit 5a687b4
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 7 deletions.
62 changes: 61 additions & 1 deletion integration_tests/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
# limitations under the License.
from __future__ import absolute_import, division, print_function

import integration_tests.fixtures as fixtures
from datetime import date, datetime
import numpy as np
import prestodb
import pytest

import integration_tests.fixtures as fixtures
from integration_tests.fixtures import run_presto
from prestodb.transaction import IsolationLevel

Expand Down Expand Up @@ -116,7 +119,64 @@ def test_select_failed_query(presto_connection):
cur.execute("select * from catalog.schema.do_not_exist")
cur.fetchall()

def test_select_query_result_iteration_statement_params(presto_connection):
cur = presto_connection.cursor()
cur.execute(
"""
select * from (
values
(1, 'one', 'a'),
(2, 'two', 'b'),
(3, 'three', 'c'),
(4, 'four', 'd'),
(5, 'five', 'e')
) x (id, name, letter)
where id >= ?
""",
params=(3,) # expecting all the rows with id >= 3
)

rows = cur.fetchall()
assert len(rows) == 3

for row in rows:
# Validate that all the ids of the returned rows are greather or equals than 3
assert row[0] >= 3


def test_select_query_param_types(presto_connection):
cur = presto_connection.cursor()

date_param = date.today()
timestamp_param = datetime.now().replace(microsecond=0)
float_param = 1.5
list_param = (1,2,3)
cur.execute(
"""
select ?,?,?,?
""",
params=(date_param, timestamp_param, float_param, list_param,)
)

rows = cur.fetchall()
assert len(rows) == 1
for row in rows:
assert date.fromisoformat(row[0]) == date_param
assert datetime.strptime(row[1], "%Y-%m-%d %H:%M:%S.%f") == timestamp_param
assert row[2] == float_param
assert (row[3] == np.array(list_param)).all()

@pytest.mark.parametrize('params', [
'NOT A LIST OR TUPPLE',
{'invalid', 'params'},
object,
])
def test_select_query_invalid_params(presto_connection, params):
cur = presto_connection.cursor()
with pytest.raises(AssertionError):
cur.execute('select ?', params=params)


def test_select_tpch_1000(presto_connection):
cur = presto_connection.cursor()
cur.execute("SELECT * FROM tpch.sf1.customer LIMIT 1000")
Expand Down
12 changes: 10 additions & 2 deletions prestodb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(
)
self._http_session.headers.update(self.get_oauth_token())

self.prepared_statements = []
self._http_session.headers.update(self.http_headers)
self._exceptions = self.HTTP_EXCEPTIONS
self._auth = auth
Expand Down Expand Up @@ -270,6 +271,8 @@ def http_headers(self):
headers[constants.HEADER_SCHEMA] = self._client_session.schema
headers[constants.HEADER_SOURCE] = self._client_session.source
headers[constants.HEADER_USER] = self._client_session.user
if len(self.prepared_statements) > 0:
headers[constants.HEADER_PREPARED_STATEMENT] = ",".join(self.prepared_statements)

headers[constants.HEADER_SESSION] = ",".join(
# ``name`` must not contain ``=``
Expand Down Expand Up @@ -417,6 +420,11 @@ def process(self, http_response):
):
self._client_session.properties[key] = value

if constants.HEADER_ADDED_PREPARE in http_response.headers:
self._http_session.headers[
constants.HEADER_PREPARED_STATEMENT
] = http_response.headers[constants.HEADER_ADDED_PREPARE]

self._next_uri = response.get("nextUri")

return PrestoStatus(
Expand Down Expand Up @@ -529,12 +537,12 @@ def execute(self):

response = self._request.post(self._sql)
status = self._request.process(response)
if status.next_uri is None:
self._finished = True
self.query_id = status.id
self._stats.update({"queryId": self.query_id})
self._stats.update(status.stats)
self._warnings = getattr(status, "warnings", [])
if status.next_uri is None:
self._finished = True
self._result = PrestoResult(self, status.rows)
while (
not self._finished and not self._cancelled
Expand Down
3 changes: 3 additions & 0 deletions prestodb/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
HEADER_STARTED_TRANSACTION = HEADER_PREFIX + "Started-Transaction-Id"
HEADER_TRANSACTION = HEADER_PREFIX + "Transaction-Id"

HEADER_PREPARED_STATEMENT = 'X-Presto-Prepared-Statement'
HEADER_ADDED_PREPARE = 'X-Presto-Added-Prepare'

PRESTO_EXTRA_CREDENTIAL = "X-Presto-Extra-Credential"
GCS_CREDENTIALS_OAUTH_TOKEN_KEY = "hive.gcs.oauth"

Expand Down
124 changes: 120 additions & 4 deletions prestodb/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,126 @@ def setoutputsize(self, size, column):
raise prestodb.exceptions.NotSupportedError

def execute(self, operation, params=None):
self._query = prestodb.client.PrestoQuery(self._request, sql=operation)
result = self._query.execute()
self._iterator = iter(result)
return result
if params:
assert isinstance(params, (list, tuple)), (
"params must be a list or tuple containing the query "
"parameter values"
)

statement_name = self._generate_unique_statement_name()
self._prepare_statement(operation, statement_name)

try:
# Send execute statement and assign the return value to `results`
# as it will be returned by the function
self._query = self._execute_prepared_statement(statement_name, params)
self._iterator = iter(self._query.execute())
finally:
# Send deallocate statement
# At this point the query can be deallocated since it has already
# been executed
# TODO: Consider caching prepared statements if requested by caller
self._deallocate_prepared_statement(statement_name)
else:
self._query = prestodb.client.PrestoQuery(self._request, sql=operation)
self._iterator = iter(self._query.execute())
return self

def _generate_unique_statement_name(self):
return "st_" + uuid.uuid4().hex.replace("-", "")

def _prepare_statement(self, statement: str, name: str) -> None:
sql = f"PREPARE {name} FROM {statement}"
query = prestodb.client.PrestoQuery(self._request, sql=sql)
query.execute()

def _execute_prepared_statement(self, statement_name, params):
sql = (
"EXECUTE "
+ statement_name
+ " USING "
+ ",".join(map(self._format_prepared_param, params))
)
return prestodb.client.PrestoQuery(self._request, sql=sql)

def _deallocate_prepared_statement(self, statement_name: str) -> None:
sql = "DEALLOCATE PREPARE " + statement_name
query = prestodb.client.PrestoQuery(self._request, sql=sql)
query.execute()

def _format_prepared_param(self, param):
"""
Formats parameters to be passed in an
EXECUTE statement.
"""
if param is None:
return "NULL"

if isinstance(param, bool):
return "true" if param else "false"

if isinstance(param, int):
# TODO represent numbers exceeding 64-bit (BIGINT) as DECIMAL
return "%d" % param

if isinstance(param, float):
if param == float("+inf"):
return "infinity()"
if param == float("-inf"):
return "-infinity()"
return "DOUBLE '%s'" % param

if isinstance(param, str):
return "'%s'" % param.replace("'", "''")

if isinstance(param, bytes):
return "X'%s'" % param.hex()

if isinstance(param, datetime.datetime) and param.tzinfo is None:
datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f")
return "TIMESTAMP '%s'" % datetime_str

if isinstance(param, datetime.datetime) and param.tzinfo is not None:
datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f")
# offset-based timezones
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.tzname(param))

# We can't calculate the offset for a time without a point in time
if isinstance(param, datetime.time) and param.tzinfo is None:
time_str = param.strftime("%H:%M:%S.%f")
return "TIME '%s'" % time_str

if isinstance(param, datetime.time) and param.tzinfo is not None:
time_str = param.strftime("%H:%M:%S.%f")
# offset-based timezones
return "TIME '%s %s'" % (time_str, param.strftime("%Z")[3:])

if isinstance(param, datetime.date):
date_str = param.strftime("%Y-%m-%d")
return "DATE '%s'" % date_str

if isinstance(param, list):
return "ARRAY[%s]" % ",".join(map(self._format_prepared_param, param))

if isinstance(param, tuple):
return "ROW(%s)" % ",".join(map(self._format_prepared_param, param))

if isinstance(param, dict):
keys = list(param.keys())
values = [param[key] for key in keys]
return "MAP({}, {})".format(
self._format_prepared_param(keys), self._format_prepared_param(values)
)

if isinstance(param, uuid.UUID):
return "UUID '%s'" % param

if isinstance(param, (bytes, bytearray)):
return "X'%s'" % binascii.hexlify(param).decode("utf-8")

raise prestodb.exceptions.NotSupportedError(
"Query parameter of type '%s' is not supported." % type(param)
)

def executemany(self, operation, seq_of_params):
raise prestodb.exceptions.NotSupportedError
Expand Down

0 comments on commit 5a687b4

Please sign in to comment.