Skip to content

Commit

Permalink
Refractored the async code
Browse files Browse the repository at this point in the history
  • Loading branch information
jprakash-db committed Nov 24, 2024
1 parent 8bf4442 commit b44b298
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 26 deletions.
59 changes: 50 additions & 9 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence

import pandas
Expand Down Expand Up @@ -430,6 +431,8 @@ def __init__(
self.escaper = ParamEscaper()
self.lastrowid = None

self.ASYNC_DEFAULT_POLLING_INTERVAL = 2

# The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently.
def __enter__(self) -> "Cursor":
return self
Expand Down Expand Up @@ -733,7 +736,6 @@ def execute(
self,
operation: str,
parameters: Optional[TParameterCollection] = None,
async_op=False,
) -> "Cursor":
"""
Execute a query and wait for execution to complete.
Expand Down Expand Up @@ -802,15 +804,14 @@ def execute(
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
parameters=prepared_params,
async_op=async_op,
async_op=False,
)
self.active_result_set = ResultSet(
self.connection,
execute_response,
self.thrift_backend,
self.buffer_size_bytes,
self.arraysize,
async_op,
)

if execute_response.is_staging_operation:
Expand All @@ -829,12 +830,43 @@ def execute_async(
Execute a query and do not wait for it to complete and just move ahead
Internally it calls execute function with async_op=True
:param operation:
:param parameters:
:return:
"""
self.execute(operation, parameters, True)
param_approach = self._determine_parameter_approach(parameters)
if param_approach == ParameterApproach.NONE:
prepared_params = NO_NATIVE_PARAMS
prepared_operation = operation

elif param_approach == ParameterApproach.INLINE:
prepared_operation, prepared_params = self._prepare_inline_parameters(
operation, parameters
)
elif param_approach == ParameterApproach.NATIVE:
normalized_parameters = self._normalize_tparametercollection(parameters)
param_structure = self._determine_parameter_structure(normalized_parameters)
transformed_operation = transform_paramstyle(
operation, normalized_parameters, param_structure
)
prepared_operation, prepared_params = self._prepare_native_parameters(
transformed_operation, normalized_parameters, param_structure
)

self._check_not_closed()
self._close_and_clear_active_result_set()
self.thrift_backend.execute_command(
operation=prepared_operation,
session_handle=self.connection._session_handle,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
parameters=prepared_params,
async_op=True,
)

return self

def get_query_state(self) -> "TOperationState":
Expand All @@ -846,15 +878,25 @@ def get_query_state(self) -> "TOperationState":
self._check_not_closed()
return self.thrift_backend.get_query_state(self.active_op_handle)

def get_execution_result(self):
def get_async_execution_result(self):
"""
Checks for the status of the async executing query and fetches the result if the query is finished
If executed sets the active_result_set to the obtained result
Otherwise it will keep polling the status of the query till there is a Not pending state
:return:
"""
self._check_not_closed()

def is_executing(operation_state) -> "bool":
return not operation_state or operation_state in [
ttypes.TOperationState.RUNNING_STATE,
ttypes.TOperationState.PENDING_STATE,
]

while(is_executing(self.get_query_state())):
# Poll after some default time
time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL)

operation_state = self.get_query_state()
if operation_state == ttypes.TOperationState.FINISHED_STATE:
execute_response = self.thrift_backend.get_execution_result(
Expand Down Expand Up @@ -1164,7 +1206,6 @@ def __init__(
thrift_backend: ThriftBackend,
result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
arraysize: int = 10000,
async_op=False,
):
"""
A ResultSet manages the results of a single command.
Expand All @@ -1187,7 +1228,7 @@ def __init__(
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
self._next_row_index = 0

if execute_response.arrow_queue or async_op:
if execute_response.arrow_queue:
# In this case the server has taken the fast path and returned an initial batch of
# results
self.results = execute_response.arrow_queue
Expand Down
15 changes: 1 addition & 14 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def execute_command(
resp = self.make_request(self._client.ExecuteStatement, req)

if async_op:
return self._handle_execute_response_async(resp, cursor)
self._handle_execute_response_async(resp, cursor)
else:
return self._handle_execute_response(resp, cursor)

Expand Down Expand Up @@ -1018,19 +1018,6 @@ def _handle_execute_response(self, resp, cursor):
def _handle_execute_response_async(self, resp, cursor):
cursor.active_op_handle = resp.operationHandle
self._check_direct_results_for_error(resp.directResults)
operation_status = resp.status.statusCode

return ExecuteResponse(
arrow_queue=None,
status=operation_status,
has_been_closed_server_side=None,
has_more_rows=None,
lz4_compressed=None,
is_staging_operation=None,
command_handle=resp.operationHandle,
description=None,
arrow_schema_bytes=None,
)

def fetch_results(
self,
Expand Down
7 changes: 4 additions & 3 deletions tests/e2e/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class PySQLPytestTestCase:
}
arraysize = 1000
buffer_size_bytes = 104857600
POLLING_INTERVAL = 2

@pytest.fixture(autouse=True)
def get_details(self, connection_details):
Expand Down Expand Up @@ -187,12 +188,12 @@ def isExecuting(operation_state):
with self.cursor() as cursor:
cursor.execute_async(long_running_query)

## Polling after every 10 seconds
## Polling after every POLLING_INTERVAL seconds
while isExecuting(cursor.get_query_state()):
time.sleep(10)
time.sleep(self.POLLING_INTERVAL)
log.info("Polling the status in test_execute_async")

cursor.get_execution_result()
cursor.get_async_execution_result()
result = cursor.fetchall()

assert result[0].asDict() == {"count(1)": 0}
Expand Down

0 comments on commit b44b298

Please sign in to comment.