From 08209bdaa71cd29f588be35bfe2a67f281d2606a Mon Sep 17 00:00:00 2001 From: Felix Gustavsson Date: Sat, 4 May 2024 18:25:47 +0200 Subject: [PATCH 01/15] Add support for keycloak and option to disable group-of-groups --- README.md | 25 ++++ apricot/apricot_server.py | 10 +- apricot/ldap/oauth_ldap_server_factory.py | 4 +- apricot/ldap/oauth_ldap_tree.py | 5 +- apricot/models/ldap_attribute_adaptor.py | 1 + apricot/models/ldap_inetorgperson.py | 5 +- apricot/oauth/__init__.py | 6 +- apricot/oauth/enums.py | 1 + apricot/oauth/keycloak_client.py | 166 ++++++++++++++++++++++ apricot/oauth/oauth_client.py | 27 +++- apricot/oauth/oauth_data_adaptor.py | 32 +++-- docker/entrypoint.sh | 12 ++ run.py | 8 ++ 13 files changed, 275 insertions(+), 27 deletions(-) create mode 100644 apricot/oauth/keycloak_client.py diff --git a/README.md b/README.md index f00251b..51c8e96 100644 --- a/README.md +++ b/README.md @@ -151,3 +151,28 @@ Do this as follows: - `Microsoft Graph` > `GroupMember.Read.All` (application) - `Microsoft Graph` > `User.Read.All` (delegated) - Select this and click the `Grant admin consent` button (otherwise manual consent is needed from each user) + + +### Keycloak + +You will need to use the following command line arguments: + +```bash +--backend Keycloak --keycloak-base-url "/" --keycloak-realm "" +``` + +You will need to register an application to interact with `Keycloak`. +Do this as follows: + +- Create a new `Client` in your `Keycloak` instance. + - Set the name to whatever you choose (e.g. `apricot`) + - Enable `Client authentication` + - Enable the following authentication flows and disable the rest: + - Direct access grants + - Service account roles +- Under `Credentials` copy `client secret` +- Under `Service account roles`: + - Ensure that the following role are assigned + - `realm-management` > `view-users` + - `realm-management` > `manage-users` + - `realm-management` > `query-groups` diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 7163ffb..1bdae49 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -1,3 +1,4 @@ +import inspect import sys from typing import Any, cast @@ -19,6 +20,7 @@ def __init__( client_secret: str, domain: str, port: int, + enable_group_of_groups: bool, *, debug: bool = False, redis_host: str | None = None, @@ -45,12 +47,14 @@ def __init__( try: if self.debug: log.msg(f"Creating an OAuthClient for {backend}.") - oauth_client = OAuthClientMap[backend]( + oauth_backend = OAuthClientMap[backend] + oauth_backend_args = inspect.getfullargspec(oauth_backend.__init__).args + oauth_client = oauth_backend( client_id=client_id, client_secret=client_secret, debug=debug, uid_cache=uid_cache, - **kwargs, + **{k: v for k, v in kwargs.items() if k in oauth_backend_args}, ) except Exception as exc: msg = f"Could not construct an OAuth client for the '{backend}' backend.\n{exc!s}" @@ -59,7 +63,7 @@ def __init__( # Create an LDAPServerFactory if self.debug: log.msg("Creating an LDAPServerFactory.") - factory = OAuthLDAPServerFactory(domain, oauth_client) + factory = OAuthLDAPServerFactory(domain, oauth_client, enable_group_of_groups) # Attach a listening endpoint if self.debug: diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index 2890b35..445a33c 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -8,14 +8,14 @@ class OAuthLDAPServerFactory(ServerFactory): - def __init__(self, domain: str, oauth_client: OAuthClient): + def __init__(self, domain: str, oauth_client: OAuthClient, enable_group_of_groups: bool): """ Initialise an LDAPServerFactory @param oauth_client: An OAuth client used to construct the LDAP tree """ # Create an LDAP lookup tree - self.adaptor = OAuthLDAPTree(domain, oauth_client) + self.adaptor = OAuthLDAPTree(domain, oauth_client, enable_group_of_groups) def __repr__(self) -> str: return f"{self.__class__.__name__} using adaptor {self.adaptor}" diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 136ce31..49a5639 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -14,7 +14,7 @@ class OAuthLDAPTree: def __init__( - self, domain: str, oauth_client: OAuthClient, refresh_interval: int = 60 + self, domain: str, oauth_client: OAuthClient, enable_group_of_groups: bool, refresh_interval: int = 60 ) -> None: """ Initialise an OAuthLDAPTree @@ -29,6 +29,7 @@ def __init__( self.oauth_client = oauth_client self.refresh_interval = refresh_interval self.root_: OAuthLDAPEntry | None = None + self.enable_group_of_groups = enable_group_of_groups @property def dn(self) -> DistinguishedName: @@ -47,7 +48,7 @@ def root(self) -> OAuthLDAPEntry: ): # Update users and groups from the OAuth server log.msg("Retrieving OAuth data.") - oauth_adaptor = OAuthDataAdaptor(self.domain, self.oauth_client) + oauth_adaptor = OAuthDataAdaptor(self.domain, self.oauth_client, self.enable_group_of_groups) # Create a root node for the tree log.msg("Rebuilding LDAP tree.") diff --git a/apricot/models/ldap_attribute_adaptor.py b/apricot/models/ldap_attribute_adaptor.py index dfd3bd1..40f986d 100644 --- a/apricot/models/ldap_attribute_adaptor.py +++ b/apricot/models/ldap_attribute_adaptor.py @@ -8,6 +8,7 @@ def __init__(self, attributes: dict[Any, Any]) -> None: self.attributes = { str(k): list(map(str, v)) if isinstance(v, list) else [str(v)] for k, v in attributes.items() + if v is not None } @property diff --git a/apricot/models/ldap_inetorgperson.py b/apricot/models/ldap_inetorgperson.py index fe86b8e..8e0da25 100644 --- a/apricot/models/ldap_inetorgperson.py +++ b/apricot/models/ldap_inetorgperson.py @@ -1,3 +1,5 @@ +from typing import Optional + from .ldap_organizational_person import LDAPOrganizationalPerson @@ -12,9 +14,10 @@ class LDAPInetOrgPerson(LDAPOrganizationalPerson): """ cn: str - displayName: str # noqa: N815 + displayName: Optional[str] # noqa: N815 givenName: str # noqa: N815 sn: str + mail: Optional[str] = None def names(self) -> list[str]: return [*super().names(), "inetOrgPerson"] diff --git a/apricot/oauth/__init__.py b/apricot/oauth/__init__.py index 0cd8aa5..c5d6268 100644 --- a/apricot/oauth/__init__.py +++ b/apricot/oauth/__init__.py @@ -1,11 +1,15 @@ from apricot.types import LDAPAttributeDict, LDAPControlTuple from .enums import OAuthBackend +from .keycloak_client import KeycloakClient from .microsoft_entra_client import MicrosoftEntraClient from .oauth_client import OAuthClient from .oauth_data_adaptor import OAuthDataAdaptor -OAuthClientMap = {OAuthBackend.MICROSOFT_ENTRA: MicrosoftEntraClient} +OAuthClientMap = { + OAuthBackend.MICROSOFT_ENTRA: MicrosoftEntraClient, + OAuthBackend.KEYCLOAK: KeycloakClient, +} __all__ = [ "LDAPAttributeDict", diff --git a/apricot/oauth/enums.py b/apricot/oauth/enums.py index 8675218..d9c356d 100644 --- a/apricot/oauth/enums.py +++ b/apricot/oauth/enums.py @@ -5,3 +5,4 @@ class OAuthBackend(str, Enum): """Available OAuth backends.""" MICROSOFT_ENTRA = "MicrosoftEntra" + KEYCLOAK = "Keycloak" diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py new file mode 100644 index 0000000..b7229da --- /dev/null +++ b/apricot/oauth/keycloak_client.py @@ -0,0 +1,166 @@ +from typing import Any, cast + +from apricot.types import JSONDict + +from .oauth_client import OAuthClient + + +def get_single_value_attribute(obj: JSONDict, key: str, default=None) -> Any: + for part in key.split("."): + obj = obj.get(part) + if obj is None: + return default + if isinstance(obj, list): + try: + return next(iter(obj)) + except StopIteration: + pass + else: + return obj + return default + + +class KeycloakClient(OAuthClient): + """OAuth client for the Keycloak backend.""" + + max_rows = 100 + + def __init__( + self, + keycloak_base_url: str, + keycloak_realm: str, + **kwargs: Any, + ): + self.base_url = keycloak_base_url + self.realm = keycloak_realm + + redirect_uri = "urn:ietf:wg:oauth:2.0:oob" # this is the "no redirect" URL + scopes = [] # this is the default scope + token_url = ( + f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token" + ) + + super().__init__( + redirect_uri=redirect_uri, scopes=scopes, token_url=token_url, **kwargs, + ) + + def extract_token(self, json_response: JSONDict) -> str: + return str(json_response["access_token"]) + + def groups(self) -> list[JSONDict]: + output = [] + try: + group_data = [] + while data := self.query( + f"{self.base_url}/admin/realms/{self.realm}/groups?first={len(group_data)}&max={self.max_rows}&briefRepresentation=false" + ): + group_data.extend(data) + if len(data) != self.max_rows: + break + + group_data = sorted(group_data, key=lambda g: int(get_single_value_attribute(g, "attributes.gid", default="9999999999"), 10)) + + next_gid = max( + *( + int(get_single_value_attribute(g, "attributes.gid", default="-1"), 10)+1 + for g in group_data + ), + 3000 + ) + + for group_dict in cast( + list[JSONDict], + group_data, + ): + group_gid = get_single_value_attribute(group_dict, "attributes.gid", default=None) + if group_gid: + group_gid = int(group_gid, 10) + if not group_gid: + group_gid = next_gid + next_gid += 1 + group_dict["attributes"] = group_dict.get("attributes", {}) + group_dict["attributes"]["gid"] = [str(group_gid)] + self.request( + f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}", + method="PUT", + json=group_dict + ) + attributes: JSONDict = {} + attributes["cn"] = group_dict.get("name", None) + attributes["description"] = group_dict.get("id", None) + attributes["gidNumber"] = group_gid + attributes["oauth_id"] = group_dict.get("id", None) + # Add membership attributes + members = self.query( + f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}/members" + ) + attributes["memberUid"] = [ + user["username"] + for user in cast(list[JSONDict], members) + ] + output.append(attributes) + except KeyError: + pass + return output + + def users(self) -> list[JSONDict]: + output = [] + try: + user_data = [] + while data := self.query( + f"{self.base_url}/admin/realms/{self.realm}/users?first={len(user_data)}&max={self.max_rows}&briefRepresentation=false" + ): + user_data.extend(data) + if len(data) != self.max_rows: + break + + user_data = sorted(user_data, key=lambda u: int(get_single_value_attribute(u, "attributes.uid", default="9999999999"), 10)) + + next_uid = max( + *( + int(get_single_value_attribute(g, "attributes.uid", default="-1"), 10)+1 + for g in user_data + ), + 3000 + ) + + for user_dict in cast( + list[JSONDict], + sorted(user_data, key=lambda user: user["createdTimestamp"]), + ): + user_uid = get_single_value_attribute(user_dict, "attributes.uid", default=None) + if user_uid: + user_uid = int(user_uid, 10) + if not user_uid: + user_uid = next_uid + next_uid += 1 + + user_dict["attributes"] = user_dict.get("attributes", {}) + user_dict["attributes"]["uid"] = [str(user_uid)] + self.request( + f"{self.base_url}/admin/realms/{self.realm}/users/{user_dict['id']}", + method="PUT", + json=user_dict + ) + # Get user attributes + first_name = user_dict.get("firstName", None) + last_name = user_dict.get("lastName", None) + full_name = " ".join(filter(lambda x: x, [first_name, last_name])) or None + username = user_dict.get("username") + attributes: JSONDict = {} + attributes["cn"] = username + attributes["uid"] = username + attributes["oauth_username"] = username + attributes["displayName"] = full_name + attributes["mail"] = user_dict.get("email") + attributes["description"] = "" + attributes["gidNumber"] = user_uid + attributes["givenName"] = first_name if first_name else "" + attributes["homeDirectory"] = f"/home/{username}" if username else None + attributes["oauth_id"] = user_dict.get("id", None) + attributes["sn"] = last_name if last_name else "" + attributes["uidNumber"] = user_uid + output.append(attributes) + except KeyError: + pass + return output diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 857b553..a4e1753 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -1,5 +1,6 @@ import os from abc import ABC, abstractmethod +from http import HTTPStatus from typing import Any import requests @@ -116,9 +117,7 @@ def query(self, url: str) -> dict[str, Any]: def query_(url: str) -> requests.Response: return self.session_application.get( # type: ignore[no-any-return] url=url, - headers={"Authorization": f"Bearer {self.bearer_token}"}, - client_id=self.session_application._client.client_id, - client_secret=self.client_secret, + headers={"Authorization": f"Bearer {self.bearer_token}"} ) try: @@ -130,6 +129,28 @@ def query_(url: str) -> requests.Response: result = query_(url) return result.json() # type: ignore + def request(self, *args, method="GET", **kwargs) -> dict[str, Any]: + """ + Make a query against the OAuth backend + """ + + def query_(*args, **kwargs) -> requests.Response: + return self.session_application.request( # type: ignore[no-any-return] + method, + *args, **kwargs, + headers={"Authorization": f"Bearer {self.bearer_token}"} + ) + + try: + result = query_(*args, **kwargs) + result.raise_for_status() + except (TokenExpiredError, requests.exceptions.HTTPError): + log.msg("Authentication token has expired.") + self.bearer_token_ = None + result = query_( *args, **kwargs) + if result.status_code != HTTPStatus.NO_CONTENT: + return result.json() # type: ignore + def verify(self, username: str, password: str) -> bool: """ Verify username and password by attempting to authenticate against the OAuth backend. diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index 701e55a..8a7048a 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -21,10 +21,11 @@ class OAuthDataAdaptor: """Adaptor for converting raw user and group data into LDAP format.""" - def __init__(self, domain: str, oauth_client: OAuthClient): + def __init__(self, domain: str, oauth_client: OAuthClient, enable_group_of_groups: bool): self.debug = oauth_client.debug self.oauth_client = oauth_client self.root_dn = "DC=" + domain.replace(".", ",DC=") + self.enable_group_of_groups = enable_group_of_groups # Retrieve and validate user and group information annotated_groups, annotated_users = self._retrieve_entries() @@ -105,20 +106,21 @@ def _retrieve_entries( # Add one group of groups for each existing group. # Its members are the primary user groups for each original group member. groups_of_groups = [] - for group in oauth_groups: - group_dict = {} - group_dict["cn"] = f"Primary user groups for {group['cn']}" - group_dict["description"] = ( - f"Primary user groups for members of '{group['cn']}'" - ) - # Replace each member user with a member group - group_dict["member"] = [ - str(member).replace("OU=users", "OU=groups") - for member in group["member"] - ] - # Groups do not have UIDs so memberUid must be empty - group_dict["memberUid"] = [] - groups_of_groups.append(group_dict) + if self.enable_group_of_groups: + for group in oauth_groups: + group_dict = {} + group_dict["cn"] = f"Primary user groups for {group['cn']}" + group_dict["description"] = ( + f"Primary user groups for members of '{group['cn']}'" + ) + # Replace each member user with a member group + group_dict["member"] = [ + str(member).replace("OU=users", "OU=groups") + for member in group["member"] + ] + # Groups do not have UIDs so memberUid must be empty + group_dict["memberUid"] = [] + groups_of_groups.append(group_dict) # Ensure memberOf is set correctly for users for child_dict in oauth_users: diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 8a21379..c797ba8 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -37,6 +37,10 @@ if [ -n "${DEBUG}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --debug" fi +if [ -n "${DISABLE_GROUP_OF_GROUPS}" ]; then + EXTRA_OPTS="${EXTRA_OPTS} --disable-group-of-groups" +fi + if [ -n "${ENTRA_TENANT_ID}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --entra-tenant-id $ENTRA_TENANT_ID" fi @@ -49,6 +53,14 @@ if [ -n "${REDIS_HOST}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --redis-host $REDIS_HOST --redis-port $REDIS_PORT" fi +if [ -n "${KEYCLOAK_BASE_URL}" ]; then + if [ -z "${KEYCLOAK_REALM}" ]; then + echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] KEYCLOAK_REALM environment variable is not set" + exit 1 + fi + EXTRA_OPTS="${EXTRA_OPTS} --keycloak-base-url $KEYCLOAK_BASE_URL --keycloak-realm $KEYCLOAK_REALM" +fi + # Run the server hatch run python run.py \ --backend "${BACKEND}" \ diff --git a/run.py b/run.py index 5ac4230..6a1b5b8 100644 --- a/run.py +++ b/run.py @@ -16,10 +16,18 @@ parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.") parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.") parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.") + parser.add_argument("--disable-group-of-groups", action="store_false", + dest="enable_group_of_groups", default=True, + help="Disable creation of group-of-groups.") parser.add_argument("--debug", action="store_true", help="Enable debug logging.") # Options for Microsoft Entra backend entra_group = parser.add_argument_group("Microsoft Entra") entra_group.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False) + + # Options for Keycloak backend + keycloak_group = parser.add_argument_group("Keycloak") + keycloak_group.add_argument("--keycloak-base-url", type=str, help="Keycloak base URL.", required=False) + keycloak_group.add_argument("--keycloak-realm", type=str, help="Keycloak Realm.", required=False) # Options for Redis cache redis_group = parser.add_argument_group("Redis") redis_group.add_argument("--redis-host", type=str, help="Host for Redis server.") From 221955225496bdb697622da499db0e7475867e41 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 23 May 2024 15:46:21 +0100 Subject: [PATCH 02/15] :rotating_light: Run linting fixes --- apricot/apricot_server.py | 6 ++- apricot/ldap/oauth_ldap_server_factory.py | 4 +- apricot/ldap/oauth_ldap_tree.py | 13 ++++- apricot/oauth/keycloak_client.py | 64 +++++++++++++++++------ apricot/oauth/oauth_client.py | 12 ++--- apricot/oauth/oauth_data_adaptor.py | 4 +- 6 files changed, 74 insertions(+), 29 deletions(-) diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 1bdae49..920b728 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -20,9 +20,9 @@ def __init__( client_secret: str, domain: str, port: int, - enable_group_of_groups: bool, *, debug: bool = False, + enable_group_of_groups: bool, redis_host: str | None = None, redis_port: int | None = None, **kwargs: Any, @@ -63,7 +63,9 @@ def __init__( # Create an LDAPServerFactory if self.debug: log.msg("Creating an LDAPServerFactory.") - factory = OAuthLDAPServerFactory(domain, oauth_client, enable_group_of_groups) + factory = OAuthLDAPServerFactory( + domain, oauth_client, enable_group_of_groups=enable_group_of_groups + ) # Attach a listening endpoint if self.debug: diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index 445a33c..ea2da9d 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -8,7 +8,9 @@ class OAuthLDAPServerFactory(ServerFactory): - def __init__(self, domain: str, oauth_client: OAuthClient, enable_group_of_groups: bool): + def __init__( + self, domain: str, oauth_client: OAuthClient, *, enable_group_of_groups: bool + ): """ Initialise an LDAPServerFactory diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 49a5639..90664c7 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -14,7 +14,12 @@ class OAuthLDAPTree: def __init__( - self, domain: str, oauth_client: OAuthClient, enable_group_of_groups: bool, refresh_interval: int = 60 + self, + domain: str, + oauth_client: OAuthClient, + *, + enable_group_of_groups: bool, + refresh_interval: int = 60, ) -> None: """ Initialise an OAuthLDAPTree @@ -48,7 +53,11 @@ def root(self) -> OAuthLDAPEntry: ): # Update users and groups from the OAuth server log.msg("Retrieving OAuth data.") - oauth_adaptor = OAuthDataAdaptor(self.domain, self.oauth_client, self.enable_group_of_groups) + oauth_adaptor = OAuthDataAdaptor( + self.domain, + self.oauth_client, + enable_group_of_groups=self.enable_group_of_groups, + ) # Create a root node for the tree log.msg("Rebuilding LDAP tree.") diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index b7229da..6667778 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -36,12 +36,13 @@ def __init__( redirect_uri = "urn:ietf:wg:oauth:2.0:oob" # this is the "no redirect" URL scopes = [] # this is the default scope - token_url = ( - f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token" - ) + token_url = f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token" super().__init__( - redirect_uri=redirect_uri, scopes=scopes, token_url=token_url, **kwargs, + redirect_uri=redirect_uri, + scopes=scopes, + token_url=token_url, + **kwargs, ) def extract_token(self, json_response: JSONDict) -> str: @@ -58,21 +59,35 @@ def groups(self) -> list[JSONDict]: if len(data) != self.max_rows: break - group_data = sorted(group_data, key=lambda g: int(get_single_value_attribute(g, "attributes.gid", default="9999999999"), 10)) + group_data = sorted( + group_data, + key=lambda g: int( + get_single_value_attribute( + g, "attributes.gid", default="9999999999" + ), + 10, + ), + ) next_gid = max( *( - int(get_single_value_attribute(g, "attributes.gid", default="-1"), 10)+1 + int( + get_single_value_attribute(g, "attributes.gid", default="-1"), + 10, + ) + + 1 for g in group_data ), - 3000 + 3000, ) for group_dict in cast( list[JSONDict], group_data, ): - group_gid = get_single_value_attribute(group_dict, "attributes.gid", default=None) + group_gid = get_single_value_attribute( + group_dict, "attributes.gid", default=None + ) if group_gid: group_gid = int(group_gid, 10) if not group_gid: @@ -83,7 +98,7 @@ def groups(self) -> list[JSONDict]: self.request( f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}", method="PUT", - json=group_dict + json=group_dict, ) attributes: JSONDict = {} attributes["cn"] = group_dict.get("name", None) @@ -95,8 +110,7 @@ def groups(self) -> list[JSONDict]: f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}/members" ) attributes["memberUid"] = [ - user["username"] - for user in cast(list[JSONDict], members) + user["username"] for user in cast(list[JSONDict], members) ] output.append(attributes) except KeyError: @@ -114,21 +128,35 @@ def users(self) -> list[JSONDict]: if len(data) != self.max_rows: break - user_data = sorted(user_data, key=lambda u: int(get_single_value_attribute(u, "attributes.uid", default="9999999999"), 10)) + user_data = sorted( + user_data, + key=lambda u: int( + get_single_value_attribute( + u, "attributes.uid", default="9999999999" + ), + 10, + ), + ) next_uid = max( *( - int(get_single_value_attribute(g, "attributes.uid", default="-1"), 10)+1 + int( + get_single_value_attribute(g, "attributes.uid", default="-1"), + 10, + ) + + 1 for g in user_data ), - 3000 + 3000, ) for user_dict in cast( list[JSONDict], sorted(user_data, key=lambda user: user["createdTimestamp"]), ): - user_uid = get_single_value_attribute(user_dict, "attributes.uid", default=None) + user_uid = get_single_value_attribute( + user_dict, "attributes.uid", default=None + ) if user_uid: user_uid = int(user_uid, 10) if not user_uid: @@ -140,12 +168,14 @@ def users(self) -> list[JSONDict]: self.request( f"{self.base_url}/admin/realms/{self.realm}/users/{user_dict['id']}", method="PUT", - json=user_dict + json=user_dict, ) # Get user attributes first_name = user_dict.get("firstName", None) last_name = user_dict.get("lastName", None) - full_name = " ".join(filter(lambda x: x, [first_name, last_name])) or None + full_name = ( + " ".join(filter(lambda x: x, [first_name, last_name])) or None + ) username = user_dict.get("username") attributes: JSONDict = {} attributes["cn"] = username diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index a4e1753..3311443 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -116,8 +116,7 @@ def query(self, url: str) -> dict[str, Any]: def query_(url: str) -> requests.Response: return self.session_application.get( # type: ignore[no-any-return] - url=url, - headers={"Authorization": f"Bearer {self.bearer_token}"} + url=url, headers={"Authorization": f"Bearer {self.bearer_token}"} ) try: @@ -135,10 +134,11 @@ def request(self, *args, method="GET", **kwargs) -> dict[str, Any]: """ def query_(*args, **kwargs) -> requests.Response: - return self.session_application.request( # type: ignore[no-any-return] + return self.session_application.request( # type: ignore[no-any-return] method, - *args, **kwargs, - headers={"Authorization": f"Bearer {self.bearer_token}"} + *args, + **kwargs, + headers={"Authorization": f"Bearer {self.bearer_token}"}, ) try: @@ -147,7 +147,7 @@ def query_(*args, **kwargs) -> requests.Response: except (TokenExpiredError, requests.exceptions.HTTPError): log.msg("Authentication token has expired.") self.bearer_token_ = None - result = query_( *args, **kwargs) + result = query_(*args, **kwargs) if result.status_code != HTTPStatus.NO_CONTENT: return result.json() # type: ignore diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index 8a7048a..cfd7a32 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -21,7 +21,9 @@ class OAuthDataAdaptor: """Adaptor for converting raw user and group data into LDAP format.""" - def __init__(self, domain: str, oauth_client: OAuthClient, enable_group_of_groups: bool): + def __init__( + self, domain: str, oauth_client: OAuthClient, *, enable_group_of_groups: bool + ): self.debug = oauth_client.debug self.oauth_client = oauth_client self.root_dn = "DC=" + domain.replace(".", ",DC=") From 90bcb0f01ea537dd3a0324a539d91f63f54800cc Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 21 May 2024 17:21:01 +0100 Subject: [PATCH 03/15] :rotating_light: Fix linting errors in README --- README.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 51c8e96..a28bd0f 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Start the `Apricot` server on port 1389 by running: python run.py --client-id "" --client-secret "" --backend "" --port 1389 --domain "" --redis-host "" ``` -Alternatively, you can run in Docker by editing `docker/docker-compose.yaml` and running: +If you prefer to use Docker, you can edit `docker/docker-compose.yaml` and run: ```bash docker compose up @@ -146,12 +146,11 @@ Do this as follows: - Set the expiry time to whatever is relevant for your use-case - You **must** record the value of this secret at **creation time**, as it will not be visible later. - Under `API permissions`: - - Ensure that the following permissions are enabled + - Enable the following permissions: - `Microsoft Graph` > `User.Read.All` (application) - `Microsoft Graph` > `GroupMember.Read.All` (application) - `Microsoft Graph` > `User.Read.All` (delegated) - - Select this and click the `Grant admin consent` button (otherwise manual consent is needed from each user) - + - Select this and click the `Grant admin consent` button (otherwise each user will need to manually consent) ### Keycloak @@ -168,11 +167,13 @@ Do this as follows: - Set the name to whatever you choose (e.g. `apricot`) - Enable `Client authentication` - Enable the following authentication flows and disable the rest: - - Direct access grants - - Service account roles + - Direct access grants + - Service account roles - Under `Credentials` copy `client secret` - Under `Service account roles`: - - Ensure that the following role are assigned + - Click on `Assign role` then `Filter by clients` + - Assign the following roles: - `realm-management` > `view-users` - `realm-management` > `manage-users` - `realm-management` > `query-groups` + - `realm-management` > `query-users` From 6a056aa2aa8b857109bd5d32f8131f104bdfae02 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 23 May 2024 15:40:16 +0100 Subject: [PATCH 04/15] :memo: Add explanation of how/why to use the groups-of-groups feature --- README.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/README.md b/README.md index a28bd0f..2b892cf 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,8 @@ member: CN=sherlock.holmes,OU=users,DC= ## Mirrored groups +:exclamation: You can disable the creation of mirrored groups with the `--disable-mirrored-groups` command line option :exclamation: + Each group of users will have an associated group-of-groups where each user in the group will have its user primary group in the group-of-groups. Note that these groups-of-groups are **not** `posixGroup`s as POSIX does not allow nested groups. @@ -109,6 +111,7 @@ objectClass: posixGroup objectClass: top ... member: CN=sherlock.holmes,OU=users,DC= +... ``` will have an associated group-of-groups @@ -122,6 +125,32 @@ member: CN=sherlock.holmes,OU=groups,DC= ... ``` +This allows a user to make a request for "all primary user groups needed by members of group X" without getting a large number of primary user groups for unrelated users. To do this, you will need an LDAP request that looks like: + +```ldap +(&(objectClass=posixGroup)(|(CN=Detectives)(memberOf=Primary user groups for Detectives))) +``` + +which will return: + +``` +dn:CN=Detectives,OU=groups,DC= +objectClass: groupOfNames +objectClass: posixGroup +objectClass: top +... +member: CN=sherlock.holmes,OU=users,DC= +... + +dn: CN=sherlock.holmes,OU=groups,DC= +objectClass: groupOfNames +objectClass: posixGroup +objectClass: top +... +member: CN=sherlock.holmes,OU=users,DC= +... +``` + ## OpenID Connect Instructions for specific OpenID Connect backends below. From a3128923c06bfb1e700a08543c4f74961d1376f4 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 23 May 2024 15:43:23 +0100 Subject: [PATCH 05/15] :truck: Rename disable-groups-of-groups to disable-mirrored-groups --- README.md | 4 +++- apricot/apricot_server.py | 4 ++-- apricot/ldap/oauth_ldap_server_factory.py | 4 ++-- apricot/ldap/oauth_ldap_tree.py | 6 +++--- apricot/oauth/oauth_data_adaptor.py | 6 +++--- docker/entrypoint.sh | 4 ++-- run.py | 6 +++--- 7 files changed, 18 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 2b892cf..f13ad3b 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,9 @@ member: ## Primary groups -Note that each user will have an associated group to act as its POSIX user primary group +:exclamation: You can disable the creation of mirrored groups with the `--disable-primary-groups` command line option :exclamation: + +Apricot creates an associated group for each user, which acts as its POSIX user primary group. For example: diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 920b728..a40c431 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -22,7 +22,7 @@ def __init__( port: int, *, debug: bool = False, - enable_group_of_groups: bool, + enable_mirrored_groups: bool, redis_host: str | None = None, redis_port: int | None = None, **kwargs: Any, @@ -64,7 +64,7 @@ def __init__( if self.debug: log.msg("Creating an LDAPServerFactory.") factory = OAuthLDAPServerFactory( - domain, oauth_client, enable_group_of_groups=enable_group_of_groups + domain, oauth_client, enable_mirrored_groups=enable_mirrored_groups ) # Attach a listening endpoint diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index ea2da9d..72afe0d 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -9,7 +9,7 @@ class OAuthLDAPServerFactory(ServerFactory): def __init__( - self, domain: str, oauth_client: OAuthClient, *, enable_group_of_groups: bool + self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool ): """ Initialise an LDAPServerFactory @@ -17,7 +17,7 @@ def __init__( @param oauth_client: An OAuth client used to construct the LDAP tree """ # Create an LDAP lookup tree - self.adaptor = OAuthLDAPTree(domain, oauth_client, enable_group_of_groups) + self.adaptor = OAuthLDAPTree(domain, oauth_client, enable_mirrored_groups) def __repr__(self) -> str: return f"{self.__class__.__name__} using adaptor {self.adaptor}" diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 90664c7..6095e75 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -18,7 +18,7 @@ def __init__( domain: str, oauth_client: OAuthClient, *, - enable_group_of_groups: bool, + enable_mirrored_groups: bool, refresh_interval: int = 60, ) -> None: """ @@ -34,7 +34,7 @@ def __init__( self.oauth_client = oauth_client self.refresh_interval = refresh_interval self.root_: OAuthLDAPEntry | None = None - self.enable_group_of_groups = enable_group_of_groups + self.enable_mirrored_groups = enable_mirrored_groups @property def dn(self) -> DistinguishedName: @@ -56,7 +56,7 @@ def root(self) -> OAuthLDAPEntry: oauth_adaptor = OAuthDataAdaptor( self.domain, self.oauth_client, - enable_group_of_groups=self.enable_group_of_groups, + enable_mirrored_groups=self.enable_mirrored_groups, ) # Create a root node for the tree diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index cfd7a32..0ce7e19 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -22,12 +22,12 @@ class OAuthDataAdaptor: """Adaptor for converting raw user and group data into LDAP format.""" def __init__( - self, domain: str, oauth_client: OAuthClient, *, enable_group_of_groups: bool + self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool ): self.debug = oauth_client.debug self.oauth_client = oauth_client self.root_dn = "DC=" + domain.replace(".", ",DC=") - self.enable_group_of_groups = enable_group_of_groups + self.enable_mirrored_groups = enable_mirrored_groups # Retrieve and validate user and group information annotated_groups, annotated_users = self._retrieve_entries() @@ -108,7 +108,7 @@ def _retrieve_entries( # Add one group of groups for each existing group. # Its members are the primary user groups for each original group member. groups_of_groups = [] - if self.enable_group_of_groups: + if self.enable_mirrored_groups: for group in oauth_groups: group_dict = {} group_dict["cn"] = f"Primary user groups for {group['cn']}" diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index c797ba8..04261da 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -37,8 +37,8 @@ if [ -n "${DEBUG}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --debug" fi -if [ -n "${DISABLE_GROUP_OF_GROUPS}" ]; then - EXTRA_OPTS="${EXTRA_OPTS} --disable-group-of-groups" +if [ -n "${DISABLE_MIRRORED_GROUPS}" ]; then + EXTRA_OPTS="${EXTRA_OPTS} --disable-mirrored-groups" fi if [ -n "${ENTRA_TENANT_ID}" ]; then diff --git a/run.py b/run.py index 6a1b5b8..ae3dbf0 100644 --- a/run.py +++ b/run.py @@ -16,9 +16,9 @@ parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.") parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.") parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.") - parser.add_argument("--disable-group-of-groups", action="store_false", - dest="enable_group_of_groups", default=True, - help="Disable creation of group-of-groups.") + parser.add_argument("--disable-mirrored-groups", action="store_false", + dest="enable_mirrored", default=True, + help="Disable creation of mirrored groups.") parser.add_argument("--debug", action="store_true", help="Enable debug logging.") # Options for Microsoft Entra backend entra_group = parser.add_argument_group("Microsoft Entra") From 56464c65e287f7f2825eadb184364389f6ae09b6 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 23 May 2024 16:00:53 +0100 Subject: [PATCH 06/15] :wrench: Add additional optional attributes to LDAPInetOrgPerson --- apricot/models/ldap_inetorgperson.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/apricot/models/ldap_inetorgperson.py b/apricot/models/ldap_inetorgperson.py index 8e0da25..51e5cb5 100644 --- a/apricot/models/ldap_inetorgperson.py +++ b/apricot/models/ldap_inetorgperson.py @@ -1,5 +1,3 @@ -from typing import Optional - from .ldap_organizational_person import LDAPOrganizationalPerson @@ -14,10 +12,12 @@ class LDAPInetOrgPerson(LDAPOrganizationalPerson): """ cn: str - displayName: Optional[str] # noqa: N815 - givenName: str # noqa: N815 + displayName: str | None = None # noqa: N815 + employeeNumber: str | None = None # noqa: N815 + givenName: str | None = None # noqa: N815 sn: str - mail: Optional[str] = None + mail: str | None = None + telephoneNumber: str | None = None # noqa: N815 def names(self) -> list[str]: return [*super().names(), "inetOrgPerson"] From 278c72d771c8871957ad100d82a6107791cc96d9 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 23 May 2024 16:16:10 +0100 Subject: [PATCH 07/15] :bug: Revert change to OAuthClient.query which stopped passing client_id and client_secret to the OAuth backend --- apricot/cache/redis_cache.py | 2 +- apricot/oauth/keycloak_client.py | 9 ++++++--- apricot/oauth/oauth_client.py | 31 +++++++++++++++---------------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/apricot/cache/redis_cache.py b/apricot/cache/redis_cache.py index 4a1d919..24ac506 100644 --- a/apricot/cache/redis_cache.py +++ b/apricot/cache/redis_cache.py @@ -9,7 +9,7 @@ class RedisCache(UidCache): def __init__(self, redis_host: str, redis_port: int) -> None: self.redis_host = redis_host self.redis_port = redis_port - self.cache_: "redis.Redis[str]" | None = None + self.cache_: "redis.Redis[str]" | None = None # noqa: UP037 @property def cache(self) -> "redis.Redis[str]": diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index 6667778..d23d638 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -53,7 +53,8 @@ def groups(self) -> list[JSONDict]: try: group_data = [] while data := self.query( - f"{self.base_url}/admin/realms/{self.realm}/groups?first={len(group_data)}&max={self.max_rows}&briefRepresentation=false" + f"{self.base_url}/admin/realms/{self.realm}/groups?first={len(group_data)}&max={self.max_rows}&briefRepresentation=false", + use_client_secret=False, ): group_data.extend(data) if len(data) != self.max_rows: @@ -107,7 +108,8 @@ def groups(self) -> list[JSONDict]: attributes["oauth_id"] = group_dict.get("id", None) # Add membership attributes members = self.query( - f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}/members" + f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}/members", + use_client_secret=False, ) attributes["memberUid"] = [ user["username"] for user in cast(list[JSONDict], members) @@ -122,7 +124,8 @@ def users(self) -> list[JSONDict]: try: user_data = [] while data := self.query( - f"{self.base_url}/admin/realms/{self.realm}/users?first={len(user_data)}&max={self.max_rows}&briefRepresentation=false" + f"{self.base_url}/admin/realms/{self.realm}/users?first={len(user_data)}&max={self.max_rows}&briefRepresentation=false", + use_client_secret=False, ): user_data.extend(data) if len(data) != self.max_rows: diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 3311443..90d2143 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -109,28 +109,27 @@ def users(self) -> list[JSONDict]: """ pass - def query(self, url: str) -> dict[str, Any]: + def query(self, url: str, *, use_client_secret=True) -> dict[str, Any]: """ Make a query against the OAuth backend """ - - def query_(url: str) -> requests.Response: - return self.session_application.get( # type: ignore[no-any-return] - url=url, headers={"Authorization": f"Bearer {self.bearer_token}"} - ) - - try: - result = query_(url) - result.raise_for_status() - except (TokenExpiredError, requests.exceptions.HTTPError): - log.msg("Authentication token has expired.") - self.bearer_token_ = None - result = query_(url) - return result.json() # type: ignore + kwargs = ( + { + "client_id": self.session_application._client.client_id, + "client_secret": self.client_secret, + } + if use_client_secret + else {} + ) + return self.request( + url=url, + method="GET", + **kwargs, + ) def request(self, *args, method="GET", **kwargs) -> dict[str, Any]: """ - Make a query against the OAuth backend + Make a request to the OAuth backend """ def query_(*args, **kwargs) -> requests.Response: From 9cc497329fcb8f1c075eaff2675d8cabfad55914 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 21 May 2024 12:24:15 +0100 Subject: [PATCH 08/15] :memo: Add debug messages for each group and user added to the LDAP tree --- apricot/ldap/oauth_ldap_entry.py | 3 +++ apricot/ldap/oauth_ldap_tree.py | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/apricot/ldap/oauth_ldap_entry.py b/apricot/ldap/oauth_ldap_entry.py index d945ef0..6845a33 100644 --- a/apricot/ldap/oauth_ldap_entry.py +++ b/apricot/ldap/oauth_ldap_entry.py @@ -83,3 +83,6 @@ def _bind(password: bytes) -> "OAuthLDAPEntry": raise LDAPInvalidCredentials(msg) return defer.maybeDeferred(_bind, password) + + def list_children(self) -> "list[OAuthLDAPEntry]": + return [cast(OAuthLDAPEntry, entry) for entry in self._children.values()] diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 6095e75..d9eb133 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -77,15 +77,29 @@ def root(self) -> OAuthLDAPEntry: # Add groups to the groups OU if self.debug: - log.msg(f"Adding {len(oauth_adaptor.groups)} groups to the LDAP tree.") + log.msg( + f"Attempting to add {len(oauth_adaptor.groups)} groups to the LDAP tree." + ) for group_attrs in oauth_adaptor.groups: groups_ou.add_child(f"CN={group_attrs.cn}", group_attrs.to_dict()) + if self.debug: + children = groups_ou.list_children() + for child in children: + log.msg(f"... {child.dn.getText()}") + log.msg(f"There are {len(children)} groups in the LDAP tree.") # Add users to the users OU if self.debug: - log.msg(f"Adding {len(oauth_adaptor.users)} users to the LDAP tree.") + log.msg( + f"Attempting to add {len(oauth_adaptor.users)} users to the LDAP tree." + ) for user_attrs in oauth_adaptor.users: users_ou.add_child(f"CN={user_attrs.cn}", user_attrs.to_dict()) + if self.debug: + children = users_ou.list_children() + for child in children: + log.msg(f"... {child.dn.getText()}") + log.msg(f"There are {len(children)} users in the LDAP tree.") # Set last updated time log.msg("Finished building LDAP tree.") From 819e360988ec3ec958946b5bcc0fde004aa3f6c0 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 21 May 2024 13:48:37 +0100 Subject: [PATCH 09/15] :loud_sound: Add additional debug messages for user and group membership --- apricot/oauth/oauth_data_adaptor.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index 0ce7e19..e2e6ea5 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -132,6 +132,11 @@ def _retrieve_entries( for parent_dict in oauth_groups + user_primary_groups + groups_of_groups if child_dn in parent_dict["member"] ] + if self.debug: + for group_name in child_dict["memberOf"]: + log.msg( + f"... user '{child_dict['cn']}' is a member of '{group_name}'" + ) # Ensure memberOf is set correctly for groups for child_dict in oauth_groups + user_primary_groups + groups_of_groups: @@ -141,6 +146,11 @@ def _retrieve_entries( for parent_dict in oauth_groups + user_primary_groups + groups_of_groups if child_dn in parent_dict["member"] ] + if self.debug: + for group_name in child_dict["memberOf"]: + log.msg( + f"... group '{child_dict['cn']}' is a member of '{group_name}'" + ) # Annotate group and user dicts with the appropriate LDAP classes annotated_groups = [ From 83a0ff3d0610fd6b33554b27f2f037a28a74b94f Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 21 May 2024 13:47:11 +0100 Subject: [PATCH 10/15] :bug: Continue processing groups even if attributes cannot be processed for one of them --- apricot/oauth/microsoft_entra_client.py | 35 ++++++++++++++----------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/apricot/oauth/microsoft_entra_client.py b/apricot/oauth/microsoft_entra_client.py index 4bc94c8..7ca1e07 100644 --- a/apricot/oauth/microsoft_entra_client.py +++ b/apricot/oauth/microsoft_entra_client.py @@ -1,5 +1,7 @@ from typing import Any, cast +from twisted.python import log + from apricot.types import JSONDict from .oauth_client import OAuthClient @@ -28,19 +30,19 @@ def extract_token(self, json_response: JSONDict) -> str: def groups(self) -> list[JSONDict]: output = [] - try: - queries = [ - "createdDateTime", - "displayName", - "id", - ] - group_data = self.query( - f"https://graph.microsoft.com/v1.0/groups?$select={','.join(queries)}" - ) - for group_dict in cast( - list[JSONDict], - sorted(group_data["value"], key=lambda group: group["createdDateTime"]), - ): + queries = [ + "createdDateTime", + "displayName", + "id", + ] + group_data = self.query( + f"https://graph.microsoft.com/v1.0/groups?$select={','.join(queries)}" + ) + for group_dict in cast( + list[JSONDict], + sorted(group_data["value"], key=lambda group: group["createdDateTime"]), + ): + try: group_uid = self.uid_cache.get_group_uid(group_dict["id"]) attributes: JSONDict = {} attributes["cn"] = group_dict.get("displayName", None) @@ -57,8 +59,11 @@ def groups(self) -> list[JSONDict]: if user["userPrincipalName"] ] output.append(attributes) - except KeyError: - pass + except KeyError as exc: + msg = ( + f"Failed to process group {group_dict} due to a missing key {exc}." + ) + log.msg(msg) return output def users(self) -> list[JSONDict]: From a2ec61bcc9711ae9adf35b03ef34e1e2b68d2de3 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 21 May 2024 13:48:11 +0100 Subject: [PATCH 11/15] :bug: Ensure that userPrincipalName key exists before using it to construct group members --- apricot/oauth/microsoft_entra_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apricot/oauth/microsoft_entra_client.py b/apricot/oauth/microsoft_entra_client.py index 7ca1e07..eecfa41 100644 --- a/apricot/oauth/microsoft_entra_client.py +++ b/apricot/oauth/microsoft_entra_client.py @@ -56,7 +56,7 @@ def groups(self) -> list[JSONDict]: attributes["memberUid"] = [ str(user["userPrincipalName"]).split("@")[0] for user in members["value"] - if user["userPrincipalName"] + if user.get("userPrincipalName") ] output.append(attributes) except KeyError as exc: From b5fa15f34cb4a1ffb0e4189d330bf8be0f665c8b Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 23 May 2024 16:21:43 +0100 Subject: [PATCH 12/15] :rotating_light: Fix linting error in README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f13ad3b..170917d 100644 --- a/README.md +++ b/README.md @@ -129,13 +129,13 @@ member: CN=sherlock.holmes,OU=groups,DC= This allows a user to make a request for "all primary user groups needed by members of group X" without getting a large number of primary user groups for unrelated users. To do this, you will need an LDAP request that looks like: -```ldap +```ldif (&(objectClass=posixGroup)(|(CN=Detectives)(memberOf=Primary user groups for Detectives))) ``` which will return: -``` +```ldif dn:CN=Detectives,OU=groups,DC= objectClass: groupOfNames objectClass: posixGroup From 38a28e4da81d164676b04d330a54a07162860f18 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 23 May 2024 16:44:52 +0100 Subject: [PATCH 13/15] :rotating_light: Additional linting fixes --- apricot/apricot_server.py | 4 +- apricot/ldap/oauth_ldap_server_factory.py | 4 +- apricot/oauth/keycloak_client.py | 56 ++++++++++++----------- apricot/oauth/oauth_client.py | 11 +++-- apricot/types.py | 1 + 5 files changed, 42 insertions(+), 34 deletions(-) diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index a40c431..b45ea3e 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -48,7 +48,9 @@ def __init__( if self.debug: log.msg(f"Creating an OAuthClient for {backend}.") oauth_backend = OAuthClientMap[backend] - oauth_backend_args = inspect.getfullargspec(oauth_backend.__init__).args + oauth_backend_args = inspect.getfullargspec( + oauth_backend.__init__ # type: ignore + ).args oauth_client = oauth_backend( client_id=client_id, client_secret=client_secret, diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index 72afe0d..bcabc6c 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -17,7 +17,9 @@ def __init__( @param oauth_client: An OAuth client used to construct the LDAP tree """ # Create an LDAP lookup tree - self.adaptor = OAuthLDAPTree(domain, oauth_client, enable_mirrored_groups) + self.adaptor = OAuthLDAPTree( + domain, oauth_client, enable_mirrored_groups=enable_mirrored_groups + ) def __repr__(self) -> str: return f"{self.__class__.__name__} using adaptor {self.adaptor}" diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index d23d638..3c21a9c 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -5,9 +5,11 @@ from .oauth_client import OAuthClient -def get_single_value_attribute(obj: JSONDict, key: str, default=None) -> Any: +def get_single_value_attribute( + obj: JSONDict, key: str, default: str | None = None +) -> Any: for part in key.split("."): - obj = obj.get(part) + obj = obj.get(part) # type: ignore if obj is None: return default if isinstance(obj, list): @@ -35,7 +37,7 @@ def __init__( self.realm = keycloak_realm redirect_uri = "urn:ietf:wg:oauth:2.0:oob" # this is the "no redirect" URL - scopes = [] # this is the default scope + scopes: list[str] = [] # this is the default scope token_url = f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token" super().__init__( @@ -51,41 +53,40 @@ def extract_token(self, json_response: JSONDict) -> str: def groups(self) -> list[JSONDict]: output = [] try: - group_data = [] + group_data: list[JSONDict] = [] while data := self.query( f"{self.base_url}/admin/realms/{self.realm}/groups?first={len(group_data)}&max={self.max_rows}&briefRepresentation=false", use_client_secret=False, ): - group_data.extend(data) + group_data.extend(cast(list[JSONDict], data)) if len(data) != self.max_rows: break group_data = sorted( group_data, - key=lambda g: int( + key=lambda group: int( get_single_value_attribute( - g, "attributes.gid", default="9999999999" + group, "attributes.gid", default="9999999999" ), - 10, + base=10, ), ) next_gid = max( *( int( - get_single_value_attribute(g, "attributes.gid", default="-1"), - 10, + get_single_value_attribute( + group, "attributes.gid", default="-1" + ), + base=10, ) + 1 - for g in group_data + for group in group_data ), 3000, ) - for group_dict in cast( - list[JSONDict], - group_data, - ): + for group_dict in group_data: group_gid = get_single_value_attribute( group_dict, "attributes.gid", default=None ) @@ -122,46 +123,47 @@ def groups(self) -> list[JSONDict]: def users(self) -> list[JSONDict]: output = [] try: - user_data = [] + user_data: list[JSONDict] = [] while data := self.query( f"{self.base_url}/admin/realms/{self.realm}/users?first={len(user_data)}&max={self.max_rows}&briefRepresentation=false", use_client_secret=False, ): - user_data.extend(data) + user_data.extend(cast(list[JSONDict], data)) if len(data) != self.max_rows: break user_data = sorted( user_data, - key=lambda u: int( + key=lambda user: int( get_single_value_attribute( - u, "attributes.uid", default="9999999999" + user, "attributes.uid", default="9999999999" ), - 10, + base=10, ), ) next_uid = max( *( int( - get_single_value_attribute(g, "attributes.uid", default="-1"), - 10, + get_single_value_attribute( + user, "attributes.uid", default="-1" + ), + base=10, ) + 1 - for g in user_data + for user in user_data ), 3000, ) - for user_dict in cast( - list[JSONDict], - sorted(user_data, key=lambda user: user["createdTimestamp"]), + for user_dict in sorted( + user_data, key=lambda user: user["createdTimestamp"] ): user_uid = get_single_value_attribute( user_dict, "attributes.uid", default=None ) if user_uid: - user_uid = int(user_uid, 10) + user_uid = int(user_uid, base=10) if not user_uid: user_uid = next_uid next_uid += 1 diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 90d2143..b47f98c 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -109,7 +109,7 @@ def users(self) -> list[JSONDict]: """ pass - def query(self, url: str, *, use_client_secret=True) -> dict[str, Any]: + def query(self, url: str, *, use_client_secret: bool = True) -> dict[str, Any]: """ Make a query against the OAuth backend """ @@ -127,12 +127,12 @@ def query(self, url: str, *, use_client_secret=True) -> dict[str, Any]: **kwargs, ) - def request(self, *args, method="GET", **kwargs) -> dict[str, Any]: + def request(self, *args: Any, method: str = "GET", **kwargs: Any) -> dict[str, Any]: """ Make a request to the OAuth backend """ - def query_(*args, **kwargs) -> requests.Response: + def query_(*args: Any, **kwargs: Any) -> requests.Response: return self.session_application.request( # type: ignore[no-any-return] method, *args, @@ -147,8 +147,9 @@ def query_(*args, **kwargs) -> requests.Response: log.msg("Authentication token has expired.") self.bearer_token_ = None result = query_(*args, **kwargs) - if result.status_code != HTTPStatus.NO_CONTENT: - return result.json() # type: ignore + if result.status_code == HTTPStatus.NO_CONTENT: + return {} + return result.json() # type: ignore def verify(self, username: str, password: str) -> bool: """ diff --git a/apricot/types.py b/apricot/types.py index e93f9ea..5cc0617 100644 --- a/apricot/types.py +++ b/apricot/types.py @@ -1,5 +1,6 @@ from typing import Any JSONDict = dict[str, Any] +JSONKey = list[Any] | dict[str, Any] | Any LDAPAttributeDict = dict[str, list[str]] LDAPControlTuple = tuple[str, bool, Any] From f986e41d142a6e34892cbcd5cbfa6c3e1846c8b3 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 23 May 2024 19:15:21 +0100 Subject: [PATCH 14/15] :recycle: Switch KeycloakClient to use UidCache for generating missing UIDs --- apricot/apricot_server.py | 2 +- apricot/cache/uid_cache.py | 28 ++++++++ apricot/oauth/keycloak_client.py | 120 +++++++++---------------------- 3 files changed, 64 insertions(+), 86 deletions(-) diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index b45ea3e..fa98c22 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -22,7 +22,7 @@ def __init__( port: int, *, debug: bool = False, - enable_mirrored_groups: bool, + enable_mirrored_groups: bool = True, redis_host: str | None = None, redis_port: int | None = None, **kwargs: Any, diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py index ab46029..eb9c729 100644 --- a/apricot/cache/uid_cache.py +++ b/apricot/cache/uid_cache.py @@ -77,3 +77,31 @@ def _get_max_uid(self, category: str | None) -> int: keys = self.keys() values = [*self.values(keys), -999] return max(values) + + def overwrite_group_uid(self, identifier: str, uid: int) -> None: + """ + Set UID for a group, overwriting the existing value if there is one + + @param identifier: Identifier for group + @param uid: Desired UID + """ + return self.overwrite_uid(identifier, category="group", uid=uid) + + def overwrite_user_uid(self, identifier: str, uid: int) -> None: + """ + Get UID for a user, constructing one if necessary + + @param identifier: Identifier for user + @param uid: Desired UID + """ + return self.overwrite_uid(identifier, category="user", uid=uid) + + def overwrite_uid(self, identifier: str, category: str, uid: int) -> None: + """ + Set UID, overwriting the existing one if necessary. + + @param identifier: Identifier for object + @param category: Category the object belongs to + @param uid: Desired UID + """ + self.set(f"{category}-{identifier}", uid) diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index 3c21a9c..8c2f58d 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -5,23 +5,6 @@ from .oauth_client import OAuthClient -def get_single_value_attribute( - obj: JSONDict, key: str, default: str | None = None -) -> Any: - for part in key.split("."): - obj = obj.get(part) # type: ignore - if obj is None: - return default - if isinstance(obj, list): - try: - return next(iter(obj)) - except StopIteration: - pass - else: - return obj - return default - - class KeycloakClient(OAuthClient): """OAuth client for the Keycloak backend.""" @@ -62,41 +45,25 @@ def groups(self) -> list[JSONDict]: if len(data) != self.max_rows: break - group_data = sorted( - group_data, - key=lambda group: int( - get_single_value_attribute( - group, "attributes.gid", default="9999999999" - ), - base=10, - ), - ) - - next_gid = max( - *( - int( - get_single_value_attribute( - group, "attributes.gid", default="-1" - ), - base=10, + # Ensure that gid attribute exists for all groups + for group_dict in group_data: + group_dict["attributes"] = group_dict.get("attributes", {}) + if "gid" not in group_dict["attributes"]: + group_dict["attributes"]["gid"] = None + # If group_gid exists then set the cache to the same value + # This ensures that any groups without a `gid` attribute will receive a + # UID that does not overlap with existing groups + if group_gid := group_dict["attributes"]["gid"]: + self.uid_cache.overwrite_group_uid( + group_dict["id"], int(group_gid, 10) ) - + 1 - for group in group_data - ), - 3000, - ) + # Read group attributes for group_dict in group_data: - group_gid = get_single_value_attribute( - group_dict, "attributes.gid", default=None - ) - if group_gid: - group_gid = int(group_gid, 10) - if not group_gid: - group_gid = next_gid - next_gid += 1 - group_dict["attributes"] = group_dict.get("attributes", {}) - group_dict["attributes"]["gid"] = [str(group_gid)] + if not group_dict["attributes"]["gid"]: + group_dict["attributes"]["gid"] = [ + str(self.uid_cache.get_group_uid(group_dict["id"])) + ] self.request( f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}", method="PUT", @@ -105,7 +72,7 @@ def groups(self) -> list[JSONDict]: attributes: JSONDict = {} attributes["cn"] = group_dict.get("name", None) attributes["description"] = group_dict.get("id", None) - attributes["gidNumber"] = group_gid + attributes["gidNumber"] = group_dict["attributes"]["gid"] attributes["oauth_id"] = group_dict.get("id", None) # Add membership attributes members = self.query( @@ -132,44 +99,27 @@ def users(self) -> list[JSONDict]: if len(data) != self.max_rows: break - user_data = sorted( - user_data, - key=lambda user: int( - get_single_value_attribute( - user, "attributes.uid", default="9999999999" - ), - base=10, - ), - ) - - next_uid = max( - *( - int( - get_single_value_attribute( - user, "attributes.uid", default="-1" - ), - base=10, + # Ensure that uid attribute exists for all users + for user_dict in user_data: + user_dict["attributes"] = user_dict.get("attributes", {}) + if "uid" not in user_dict["attributes"]: + user_dict["attributes"]["uid"] = None + # If user_uid exists then set the cache to the same value. + # This ensures that any groups without a `gid` attribute will receive a + # UID that does not overlap with existing groups + if user_uid := user_dict["attributes"]["uid"]: + self.uid_cache.overwrite_user_uid( + user_dict["id"], int(user_uid, 10) ) - + 1 - for user in user_data - ), - 3000, - ) + # Read user attributes for user_dict in sorted( user_data, key=lambda user: user["createdTimestamp"] ): - user_uid = get_single_value_attribute( - user_dict, "attributes.uid", default=None - ) - if user_uid: - user_uid = int(user_uid, base=10) - if not user_uid: - user_uid = next_uid - next_uid += 1 - - user_dict["attributes"] = user_dict.get("attributes", {}) - user_dict["attributes"]["uid"] = [str(user_uid)] + if not user_dict["attributes"]["uid"]: + user_dict["attributes"]["uid"] = [ + str(self.uid_cache.get_user_uid(user_dict["id"])) + ] self.request( f"{self.base_url}/admin/realms/{self.realm}/users/{user_dict['id']}", method="PUT", @@ -189,12 +139,12 @@ def users(self) -> list[JSONDict]: attributes["displayName"] = full_name attributes["mail"] = user_dict.get("email") attributes["description"] = "" - attributes["gidNumber"] = user_uid + attributes["gidNumber"] = user_dict["attributes"]["uid"] attributes["givenName"] = first_name if first_name else "" attributes["homeDirectory"] = f"/home/{username}" if username else None attributes["oauth_id"] = user_dict.get("id", None) attributes["sn"] = last_name if last_name else "" - attributes["uidNumber"] = user_uid + attributes["uidNumber"] = user_dict["attributes"]["uid"] output.append(attributes) except KeyError: pass From 2483fa620b6e987db82a1c7f2a4fc673040fcd6e Mon Sep 17 00:00:00 2001 From: Felix Gustavsson Date: Sat, 25 May 2024 13:53:35 +0200 Subject: [PATCH 15/15] Fix fetching of uid from multi value attribute fix naming issue for mirrored groups parameter --- apricot/oauth/keycloak_client.py | 18 +++++++++++------- run.py | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index 8c2f58d..5b584c7 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -53,9 +53,11 @@ def groups(self) -> list[JSONDict]: # If group_gid exists then set the cache to the same value # This ensures that any groups without a `gid` attribute will receive a # UID that does not overlap with existing groups - if group_gid := group_dict["attributes"]["gid"]: + if (group_gid := group_dict["attributes"]["gid"]) and len( + group_dict["attributes"]["gid"] + ) == 1: self.uid_cache.overwrite_group_uid( - group_dict["id"], int(group_gid, 10) + group_dict["id"], int(group_gid[0], 10) ) # Read group attributes @@ -72,7 +74,7 @@ def groups(self) -> list[JSONDict]: attributes: JSONDict = {} attributes["cn"] = group_dict.get("name", None) attributes["description"] = group_dict.get("id", None) - attributes["gidNumber"] = group_dict["attributes"]["gid"] + attributes["gidNumber"] = group_dict["attributes"]["gid"][0] attributes["oauth_id"] = group_dict.get("id", None) # Add membership attributes members = self.query( @@ -107,9 +109,11 @@ def users(self) -> list[JSONDict]: # If user_uid exists then set the cache to the same value. # This ensures that any groups without a `gid` attribute will receive a # UID that does not overlap with existing groups - if user_uid := user_dict["attributes"]["uid"]: + if (user_uid := user_dict["attributes"]["uid"]) and len( + user_dict["attributes"]["uid"] + ) == 1: self.uid_cache.overwrite_user_uid( - user_dict["id"], int(user_uid, 10) + user_dict["id"], int(user_uid[0], 10) ) # Read user attributes @@ -139,12 +143,12 @@ def users(self) -> list[JSONDict]: attributes["displayName"] = full_name attributes["mail"] = user_dict.get("email") attributes["description"] = "" - attributes["gidNumber"] = user_dict["attributes"]["uid"] + attributes["gidNumber"] = user_dict["attributes"]["uid"][0] attributes["givenName"] = first_name if first_name else "" attributes["homeDirectory"] = f"/home/{username}" if username else None attributes["oauth_id"] = user_dict.get("id", None) attributes["sn"] = last_name if last_name else "" - attributes["uidNumber"] = user_dict["attributes"]["uid"] + attributes["uidNumber"] = user_dict["attributes"]["uid"][0] output.append(attributes) except KeyError: pass diff --git a/run.py b/run.py index ae3dbf0..c228f20 100644 --- a/run.py +++ b/run.py @@ -17,7 +17,7 @@ parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.") parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.") parser.add_argument("--disable-mirrored-groups", action="store_false", - dest="enable_mirrored", default=True, + dest="enable_mirrored_groups", default=True, help="Disable creation of mirrored groups.") parser.add_argument("--debug", action="store_true", help="Enable debug logging.") # Options for Microsoft Entra backend