Skip to content

Commit

Permalink
🐛 Fix possible exception in update_user_entities
Browse files Browse the repository at this point in the history
  • Loading branch information
jemrobinson committed Sep 26, 2024
1 parent 90e05de commit 885ad8a
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 30 deletions.
2 changes: 2 additions & 0 deletions guacamole_user_sync/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .exceptions import LDAPError, PostgreSQLError
from .guacamole import GuacamoleUserDetails
from .ldap_objects import LDAPGroup, LDAPUser
from .ldap_query import LDAPQuery

LDAPSearchResult = list[tuple[int, tuple[str, dict[str, list[bytes]]]]]

__all__ = [
"GuacamoleUserDetails",
"LDAPError",
"LDAPGroup",
"LDAPQuery",
Expand Down
10 changes: 10 additions & 0 deletions guacamole_user_sync/models/guacamole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass


@dataclass
class GuacamoleUserDetails:
"""A Guacamole user with required attributes only."""

entity_id: int
full_name: str
name: str
13 changes: 8 additions & 5 deletions guacamole_user_sync/postgresql/postgresql_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Any
from typing import Any, TypeVar

from sqlalchemy import URL, Engine, TextClause, create_engine
from sqlalchemy.exc import SQLAlchemyError
Expand All @@ -20,6 +20,9 @@ class PostgreSQLConnectionDetails:
user_password: str


T = TypeVar("T", bound=DeclarativeBase)


class PostgreSQLBackend:
"""Backend for connecting to a PostgreSQL database."""

Expand Down Expand Up @@ -52,13 +55,13 @@ def session(self) -> Session:
return self._session
return Session(self.engine)

def add_all(self, items: list[DeclarativeBase]) -> None:
def add_all(self, items: list[T]) -> None:
with self.session() as session, session.begin():
session.add_all(items)

def delete(
self,
table: type[DeclarativeBase],
table: type[T],
*filter_args: Any, # noqa: ANN401
) -> None:
with self.session() as session, session.begin():
Expand All @@ -78,9 +81,9 @@ def execute_commands(self, commands: list[TextClause]) -> None:

def query(
self,
table: type[DeclarativeBase],
table: type[T],
**filter_kwargs: Any, # noqa: ANN401
) -> list[DeclarativeBase]:
) -> list[T]:
with self.session() as session, session.begin():
if filter_kwargs:
result = session.query(table).filter_by(**filter_kwargs)
Expand Down
48 changes: 27 additions & 21 deletions guacamole_user_sync/postgresql/postgresql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy.exc import SQLAlchemyError

from guacamole_user_sync.models import (
GuacamoleUserDetails,
LDAPGroup,
LDAPUser,
PostgreSQLError,
Expand Down Expand Up @@ -60,27 +61,27 @@ def assign_users_to_groups(
logger.debug("Working on group '%s'", group.name)
# Get the user_group_id for each group (via looking up the entity_id)
try:
group_entity_id = [
group_entity_id = next(
item.entity_id
for item in self.backend.query(
GuacamoleEntity,
name=group.name,
type=GuacamoleEntityType.USER_GROUP,
)
][0]
user_group_id = [
)
user_group_id = next(
item.user_group_id
for item in self.backend.query(
GuacamoleUserGroup,
entity_id=group_entity_id,
)
][0]
)
logger.debug(
"-> entity_id: %s; user_group_id: %s",
group_entity_id,
user_group_id,
)
except IndexError:
except StopIteration:
logger.debug(
"Could not determine user_group_id for group '%s'.",
group.name,
Expand All @@ -94,20 +95,20 @@ def assign_users_to_groups(
logger.debug("Could not find LDAP user with UID %s", user_uid)
continue
try:
user_entity_id = [
user_entity_id = next(
item.entity_id
for item in self.backend.query(
GuacamoleEntity,
name=user.name,
type=GuacamoleEntityType.USER,
)
][0]
)
logger.debug(
"... group member '%s' has entity_id '%s'",
user,
user_entity_id,
)
except IndexError:
except StopIteration:
logger.debug(
"Could not find entity ID for LDAP user '%s'",
user_uid,
Expand Down Expand Up @@ -284,28 +285,33 @@ def update_user_entities(self, users: list[LDAPUser]) -> None:
"There are %s user entit(y|ies) currently registered",
len(current_user_entity_ids),
)
new_user_tuples: list[tuple[int, LDAPUser]] = [
(user.entity_id, [u for u in users if u.name == user.name][0])
for user in self.backend.query(
GuacamoleEntity,
type=GuacamoleEntityType.USER,
user_entities = self.backend.query(
GuacamoleEntity,
type=GuacamoleEntityType.USER,
)
new_users = [
GuacamoleUserDetails(
entity_id=entity.entity_id,
full_name=user.display_name,
name=user.name,
)
if user.entity_id not in current_user_entity_ids
for user in users
for entity in user_entities
if entity.name == user.name
and entity.entity_id not in current_user_entity_ids
]
logger.debug(
"... %s user entit(y|ies) will be added",
len(current_user_entity_ids),
)
logger.debug("... %s user entit(y|ies) will be added", len(new_users))

self.backend.add_all(
[
GuacamoleUser(
entity_id=user_tuple[0],
full_name=user_tuple[1].display_name,
entity_id=new_user.entity_id,
full_name=new_user.full_name,
password_date=datetime.now(tz=UTC),
password_hash=secrets.token_bytes(32),
password_salt=secrets.token_bytes(32),
)
for user_tuple in new_user_tuples
for new_user in new_users
],
)
# Clean up any unused entries
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ select = [
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-rse
"RUF", # Ruff-specific rules
"S", # flake8-bandit
"SIM", # flake8-simplify
"SLF", # flake8-self
Expand Down
6 changes: 3 additions & 3 deletions tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(

def startSearch( # noqa: N802
self,
*args: Any, # noqa: ANN401, ARG002
**kwargs: Any, # noqa: ANN401, ARG002
*args: Any, # noqa: ANN401
**kwargs: Any, # noqa: ANN401
) -> None:
pass

Expand Down Expand Up @@ -78,7 +78,7 @@ def add_all(self, items: list[GuacamoleBase]) -> None:
self.contents[cls] = []
self.contents[cls] += items

def delete(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, ARG001
def delete(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
pass

def execute_commands(self, commands: list[TextClause]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_update(
"... 3 user group entit(y|ies) will be added",
"There are 3 valid user group entit(y|ies)",
"There are 0 user entit(y|ies) currently registered",
"... 0 user entit(y|ies) will be added",
"... 2 user entit(y|ies) will be added",
"There are 2 valid user entit(y|ies)",
"Ensuring that 2 user(s) are correctly assigned among 3 group(s)",
"Working on group 'defendants'",
Expand Down

0 comments on commit 885ad8a

Please sign in to comment.