Skip to content

Commit

Permalink
Merge pull request #16 from alan-turing-institute/additional-checks
Browse files Browse the repository at this point in the history
Add additional ruff checks
  • Loading branch information
jemrobinson authored Sep 26, 2024
2 parents 41b4474 + 885ad8a commit 45e745e
Show file tree
Hide file tree
Showing 16 changed files with 594 additions and 373 deletions.
33 changes: 18 additions & 15 deletions guacamole_user_sync/ldap/ldap_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ldap.ldapobject import LDAPObject

from guacamole_user_sync.models import (
LDAPException,
LDAPError,
LDAPGroup,
LDAPQuery,
LDAPSearchResult,
Expand All @@ -16,6 +16,8 @@


class LDAPClient:
"""Client for connecting to an LDAP server."""

def __init__(
self,
hostname: str,
Expand All @@ -30,14 +32,14 @@ def __init__(

def connect(self) -> LDAPObject:
if not self.cnxn:
logger.info(f"Initialising connection to LDAP host at {self.hostname}")
logger.info("Initialising connection to LDAP host at %s", self.hostname)
self.cnxn = ldap.initialize(f"ldap://{self.hostname}")
if self.bind_dn:
try:
self.cnxn.simple_bind_s(self.bind_dn, self.bind_password)
except ldap.INVALID_CREDENTIALS as exc:
logger.warning("Connection credentials were incorrect.")
raise LDAPException from exc
raise LDAPError from exc
return self.cnxn

def search_groups(self, query: LDAPQuery) -> list[LDAPGroup]:
Expand All @@ -53,9 +55,9 @@ def search_groups(self, query: LDAPQuery) -> list[LDAPGroup]:
group.decode("utf-8") for group in attr_dict["memberUid"]
],
name=attr_dict[query.id_attr][0].decode("utf-8"),
)
),
)
logger.debug(f"Loaded {len(output)} LDAP groups")
logger.debug("Loaded %s LDAP groups", len(output))
return output

def search_users(self, query: LDAPQuery) -> list[LDAPUser]:
Expand All @@ -70,16 +72,16 @@ def search_users(self, query: LDAPQuery) -> list[LDAPUser]:
],
name=attr_dict[query.id_attr][0].decode("utf-8"),
uid=attr_dict["uid"][0].decode("utf-8"),
)
),
)
logger.debug(f"Loaded {len(output)} LDAP users")
logger.debug("Loaded %s LDAP users", len(output))
return output

def search(self, query: LDAPQuery) -> LDAPSearchResult:
results: LDAPSearchResult = []
logger.info("Querying LDAP host with:")
logger.info(f"... base DN: {query.base_dn}")
logger.info(f"... filter: {query.filter}")
logger.info("... base DN: %s", query.base_dn)
logger.info("... filter: %s", query.filter)
searcher = AsyncSearchList(self.connect())
try:
searcher.startSearch(
Expand All @@ -89,15 +91,16 @@ def search(self, query: LDAPQuery) -> LDAPSearchResult:
)
if searcher.processResults() != 0:
logger.warning("Only partial results received.")
results = searcher.allResults
logger.debug(f"Server returned {len(results)} results.")
return results
except ldap.NO_SUCH_OBJECT as exc:
logger.warning("Server returned no results.")
raise LDAPException from exc
raise LDAPError from exc
except ldap.SERVER_DOWN as exc:
logger.warning("Server could not be reached.")
raise LDAPException from exc
raise LDAPError from exc
except ldap.SIZELIMIT_EXCEEDED as exc:
logger.warning("Server-side size limit exceeded.")
raise LDAPException from exc
raise LDAPError from exc
else:
results = searcher.allResults
logger.debug("Server returned %s results.", len(results))
return results
8 changes: 5 additions & 3 deletions guacamole_user_sync/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from .exceptions import LDAPException, PostgreSQLException
from .exceptions import LDAPError, PostgreSQLError
from .guacamole import GuacamoleUserDetails
from .ldap_objects import LDAPGroup, LDAPUser
from .ldap_query import LDAPQuery

LDAPSearchResult = list[tuple[int, tuple[str, dict[str, list[bytes]]]]]

__all__ = [
"LDAPException",
"GuacamoleUserDetails",
"LDAPError",
"LDAPGroup",
"LDAPQuery",
"LDAPSearchResult",
"LDAPUser",
"PostgreSQLException",
"PostgreSQLError",
]
8 changes: 4 additions & 4 deletions guacamole_user_sync/models/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
class LDAPException(Exception):
pass
class LDAPError(Exception):
"""LDAP error."""


class PostgreSQLException(Exception):
pass
class PostgreSQLError(Exception):
"""PostgreSQL error."""
10 changes: 10 additions & 0 deletions guacamole_user_sync/models/guacamole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass


@dataclass
class GuacamoleUserDetails:
"""A Guacamole user with required attributes only."""

entity_id: int
full_name: str
name: str
3 changes: 2 additions & 1 deletion guacamole_user_sync/postgresql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .postgresql_backend import PostgreSQLBackend
from .postgresql_backend import PostgreSQLBackend, PostgreSQLConnectionDetails
from .postgresql_client import PostgreSQLClient
from .sql import SchemaVersion

__all__ = [
"PostgreSQLBackend",
"PostgreSQLConnectionDetails",
"PostgreSQLClient",
"SchemaVersion",
]
31 changes: 18 additions & 13 deletions guacamole_user_sync/postgresql/orm.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,56 @@
import enum
from datetime import datetime

from sqlalchemy import DateTime, Enum, Integer, String
from sqlalchemy.dialects.postgresql import BYTEA
from sqlalchemy.orm import ( # type:ignore
DeclarativeBase,
Mapped,
mapped_column,
)
from sqlalchemy import DateTime, Enum, Integer, LargeBinary, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column


class guacamole_entity_type(enum.Enum):
class GuacamoleEntityType(enum.Enum):
"""Guacamole entity enum."""

USER = "USER"
USER_GROUP = "USER_GROUP"


class GuacamoleBase(DeclarativeBase): # type:ignore
pass
class GuacamoleBase(DeclarativeBase): # type: ignore[misc]
"""Guacamole database base table."""


class GuacamoleEntity(GuacamoleBase):
"""Guacamole database GuacamoleEntity table."""

__tablename__ = "guacamole_entity"

entity_id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String(128))
type: Mapped[guacamole_entity_type] = mapped_column(Enum(guacamole_entity_type))
type: Mapped[GuacamoleEntityType] = mapped_column(Enum(GuacamoleEntityType))


class GuacamoleUser(GuacamoleBase):
"""Guacamole database GuacamoleUser table."""

__tablename__ = "guacamole_user"

user_id: Mapped[int] = mapped_column(Integer, primary_key=True)
entity_id: Mapped[int] = mapped_column(Integer)
full_name: Mapped[str] = mapped_column(String(256))
password_hash: Mapped[bytes] = mapped_column(BYTEA)
password_salt: Mapped[bytes] = mapped_column(BYTEA)
password_hash: Mapped[bytes] = mapped_column(LargeBinary)
password_salt: Mapped[bytes] = mapped_column(LargeBinary)
password_date: Mapped[datetime] = mapped_column(DateTime(timezone=True))


class GuacamoleUserGroup(GuacamoleBase):
"""Guacamole database GuacamoleUserGroup table."""

__tablename__ = "guacamole_user_group"

user_group_id: Mapped[int] = mapped_column(Integer, primary_key=True)
entity_id: Mapped[int] = mapped_column(Integer)


class GuacamoleUserGroupMember(GuacamoleBase):
"""Guacamole database GuacamoleUserGroupMember table."""

__tablename__ = "guacamole_user_group_member"

user_group_id: Mapped[int] = mapped_column(Integer, primary_key=True)
Expand Down
100 changes: 57 additions & 43 deletions guacamole_user_sync/postgresql/postgresql_backend.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
import logging
from typing import Any, Type, TypeVar
from dataclasses import dataclass
from typing import Any, TypeVar

from sqlalchemy import create_engine
from sqlalchemy.engine import URL, Engine # type:ignore
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import TextClause
from sqlalchemy import URL, Engine, TextClause, create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import DeclarativeBase, Session

logger = logging.getLogger("guacamole_user_sync")

T = TypeVar("T")

@dataclass
class PostgreSQLConnectionDetails:
"""Dataclass for holding PostgreSQL connection details."""

database_name: str
host_name: str
port: int
user_name: str
user_password: str


T = TypeVar("T", bound=DeclarativeBase)


class PostgreSQLBackend:
"""Backend for connecting to a PostgreSQL database."""

def __init__(
self,
*,
database_name: str,
host_name: str,
port: int,
user_name: str,
user_password: str,
connection_details: PostgreSQLConnectionDetails,
session: Session | None = None,
):
self.database_name = database_name
self.host_name = host_name
self.port = port
self.user_name = user_name
self.user_password = user_password
) -> None:
self.connection_details = connection_details
self._engine: Engine | None = None
self._session = session

Expand All @@ -35,11 +41,11 @@ def engine(self) -> Engine:
if not self._engine:
url_object = URL.create(
"postgresql+psycopg",
username=self.user_name,
password=self.user_password,
host=self.host_name,
port=self.port,
database=self.database_name,
username=self.connection_details.user_name,
password=self.connection_details.user_password,
host=self.connection_details.host_name,
port=self.connection_details.port,
database=self.connection_details.database_name,
)
self._engine = create_engine(url_object, echo=False)
return self._engine
Expand All @@ -50,29 +56,37 @@ def session(self) -> Session:
return Session(self.engine)

def add_all(self, items: list[T]) -> None:
with self.session() as session: # type:ignore
with session.begin():
session.add_all(items)

def delete(self, table: Type[T], *filter_args: Any) -> None:
with self.session() as session: # type:ignore
with session.begin():
if filter_args:
session.query(table).filter(*filter_args).delete()
else:
session.query(table).delete()
with self.session() as session, session.begin():
session.add_all(items)

def delete(
self,
table: type[T],
*filter_args: Any, # noqa: ANN401
) -> None:
with self.session() as session, session.begin():
if filter_args:
session.query(table).filter(*filter_args).delete()
else:
session.query(table).delete()

def execute_commands(self, commands: list[TextClause]) -> None:
with self.session() as session: # type:ignore
with session.begin():
try:
with self.session() as session, session.begin():
for command in commands:
session.execute(command)
except SQLAlchemyError:
logger.warning("Unable to execute PostgreSQL commands.")
raise

def query(self, table: Type[T], **filter_kwargs: Any) -> list[T]:
with self.session() as session: # type:ignore
with session.begin():
if filter_kwargs:
result = session.query(table).filter_by(**filter_kwargs)
else:
result = session.query(table)
return [item for item in result]
def query(
self,
table: type[T],
**filter_kwargs: Any, # noqa: ANN401
) -> list[T]:
with self.session() as session, session.begin():
if filter_kwargs:
result = session.query(table).filter_by(**filter_kwargs)
else:
result = session.query(table)
return list(result)
Loading

0 comments on commit 45e745e

Please sign in to comment.