diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9ade21fc..c72b100d 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,3 +1,4 @@ +import time from typing import Dict, Tuple, List, Optional, Any, Union, Sequence import pandas @@ -430,6 +431,8 @@ def __init__( self.escaper = ParamEscaper() self.lastrowid = None + self.ASYNC_DEFAULT_POLLING_INTERVAL = 2 + # The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently. def __enter__(self) -> "Cursor": return self @@ -733,7 +736,6 @@ def execute( self, operation: str, parameters: Optional[TParameterCollection] = None, - async_op=False, ) -> "Cursor": """ Execute a query and wait for execution to complete. @@ -802,7 +804,7 @@ def execute( cursor=self, use_cloud_fetch=self.connection.use_cloud_fetch, parameters=prepared_params, - async_op=async_op, + async_op=False, ) self.active_result_set = ResultSet( self.connection, @@ -810,7 +812,6 @@ def execute( self.thrift_backend, self.buffer_size_bytes, self.arraysize, - async_op, ) if execute_response.is_staging_operation: @@ -829,12 +830,43 @@ def execute_async( Execute a query and do not wait for it to complete and just move ahead - Internally it calls execute function with async_op=True :param operation: :param parameters: :return: """ - self.execute(operation, parameters, True) + param_approach = self._determine_parameter_approach(parameters) + if param_approach == ParameterApproach.NONE: + prepared_params = NO_NATIVE_PARAMS + prepared_operation = operation + + elif param_approach == ParameterApproach.INLINE: + prepared_operation, prepared_params = self._prepare_inline_parameters( + operation, parameters + ) + elif param_approach == ParameterApproach.NATIVE: + normalized_parameters = self._normalize_tparametercollection(parameters) + param_structure = self._determine_parameter_structure(normalized_parameters) + transformed_operation = transform_paramstyle( + operation, normalized_parameters, param_structure + ) + prepared_operation, prepared_params = self._prepare_native_parameters( + transformed_operation, normalized_parameters, param_structure + ) + + self._check_not_closed() + self._close_and_clear_active_result_set() + self.thrift_backend.execute_command( + operation=prepared_operation, + session_handle=self.connection._session_handle, + max_rows=self.arraysize, + max_bytes=self.buffer_size_bytes, + lz4_compression=self.connection.lz4_compression, + cursor=self, + use_cloud_fetch=self.connection.use_cloud_fetch, + parameters=prepared_params, + async_op=True, + ) + return self def get_query_state(self) -> "TOperationState": @@ -846,15 +878,25 @@ def get_query_state(self) -> "TOperationState": self._check_not_closed() return self.thrift_backend.get_query_state(self.active_op_handle) - def get_execution_result(self): + def get_async_execution_result(self): """ Checks for the status of the async executing query and fetches the result if the query is finished - If executed sets the active_result_set to the obtained result + Otherwise it will keep polling the status of the query till there is a Not pending state :return: """ self._check_not_closed() + def is_executing(operation_state) -> "bool": + return not operation_state or operation_state in [ + ttypes.TOperationState.RUNNING_STATE, + ttypes.TOperationState.PENDING_STATE, + ] + + while(is_executing(self.get_query_state())): + # Poll after some default time + time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL) + operation_state = self.get_query_state() if operation_state == ttypes.TOperationState.FINISHED_STATE: execute_response = self.thrift_backend.get_execution_result( @@ -1164,7 +1206,6 @@ def __init__( thrift_backend: ThriftBackend, result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, arraysize: int = 10000, - async_op=False, ): """ A ResultSet manages the results of a single command. @@ -1187,7 +1228,7 @@ def __init__( self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._next_row_index = 0 - if execute_response.arrow_queue or async_op: + if execute_response.arrow_queue: # In this case the server has taken the fast path and returned an initial batch of # results self.results = execute_response.arrow_queue diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index ad7d41e3..dbfd5936 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -914,7 +914,7 @@ def execute_command( resp = self.make_request(self._client.ExecuteStatement, req) if async_op: - return self._handle_execute_response_async(resp, cursor) + self._handle_execute_response_async(resp, cursor) else: return self._handle_execute_response(resp, cursor) @@ -1018,19 +1018,6 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): cursor.active_op_handle = resp.operationHandle self._check_direct_results_for_error(resp.directResults) - operation_status = resp.status.statusCode - - return ExecuteResponse( - arrow_queue=None, - status=operation_status, - has_been_closed_server_side=None, - has_more_rows=None, - lz4_compressed=None, - is_staging_operation=None, - command_handle=resp.operationHandle, - description=None, - arrow_schema_bytes=None, - ) def fetch_results( self, diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 54fc7a38..2f0881cd 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -79,6 +79,7 @@ class PySQLPytestTestCase: } arraysize = 1000 buffer_size_bytes = 104857600 + POLLING_INTERVAL = 2 @pytest.fixture(autouse=True) def get_details(self, connection_details): @@ -187,12 +188,12 @@ def isExecuting(operation_state): with self.cursor() as cursor: cursor.execute_async(long_running_query) - ## Polling after every 10 seconds + ## Polling after every POLLING_INTERVAL seconds while isExecuting(cursor.get_query_state()): - time.sleep(10) + time.sleep(self.POLLING_INTERVAL) log.info("Polling the status in test_execute_async") - cursor.get_execution_result() + cursor.get_async_execution_result() result = cursor.fetchall() assert result[0].asDict() == {"count(1)": 0}