Skip to content

Commit

Permalink
⬆️ Update to psycopg3
Browse files Browse the repository at this point in the history
  • Loading branch information
jemrobinson committed Sep 14, 2023
1 parent 1cfebf1 commit f87c529
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
40 changes: 24 additions & 16 deletions data_safe_haven/external/interface/azure_postgresql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Sequence
from typing import Any

import psycopg2
import psycopg
import requests
from azure.core.polling import LROPoller
from azure.mgmt.rdbms.postgresql import PostgreSQLManagementClient
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(
self.db_server_ = None
self.db_server_admin_password = database_server_admin_password
self.logger = LoggingSingleton()
self.port = 5432
self.resource_group_name = resource_group_name
self.server_name = database_server_name
self.rule_suffix = datetime.datetime.now(tz=datetime.timezone.utc).strftime(
Expand All @@ -64,6 +65,19 @@ def wait(poller: LROPoller[Any]) -> None:
while not poller.done():
time.sleep(10)

@property
def connection_string(self) -> str:
return " ".join(
[
f"dbname={self.db_name}",
f"host={self.db_server.fully_qualified_domain_name}",
f"password={self.db_server_admin_password}",
f"port={self.port}",
f"user={self.db_server.administrator_login}@{self.server_name}",
"sslmode=require",
]
)

@property
def db_client(self) -> PostgreSQLManagementClient:
"""Get the database client."""
Expand All @@ -76,26 +90,20 @@ def db_client(self) -> PostgreSQLManagementClient:
@property
def db_server(self) -> Server:
"""Get the database server."""
# self.logger.debug(f"Connecting to database using {self.connection_string}")
if not self.db_server_:
self.db_server_ = self.db_client.servers.get(
self.resource_group_name, self.server_name
)
return self.db_server_

def db_connection(self, n_retries: int = 0) -> psycopg2.extensions.connection:
def db_connection(self, n_retries: int = 0) -> psycopg.Connection:
"""Get the database connection."""
while True:
try:
connection = psycopg2.connect(
user=f"{self.db_server.administrator_login}@{self.server_name}",
password=self.db_server_admin_password,
host=self.db_server.fully_qualified_domain_name,
port="5432",
database=self.db_name,
sslmode="require",
)
connection = psycopg.connect(self.connection_string)
break
except psycopg2.OperationalError as exc:
except psycopg.OperationalError as exc:
if n_retries > 0:
n_retries -= 1
time.sleep(10)
Expand Down Expand Up @@ -124,7 +132,7 @@ def execute_scripts(
) -> list[list[str]]:
"""Execute scripts on the PostgreSQL server."""
outputs: list[list[str]] = []
connection: psycopg2.extensions.connection | None = None
connection: psycopg.Connection | None = None
cursor = None

try:
Expand All @@ -140,21 +148,21 @@ def execute_scripts(
_filepath = pathlib.Path(filepath)
self.logger.info(f"Running SQL script: [green]{_filepath.name}[/].")
commands = self.load_sql(_filepath, mustache_values)
cursor.execute(commands)
if "SELECT" in cursor.statusmessage:
cursor.execute(query=commands.encode())
if cursor.statusmessage and "SELECT" in cursor.statusmessage:
outputs += [[str(msg) for msg in msg_tuple] for msg_tuple in cursor]

# Commit changes
connection.commit()
self.logger.info(f"Finished running {len(filepaths)} SQL scripts.")
except (Exception, psycopg2.Error) as exc:
except (Exception, psycopg.Error) as exc:
msg = f"Error while connecting to PostgreSQL.\n{exc}"
raise DataSafeHavenAzureError(msg) from exc
finally:
# Close the connection if it is open
if connection:
if cursor:
cursor.close() # type: ignore
cursor.close()
connection.close()
# Remove temporary firewall rules
self.set_database_access("disabled")
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
"cryptography~=3.4.0",
"dnspython~=2.3.0",
"msal~=1.21.0",
"psycopg2~=2.9.0",
"psycopg~=3.1.10",
"pulumi~=3.80.0",
"pulumi-azure-native~=1.104.0",
"pytz~=2022.7.0",
Expand All @@ -62,7 +62,6 @@ dependencies = [
"ruff>=0.0.243",
"types-appdirs>=1.4.3.5",
"types-chevron>=0.14.2.5",
"types-psycopg2>=2.9.21.11",
"types-pytz>=2023.3.0.0",
"types-PyYAML>=6.0.12.11",
"types-requests>=2.31.0.2",
Expand Down Expand Up @@ -153,6 +152,7 @@ module = [
"cryptography.*",
"dns.*",
"msal.*",
"psycopg.*",
"pulumi.*",
"pulumi_azure_native.*",
"rich.*",
Expand Down

0 comments on commit f87c529

Please sign in to comment.