From 0cbcbbaefe2202cd61f70affddcd747e859796c8 Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Sat, 21 Dec 2024 18:54:09 -0500 Subject: [PATCH] wrote test_time_query_with_hint --- env/integtest_pg_conn.py | 37 +++++++++++++++++++++++++------------ env/pg_conn.py | 20 ++++++++++---------- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/env/integtest_pg_conn.py b/env/integtest_pg_conn.py index 07154621..282cfca2 100644 --- a/env/integtest_pg_conn.py +++ b/env/integtest_pg_conn.py @@ -129,34 +129,47 @@ def test_time_query(self) -> None: self.assertIsNone(explain_data) def test_time_query_with_explain(self) -> None: - runtime, did_time_out, explain_data = self.pg_conn.time_query( + _, _, explain_data = self.pg_conn.time_query( "select pg_sleep(1)", add_explain=True ) - self.assertTrue(abs(runtime - 1_000_000) < 100_000) - self.assertFalse(did_time_out) self.assertIsNotNone(explain_data) def test_time_query_with_timeout(self) -> None: - runtime, did_time_out, explain_data = self.pg_conn.time_query( + runtime, did_time_out, _ = self.pg_conn.time_query( "select pg_sleep(3)", timeout=2 ) # The runtime should be about what the timeout is. self.assertTrue(abs(runtime - 2_000_000) < 100_000) self.assertTrue(did_time_out) - self.assertIsNone(explain_data) def test_time_query_with_valid_table(self) -> None: - _, did_time_out, explain_data = self.pg_conn.time_query( - "select * from lineitem limit 10" - ) - self.assertFalse(did_time_out) - self.assertIsNone(explain_data) + # This just ensures that it doesn't raise any errors. + self.pg_conn.time_query("select * from lineitem limit 10") def test_time_query_with_invalid_table(self) -> None: with self.assertRaises(psycopg.errors.UndefinedTable): - self.pg_conn.time_query( - "select * from itemline limit 10" + self.pg_conn.time_query("select * from itemline limit 10") + + def test_time_query_with_hint(self) -> None: + join_query = """SELECT * +FROM orders +JOIN lineitem ON o_orderkey = l_orderkey +WHERE o_orderdate BETWEEN '1995-01-01' AND '1995-12-31' +LIMIT 10""" + join_types = [ + ("MergeJoin", "Merge Join"), + ("HashJoin", "Hash Join"), + ("NestLoop", "Nested Loop"), + ] + + for hint_join_type, expected_join_type in join_types: + _, _, explain_data = self.pg_conn.time_query( + join_query, + query_knobs=[f"{hint_join_type}(lineitem orders)"], + add_explain=True, ) + actual_join_type = explain_data["Plan"]["Plans"][0]["Node Type"] + self.assertEqual(expected_join_type, actual_join_type) if __name__ == "__main__": diff --git a/env/pg_conn.py b/env/pg_conn.py index 104a4f56..ed131432 100644 --- a/env/pg_conn.py +++ b/env/pg_conn.py @@ -137,16 +137,16 @@ def time_query( did_time_out = False explain_data = None - def hint_notice_handler(notice) -> None: - """ - Custom handler for database notices. - Raises an error or logs the notice if it indicates a problem. - """ - logging.getLogger(DBGYM_LOGGER_NAME).warning(f"Postgres notice: {notice}") - if "hint" in notice.message.lower(): - raise RuntimeError(f"Query hint failed: {notice.message}") - - self.conn().add_notice_handler(hint_notice_handler) + # def hint_notice_handler(notice) -> None: + # """ + # Custom handler for database notices. + # Raises an error or logs the notice if it indicates a problem. + # """ + # logging.getLogger(DBGYM_LOGGER_NAME).warning(f"Postgres notice: {notice}") + # if "hint" in notice.message.lower(): + # raise RuntimeError(f"Query hint failed: {notice.message}") + + # self.conn().add_notice_handler(hint_notice_handler) try: if query_knobs: