Skip to content

Commit

Permalink
added add_explain option to time_query
Browse files Browse the repository at this point in the history
  • Loading branch information
wangpatrick57 committed Nov 14, 2024
1 parent 3d714e8 commit 3dd734a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
4 changes: 2 additions & 2 deletions env/integtest_pg_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,14 @@ def test_time_query(self) -> None:

# Testing with explain.
runtime, did_time_out, explain_data = pg_conn.time_query(
"explain (analyze, format json, timing off) select pg_sleep(1)"
"select pg_sleep(1)", add_explain=True
)
self.assertTrue(abs(runtime - 1_000_000) < 100_000)
self.assertFalse(did_time_out)
self.assertIsNotNone(explain_data)

# Testing with timeout.
runtime, did_time_out, _ = pg_conn.time_query("select pg_sleep(3)", 2)
runtime, did_time_out, _ = pg_conn.time_query("select pg_sleep(3)", timeout=2)
# The runtime should be about what the timeout is.
self.assertTrue(abs(runtime - 2_000_000) < 100_000)
self.assertTrue(did_time_out)
Expand Down
36 changes: 22 additions & 14 deletions env/pg_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def move_log(self) -> None:
shutil.move(pglog_fpath, pglog_this_step_fpath)
self.log_step += 1

def force_statement_timeout(self, timeout_sec: float) -> None:
timeout_ms = timeout_sec * 1000
def force_statement_timeout(self, timeout: float) -> None:
timeout_ms = timeout * 1000
retry = True
while retry:
retry = False
Expand All @@ -110,31 +110,39 @@ def force_statement_timeout(self, timeout_sec: float) -> None:
retry = True

def time_query(
self, query: str, timeout_sec: float = 0
self, query: str, add_explain: bool = False, timeout: float = 0
) -> tuple[float, bool, Optional[dict[str, Any]]]:
"""
Run a query with a timeout. If you want to attach per-query knobs, attach them to the query string itself.
Following Postgres's convention, timeout=0 indicates "disable timeout"
Run a query with a timeout (in seconds). If you want to attach per-query knobs, attach them to the query string
itself. Following Postgres's convention, timeout=0 indicates "disable timeout"
It returns the runtime, whether the query timed out, and the explain data.
It returns the runtime, whether the query timed out, and the explain data if add_explain is True.
If you write explain in the query manually instead of setting add_explain, it won't return explain_data. This
is because it won't know the format of the explain data.
"""
if timeout_sec > 0:
self.force_statement_timeout(timeout_sec)
if timeout > 0:
self.force_statement_timeout(timeout)
else:
assert (
timeout_sec == 0
), f'Setting timeout_sec to 0 indicates "disable timeout". However, setting timeout_sec ({timeout_sec}) < 0 is a bug.'
timeout == 0
), f'Setting timeout to 0 indicates "disable timeout". However, setting timeout ({timeout}) < 0 is a bug.'

did_time_out = False
has_explain = "explain" in query.lower()
explain_data = None

try:
if add_explain:
assert (
"explain" not in query.lower()
), "If you're using add_explain, don't also write explain manually in the query."
query = f"explain (analyze, format json, timing off) {query}"

start_time = time.time()
cursor = self.conn().execute(query)
qid_runtime = (time.time() - start_time) * 1e6

if has_explain:
if add_explain:
c = [c for c in cursor][0][0][0]
assert "Execution Time" in c
qid_runtime = float(c["Execution Time"]) * 1e3
Expand All @@ -146,9 +154,9 @@ def time_query(

except QueryCanceled:
logging.getLogger(DBGYM_LOGGER_NAME).debug(
f"{query} exceeded evaluation timeout {timeout_sec}"
f"{query} exceeded evaluation timeout {timeout}"
)
qid_runtime = timeout_sec * 1e6
qid_runtime = timeout * 1e6
did_time_out = True
except Exception as e:
assert False, e
Expand Down
9 changes: 5 additions & 4 deletions tune/protox/env/util/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ def _acquire_metrics_around_query(
query: str,
query_timeout: float = 0.0,
observation_space: Optional[StateSpace] = None,
) -> tuple[float, bool, Any, Any]:
) -> tuple[float, bool, dict[str, Any], Any]:
pg_conn.force_statement_timeout(0)
if observation_space and observation_space.require_metrics():
initial_metrics = observation_space.construct_online(pg_conn.conn())

qid_runtime, did_time_out, explain_data = pg_conn.time_query(query, query_timeout)
qid_runtime, did_time_out, explain_data = pg_conn.time_query(
query, add_explain=True, timeout=query_timeout
)
assert explain_data is not None

if observation_space and observation_space.require_metrics():
final_metrics = observation_space.construct_online(pg_conn.conn())
Expand Down Expand Up @@ -74,8 +77,6 @@ def execute_variations(
+ " */"
+ query
)
# Log the query plan.
pqk_query = "EXPLAIN (ANALYZE, FORMAT JSON, TIMING OFF) " + pqk_query

# Log out the knobs that we are using.
pqkk = [(knob.name(), val) for knob, val in qr.qknobs.items()]
Expand Down

0 comments on commit 3dd734a

Please sign in to comment.