From bcdf8278a7257b887813b731c7c73d4809c4c652 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Thu, 14 Nov 2024 09:46:17 -0500 Subject: [PATCH] fixed typing around explain_data --- env/pg_conn.py | 3 ++- tune/protox/env/types.py | 2 +- tune/protox/env/util/execute.py | 3 +-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/env/pg_conn.py b/env/pg_conn.py index 621fe084..acc4c13d 100644 --- a/env/pg_conn.py +++ b/env/pg_conn.py @@ -116,7 +116,8 @@ def time_query( Run a query with a timeout (in seconds). If you want to attach per-query knobs, attach them to the query string itself. Following Postgres's convention, timeout=0 indicates "disable timeout" - It returns the runtime, whether the query timed out, and the explain data if add_explain is True. + It returns the runtime, whether the query timed out, and the explain data if add_explain is True. Note that if + the query timed out, it won't have any explain data and thus explain_data will be None. If you write explain in the query manually instead of setting add_explain, it won't return explain_data. This is because it won't know the format of the explain data. diff --git a/tune/protox/env/types.py b/tune/protox/env/types.py index 35d3d8a0..47af8163 100644 --- a/tune/protox/env/types.py +++ b/tune/protox/env/types.py @@ -137,7 +137,7 @@ class ServerIndexMetadata(TypedDict, total=False): ("query_run", Optional[QueryRun]), ("runtime", Optional[float]), ("timed_out", bool), - ("explain_data", Optional[Any]), + ("explain_data", Optional[dict[str, Any]]), ("metric_data", Optional[dict[str, Any]]), ], ) diff --git a/tune/protox/env/util/execute.py b/tune/protox/env/util/execute.py index 52560682..7b6585e7 100644 --- a/tune/protox/env/util/execute.py +++ b/tune/protox/env/util/execute.py @@ -26,7 +26,7 @@ def _acquire_metrics_around_query( query: str, query_timeout: float = 0.0, observation_space: Optional[StateSpace] = None, -) -> tuple[float, bool, dict[str, Any], Any]: +) -> tuple[float, bool, Optional[dict[str, Any]], Any]: pg_conn.force_statement_timeout(0) if observation_space and observation_space.require_metrics(): initial_metrics = observation_space.construct_online(pg_conn.conn()) @@ -34,7 +34,6 @@ def _acquire_metrics_around_query( qid_runtime, did_time_out, explain_data = pg_conn.time_query( query, add_explain=True, timeout=query_timeout ) - assert explain_data is not None if observation_space and observation_space.require_metrics(): final_metrics = observation_space.construct_online(pg_conn.conn())