From 885ad8ac0a0e0ba511988cffc6c812fb74ca2221 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 26 Sep 2024 14:16:22 +0100 Subject: [PATCH] :bug: Fix possible exception in update_user_entities --- guacamole_user_sync/models/__init__.py | 2 + guacamole_user_sync/models/guacamole.py | 10 ++++ .../postgresql/postgresql_backend.py | 13 +++-- .../postgresql/postgresql_client.py | 48 +++++++++++-------- pyproject.toml | 1 + tests/mocks.py | 6 +-- tests/test_postgresql.py | 2 +- 7 files changed, 52 insertions(+), 30 deletions(-) create mode 100644 guacamole_user_sync/models/guacamole.py diff --git a/guacamole_user_sync/models/__init__.py b/guacamole_user_sync/models/__init__.py index c20a684..9695664 100644 --- a/guacamole_user_sync/models/__init__.py +++ b/guacamole_user_sync/models/__init__.py @@ -1,10 +1,12 @@ from .exceptions import LDAPError, PostgreSQLError +from .guacamole import GuacamoleUserDetails from .ldap_objects import LDAPGroup, LDAPUser from .ldap_query import LDAPQuery LDAPSearchResult = list[tuple[int, tuple[str, dict[str, list[bytes]]]]] __all__ = [ + "GuacamoleUserDetails", "LDAPError", "LDAPGroup", "LDAPQuery", diff --git a/guacamole_user_sync/models/guacamole.py b/guacamole_user_sync/models/guacamole.py new file mode 100644 index 0000000..a1771ba --- /dev/null +++ b/guacamole_user_sync/models/guacamole.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + + +@dataclass +class GuacamoleUserDetails: + """A Guacamole user with required attributes only.""" + + entity_id: int + full_name: str + name: str diff --git a/guacamole_user_sync/postgresql/postgresql_backend.py b/guacamole_user_sync/postgresql/postgresql_backend.py index 542dd97..f939166 100644 --- a/guacamole_user_sync/postgresql/postgresql_backend.py +++ b/guacamole_user_sync/postgresql/postgresql_backend.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import Any +from typing import Any, TypeVar from sqlalchemy import URL, Engine, TextClause, create_engine from sqlalchemy.exc import SQLAlchemyError @@ -20,6 +20,9 @@ class PostgreSQLConnectionDetails: user_password: str +T = TypeVar("T", bound=DeclarativeBase) + + class PostgreSQLBackend: """Backend for connecting to a PostgreSQL database.""" @@ -52,13 +55,13 @@ def session(self) -> Session: return self._session return Session(self.engine) - def add_all(self, items: list[DeclarativeBase]) -> None: + def add_all(self, items: list[T]) -> None: with self.session() as session, session.begin(): session.add_all(items) def delete( self, - table: type[DeclarativeBase], + table: type[T], *filter_args: Any, # noqa: ANN401 ) -> None: with self.session() as session, session.begin(): @@ -78,9 +81,9 @@ def execute_commands(self, commands: list[TextClause]) -> None: def query( self, - table: type[DeclarativeBase], + table: type[T], **filter_kwargs: Any, # noqa: ANN401 - ) -> list[DeclarativeBase]: + ) -> list[T]: with self.session() as session, session.begin(): if filter_kwargs: result = session.query(table).filter_by(**filter_kwargs) diff --git a/guacamole_user_sync/postgresql/postgresql_client.py b/guacamole_user_sync/postgresql/postgresql_client.py index ce7295a..3086e7e 100644 --- a/guacamole_user_sync/postgresql/postgresql_client.py +++ b/guacamole_user_sync/postgresql/postgresql_client.py @@ -5,6 +5,7 @@ from sqlalchemy.exc import SQLAlchemyError from guacamole_user_sync.models import ( + GuacamoleUserDetails, LDAPGroup, LDAPUser, PostgreSQLError, @@ -60,27 +61,27 @@ def assign_users_to_groups( logger.debug("Working on group '%s'", group.name) # Get the user_group_id for each group (via looking up the entity_id) try: - group_entity_id = [ + group_entity_id = next( item.entity_id for item in self.backend.query( GuacamoleEntity, name=group.name, type=GuacamoleEntityType.USER_GROUP, ) - ][0] - user_group_id = [ + ) + user_group_id = next( item.user_group_id for item in self.backend.query( GuacamoleUserGroup, entity_id=group_entity_id, ) - ][0] + ) logger.debug( "-> entity_id: %s; user_group_id: %s", group_entity_id, user_group_id, ) - except IndexError: + except StopIteration: logger.debug( "Could not determine user_group_id for group '%s'.", group.name, @@ -94,20 +95,20 @@ def assign_users_to_groups( logger.debug("Could not find LDAP user with UID %s", user_uid) continue try: - user_entity_id = [ + user_entity_id = next( item.entity_id for item in self.backend.query( GuacamoleEntity, name=user.name, type=GuacamoleEntityType.USER, ) - ][0] + ) logger.debug( "... group member '%s' has entity_id '%s'", user, user_entity_id, ) - except IndexError: + except StopIteration: logger.debug( "Could not find entity ID for LDAP user '%s'", user_uid, @@ -284,28 +285,33 @@ def update_user_entities(self, users: list[LDAPUser]) -> None: "There are %s user entit(y|ies) currently registered", len(current_user_entity_ids), ) - new_user_tuples: list[tuple[int, LDAPUser]] = [ - (user.entity_id, [u for u in users if u.name == user.name][0]) - for user in self.backend.query( - GuacamoleEntity, - type=GuacamoleEntityType.USER, + user_entities = self.backend.query( + GuacamoleEntity, + type=GuacamoleEntityType.USER, + ) + new_users = [ + GuacamoleUserDetails( + entity_id=entity.entity_id, + full_name=user.display_name, + name=user.name, ) - if user.entity_id not in current_user_entity_ids + for user in users + for entity in user_entities + if entity.name == user.name + and entity.entity_id not in current_user_entity_ids ] - logger.debug( - "... %s user entit(y|ies) will be added", - len(current_user_entity_ids), - ) + logger.debug("... %s user entit(y|ies) will be added", len(new_users)) + self.backend.add_all( [ GuacamoleUser( - entity_id=user_tuple[0], - full_name=user_tuple[1].display_name, + entity_id=new_user.entity_id, + full_name=new_user.full_name, password_date=datetime.now(tz=UTC), password_hash=secrets.token_bytes(32), password_salt=secrets.token_bytes(32), ) - for user_tuple in new_user_tuples + for new_user in new_users ], ) # Clean up any unused entries diff --git a/pyproject.toml b/pyproject.toml index 11f5a07..cdfeecb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,6 +135,7 @@ select = [ "Q", # flake8-quotes "RET", # flake8-return "RSE", # flake8-rse + "RUF", # Ruff-specific rules "S", # flake8-bandit "SIM", # flake8-simplify "SLF", # flake8-self diff --git a/tests/mocks.py b/tests/mocks.py index e6ba6be..02a7b24 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -37,8 +37,8 @@ def __init__( def startSearch( # noqa: N802 self, - *args: Any, # noqa: ANN401, ARG002 - **kwargs: Any, # noqa: ANN401, ARG002 + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 ) -> None: pass @@ -78,7 +78,7 @@ def add_all(self, items: list[GuacamoleBase]) -> None: self.contents[cls] = [] self.contents[cls] += items - def delete(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, ARG001 + def delete(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 pass def execute_commands(self, commands: list[TextClause]) -> None: diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 989b2e5..3c8f6e4 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -494,7 +494,7 @@ def test_update( "... 3 user group entit(y|ies) will be added", "There are 3 valid user group entit(y|ies)", "There are 0 user entit(y|ies) currently registered", - "... 0 user entit(y|ies) will be added", + "... 2 user entit(y|ies) will be added", "There are 2 valid user entit(y|ies)", "Ensuring that 2 user(s) are correctly assigned among 3 group(s)", "Working on group 'defendants'",