diff --git a/integration_tests/test_dbapi.py b/integration_tests/test_dbapi.py index c20c18b..d76c26a 100644 --- a/integration_tests/test_dbapi.py +++ b/integration_tests/test_dbapi.py @@ -11,9 +11,12 @@ # limitations under the License. from __future__ import absolute_import, division, print_function -import integration_tests.fixtures as fixtures +from datetime import date, datetime +import numpy as np import prestodb import pytest + +import integration_tests.fixtures as fixtures from integration_tests.fixtures import run_presto from prestodb.transaction import IsolationLevel @@ -116,7 +119,64 @@ def test_select_failed_query(presto_connection): cur.execute("select * from catalog.schema.do_not_exist") cur.fetchall() +def test_select_query_result_iteration_statement_params(presto_connection): + cur = presto_connection.cursor() + cur.execute( + """ + select * from ( + values + (1, 'one', 'a'), + (2, 'two', 'b'), + (3, 'three', 'c'), + (4, 'four', 'd'), + (5, 'five', 'e') + ) x (id, name, letter) + where id >= ? + """, + params=(3,) # expecting all the rows with id >= 3 + ) + + rows = cur.fetchall() + assert len(rows) == 3 + + for row in rows: + # Validate that all the ids of the returned rows are greather or equals than 3 + assert row[0] >= 3 + + +def test_select_query_param_types(presto_connection): + cur = presto_connection.cursor() + + date_param = date.today() + timestamp_param = datetime.now().replace(microsecond=0) + float_param = 1.5 + list_param = (1,2,3) + cur.execute( + """ + select ?,?,?,? + """, + params=(date_param, timestamp_param, float_param, list_param,) + ) + + rows = cur.fetchall() + assert len(rows) == 1 + for row in rows: + assert date.fromisoformat(row[0]) == date_param + assert datetime.strptime(row[1], "%Y-%m-%d %H:%M:%S.%f") == timestamp_param + assert row[2] == float_param + assert (row[3] == np.array(list_param)).all() + +@pytest.mark.parametrize('params', [ + 'NOT A LIST OR TUPPLE', + {'invalid', 'params'}, + object, +]) +def test_select_query_invalid_params(presto_connection, params): + cur = presto_connection.cursor() + with pytest.raises(AssertionError): + cur.execute('select ?', params=params) + def test_select_tpch_1000(presto_connection): cur = presto_connection.cursor() cur.execute("SELECT * FROM tpch.sf1.customer LIMIT 1000") diff --git a/prestodb/client.py b/prestodb/client.py index 5b335ae..1bf82fe 100644 --- a/prestodb/client.py +++ b/prestodb/client.py @@ -238,6 +238,7 @@ def __init__( ) self._http_session.headers.update(self.get_oauth_token()) + self.prepared_statements = [] self._http_session.headers.update(self.http_headers) self._exceptions = self.HTTP_EXCEPTIONS self._auth = auth @@ -270,6 +271,8 @@ def http_headers(self): headers[constants.HEADER_SCHEMA] = self._client_session.schema headers[constants.HEADER_SOURCE] = self._client_session.source headers[constants.HEADER_USER] = self._client_session.user + if len(self.prepared_statements) > 0: + headers[constants.HEADER_PREPARED_STATEMENT] = ",".join(self.prepared_statements) headers[constants.HEADER_SESSION] = ",".join( # ``name`` must not contain ``=`` @@ -417,6 +420,11 @@ def process(self, http_response): ): self._client_session.properties[key] = value + if constants.HEADER_ADDED_PREPARE in http_response.headers: + self._http_session.headers[ + constants.HEADER_PREPARED_STATEMENT + ] = http_response.headers[constants.HEADER_ADDED_PREPARE] + self._next_uri = response.get("nextUri") return PrestoStatus( @@ -529,12 +537,12 @@ def execute(self): response = self._request.post(self._sql) status = self._request.process(response) - if status.next_uri is None: - self._finished = True self.query_id = status.id self._stats.update({"queryId": self.query_id}) self._stats.update(status.stats) self._warnings = getattr(status, "warnings", []) + if status.next_uri is None: + self._finished = True self._result = PrestoResult(self, status.rows) while ( not self._finished and not self._cancelled diff --git a/prestodb/constants.py b/prestodb/constants.py index ab7cc55..99e72d9 100644 --- a/prestodb/constants.py +++ b/prestodb/constants.py @@ -42,6 +42,9 @@ HEADER_STARTED_TRANSACTION = HEADER_PREFIX + "Started-Transaction-Id" HEADER_TRANSACTION = HEADER_PREFIX + "Transaction-Id" +HEADER_PREPARED_STATEMENT = 'X-Presto-Prepared-Statement' +HEADER_ADDED_PREPARE = 'X-Presto-Added-Prepare' + PRESTO_EXTRA_CREDENTIAL = "X-Presto-Extra-Credential" GCS_CREDENTIALS_OAUTH_TOKEN_KEY = "hive.gcs.oauth" diff --git a/prestodb/dbapi.py b/prestodb/dbapi.py index 4041959..cc60bf7 100644 --- a/prestodb/dbapi.py +++ b/prestodb/dbapi.py @@ -242,10 +242,126 @@ def setoutputsize(self, size, column): raise prestodb.exceptions.NotSupportedError def execute(self, operation, params=None): - self._query = prestodb.client.PrestoQuery(self._request, sql=operation) - result = self._query.execute() - self._iterator = iter(result) - return result + if params: + assert isinstance(params, (list, tuple)), ( + "params must be a list or tuple containing the query " + "parameter values" + ) + + statement_name = self._generate_unique_statement_name() + self._prepare_statement(operation, statement_name) + + try: + # Send execute statement and assign the return value to `results` + # as it will be returned by the function + self._query = self._execute_prepared_statement(statement_name, params) + self._iterator = iter(self._query.execute()) + finally: + # Send deallocate statement + # At this point the query can be deallocated since it has already + # been executed + # TODO: Consider caching prepared statements if requested by caller + self._deallocate_prepared_statement(statement_name) + else: + self._query = prestodb.client.PrestoQuery(self._request, sql=operation) + self._iterator = iter(self._query.execute()) + return self + + def _generate_unique_statement_name(self): + return "st_" + uuid.uuid4().hex.replace("-", "") + + def _prepare_statement(self, statement: str, name: str) -> None: + sql = f"PREPARE {name} FROM {statement}" + query = prestodb.client.PrestoQuery(self._request, sql=sql) + query.execute() + + def _execute_prepared_statement(self, statement_name, params): + sql = ( + "EXECUTE " + + statement_name + + " USING " + + ",".join(map(self._format_prepared_param, params)) + ) + return prestodb.client.PrestoQuery(self._request, sql=sql) + + def _deallocate_prepared_statement(self, statement_name: str) -> None: + sql = "DEALLOCATE PREPARE " + statement_name + query = prestodb.client.PrestoQuery(self._request, sql=sql) + query.execute() + + def _format_prepared_param(self, param): + """ + Formats parameters to be passed in an + EXECUTE statement. + """ + if param is None: + return "NULL" + + if isinstance(param, bool): + return "true" if param else "false" + + if isinstance(param, int): + # TODO represent numbers exceeding 64-bit (BIGINT) as DECIMAL + return "%d" % param + + if isinstance(param, float): + if param == float("+inf"): + return "infinity()" + if param == float("-inf"): + return "-infinity()" + return "DOUBLE '%s'" % param + + if isinstance(param, str): + return "'%s'" % param.replace("'", "''") + + if isinstance(param, bytes): + return "X'%s'" % param.hex() + + if isinstance(param, datetime.datetime) and param.tzinfo is None: + datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f") + return "TIMESTAMP '%s'" % datetime_str + + if isinstance(param, datetime.datetime) and param.tzinfo is not None: + datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f") + # offset-based timezones + return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.tzname(param)) + + # We can't calculate the offset for a time without a point in time + if isinstance(param, datetime.time) and param.tzinfo is None: + time_str = param.strftime("%H:%M:%S.%f") + return "TIME '%s'" % time_str + + if isinstance(param, datetime.time) and param.tzinfo is not None: + time_str = param.strftime("%H:%M:%S.%f") + # offset-based timezones + return "TIME '%s %s'" % (time_str, param.strftime("%Z")[3:]) + + if isinstance(param, datetime.date): + date_str = param.strftime("%Y-%m-%d") + return "DATE '%s'" % date_str + + if isinstance(param, list): + return "ARRAY[%s]" % ",".join(map(self._format_prepared_param, param)) + + if isinstance(param, tuple): + return "ROW(%s)" % ",".join(map(self._format_prepared_param, param)) + + if isinstance(param, dict): + keys = list(param.keys()) + values = [param[key] for key in keys] + return "MAP({}, {})".format( + self._format_prepared_param(keys), self._format_prepared_param(values) + ) + + if isinstance(param, uuid.UUID): + return "UUID '%s'" % param + + if isinstance(param, (bytes, bytearray)): + return "X'%s'" % binascii.hexlify(param).decode("utf-8") + + raise prestodb.exceptions.NotSupportedError( + "Query parameter of type '%s' is not supported." % type(param) + ) def executemany(self, operation, seq_of_params): raise prestodb.exceptions.NotSupportedError