diff --git a/README.md b/README.md index f00251b..0aa88c5 100644 --- a/README.md +++ b/README.md @@ -151,3 +151,28 @@ Do this as follows: - `Microsoft Graph` > `GroupMember.Read.All` (application) - `Microsoft Graph` > `User.Read.All` (delegated) - Select this and click the `Grant admin consent` button (otherwise manual consent is needed from each user) + + +### Keycloak + +You will need to use the following command line arguments: + +```bash +--backend Keycloak --extra-keycloak-base-url "/" --extra-keycloak-realm "" +``` + +You will need to register an application to interact with `Keycloak`. +Do this as follows: + +- Create a new `Client` in your `Keycloak` instance. + - Set the name to whatever you choose (e.g. `apricot`) + - Enable `Client authentication` + - Enable the following authentication flows and disable the rest: + - Direct access grants + - Service account roles +- Under `Credentials` copy `client secret` +- Under `Service account roles`: + - Ensure that the following role are assigned + - `realm-management` > `view-users` + - `realm-management` > `manage-users` + - `realm-management` > `query-groups` diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 7163ffb..1bdae49 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -1,3 +1,4 @@ +import inspect import sys from typing import Any, cast @@ -19,6 +20,7 @@ def __init__( client_secret: str, domain: str, port: int, + enable_group_of_groups: bool, *, debug: bool = False, redis_host: str | None = None, @@ -45,12 +47,14 @@ def __init__( try: if self.debug: log.msg(f"Creating an OAuthClient for {backend}.") - oauth_client = OAuthClientMap[backend]( + oauth_backend = OAuthClientMap[backend] + oauth_backend_args = inspect.getfullargspec(oauth_backend.__init__).args + oauth_client = oauth_backend( client_id=client_id, client_secret=client_secret, debug=debug, uid_cache=uid_cache, - **kwargs, + **{k: v for k, v in kwargs.items() if k in oauth_backend_args}, ) except Exception as exc: msg = f"Could not construct an OAuth client for the '{backend}' backend.\n{exc!s}" @@ -59,7 +63,7 @@ def __init__( # Create an LDAPServerFactory if self.debug: log.msg("Creating an LDAPServerFactory.") - factory = OAuthLDAPServerFactory(domain, oauth_client) + factory = OAuthLDAPServerFactory(domain, oauth_client, enable_group_of_groups) # Attach a listening endpoint if self.debug: diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index 2890b35..445a33c 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -8,14 +8,14 @@ class OAuthLDAPServerFactory(ServerFactory): - def __init__(self, domain: str, oauth_client: OAuthClient): + def __init__(self, domain: str, oauth_client: OAuthClient, enable_group_of_groups: bool): """ Initialise an LDAPServerFactory @param oauth_client: An OAuth client used to construct the LDAP tree """ # Create an LDAP lookup tree - self.adaptor = OAuthLDAPTree(domain, oauth_client) + self.adaptor = OAuthLDAPTree(domain, oauth_client, enable_group_of_groups) def __repr__(self) -> str: return f"{self.__class__.__name__} using adaptor {self.adaptor}" diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 136ce31..49a5639 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -14,7 +14,7 @@ class OAuthLDAPTree: def __init__( - self, domain: str, oauth_client: OAuthClient, refresh_interval: int = 60 + self, domain: str, oauth_client: OAuthClient, enable_group_of_groups: bool, refresh_interval: int = 60 ) -> None: """ Initialise an OAuthLDAPTree @@ -29,6 +29,7 @@ def __init__( self.oauth_client = oauth_client self.refresh_interval = refresh_interval self.root_: OAuthLDAPEntry | None = None + self.enable_group_of_groups = enable_group_of_groups @property def dn(self) -> DistinguishedName: @@ -47,7 +48,7 @@ def root(self) -> OAuthLDAPEntry: ): # Update users and groups from the OAuth server log.msg("Retrieving OAuth data.") - oauth_adaptor = OAuthDataAdaptor(self.domain, self.oauth_client) + oauth_adaptor = OAuthDataAdaptor(self.domain, self.oauth_client, self.enable_group_of_groups) # Create a root node for the tree log.msg("Rebuilding LDAP tree.") diff --git a/apricot/models/ldap_attribute_adaptor.py b/apricot/models/ldap_attribute_adaptor.py index dfd3bd1..40f986d 100644 --- a/apricot/models/ldap_attribute_adaptor.py +++ b/apricot/models/ldap_attribute_adaptor.py @@ -8,6 +8,7 @@ def __init__(self, attributes: dict[Any, Any]) -> None: self.attributes = { str(k): list(map(str, v)) if isinstance(v, list) else [str(v)] for k, v in attributes.items() + if v is not None } @property diff --git a/apricot/models/ldap_person.py b/apricot/models/ldap_person.py index 0656897..adedec6 100644 --- a/apricot/models/ldap_person.py +++ b/apricot/models/ldap_person.py @@ -1,3 +1,5 @@ +from typing import Optional + from .named_ldap_class import NamedLDAPClass @@ -13,6 +15,7 @@ class LDAPPerson(NamedLDAPClass): cn: str sn: str + mail: Optional[str] = None def names(self) -> list[str]: return ["person"] diff --git a/apricot/oauth/__init__.py b/apricot/oauth/__init__.py index 0cd8aa5..c5d6268 100644 --- a/apricot/oauth/__init__.py +++ b/apricot/oauth/__init__.py @@ -1,11 +1,15 @@ from apricot.types import LDAPAttributeDict, LDAPControlTuple from .enums import OAuthBackend +from .keycloak_client import KeycloakClient from .microsoft_entra_client import MicrosoftEntraClient from .oauth_client import OAuthClient from .oauth_data_adaptor import OAuthDataAdaptor -OAuthClientMap = {OAuthBackend.MICROSOFT_ENTRA: MicrosoftEntraClient} +OAuthClientMap = { + OAuthBackend.MICROSOFT_ENTRA: MicrosoftEntraClient, + OAuthBackend.KEYCLOAK: KeycloakClient, +} __all__ = [ "LDAPAttributeDict", diff --git a/apricot/oauth/enums.py b/apricot/oauth/enums.py index 8675218..d9c356d 100644 --- a/apricot/oauth/enums.py +++ b/apricot/oauth/enums.py @@ -5,3 +5,4 @@ class OAuthBackend(str, Enum): """Available OAuth backends.""" MICROSOFT_ENTRA = "MicrosoftEntra" + KEYCLOAK = "Keycloak" diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py new file mode 100644 index 0000000..d0123d4 --- /dev/null +++ b/apricot/oauth/keycloak_client.py @@ -0,0 +1,166 @@ +from typing import Any, cast + +from apricot.types import JSONDict + +from .oauth_client import OAuthClient + + +def get_single_value_attribute(obj: JSONDict, key: str, default=None) -> Any: + for part in key.split("."): + obj = obj.get(part) + if obj is None: + return default + if isinstance(obj, list): + try: + return next(iter(obj)) + except StopIteration: + pass + else: + return obj + return default + + +class KeycloakClient(OAuthClient): + """OAuth client for the Keycloak backend.""" + + max_rows = 100 + + def __init__( + self, + extra_keycloak_base_url: str, + extra_keycloak_realm: str, + **kwargs: Any, + ): + self.base_url = extra_keycloak_base_url + self.realm = extra_keycloak_realm + + redirect_uri = "urn:ietf:wg:oauth:2.0:oob" # this is the "no redirect" URL + scopes = [] # this is the default scope + token_url = ( + f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token" + ) + + super().__init__( + redirect_uri=redirect_uri, scopes=scopes, token_url=token_url, **kwargs, + ) + + def extract_token(self, json_response: JSONDict) -> str: + return str(json_response["access_token"]) + + def groups(self) -> list[JSONDict]: + output = [] + try: + group_data = [] + while data := self.query( + f"{self.base_url}/admin/realms/{self.realm}/groups?first={len(group_data)}&max={self.max_rows}&briefRepresentation=false" + ): + group_data.extend(data) + if len(data) != self.max_rows: + break + + group_data = sorted(group_data, key=lambda g: int(get_single_value_attribute(g, "attributes.gid", default="9999999999"), 10)) + + next_gid = max( + *( + int(get_single_value_attribute(g, "attributes.gid", default="-1"), 10)+1 + for g in group_data + ), + 3000 + ) + + for group_dict in cast( + list[JSONDict], + group_data, + ): + group_gid = get_single_value_attribute(group_dict, "attributes.gid", default=None) + if group_gid: + group_gid = int(group_gid, 10) + if not group_gid: + group_gid = next_gid + next_gid += 1 + group_dict["attributes"] = group_dict.get("attributes", {}) + group_dict["attributes"]["gid"] = [str(group_gid)] + self.request( + f"https://{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}", + method="PUT", + json=group_dict + ) + attributes: JSONDict = {} + attributes["cn"] = group_dict.get("name", None) + attributes["description"] = group_dict.get("id", None) + attributes["gidNumber"] = group_gid + attributes["oauth_id"] = group_dict.get("id", None) + # Add membership attributes + members = self.query( + f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}/members" + ) + attributes["memberUid"] = [ + user["username"] + for user in cast(list[JSONDict], members) + ] + output.append(attributes) + except KeyError: + pass + return output + + def users(self) -> list[JSONDict]: + output = [] + try: + user_data = [] + while data := self.query( + f"{self.base_url}/admin/realms/{self.realm}/users?first={len(user_data)}&max={self.max_rows}&briefRepresentation=false" + ): + user_data.extend(data) + if len(data) != self.max_rows: + break + + user_data = sorted(user_data, key=lambda u: int(get_single_value_attribute(u, "attributes.uid", default="9999999999"), 10)) + + next_uid = max( + *( + int(get_single_value_attribute(g, "attributes.uid", default="-1"), 10)+1 + for g in user_data + ), + 3000 + ) + + for user_dict in cast( + list[JSONDict], + sorted(user_data, key=lambda user: user["createdTimestamp"]), + ): + user_uid = get_single_value_attribute(user_dict, "attributes.uid", default=None) + if user_uid: + user_uid = int(user_uid, 10) + if not user_uid: + user_uid = next_uid + next_uid += 1 + + user_dict["attributes"] = user_dict.get("attributes", {}) + user_dict["attributes"]["uid"] = [str(user_uid)] + self.request( + f"{self.base_url}/admin/realms/{self.realm}/users/{user_dict['id']}", + method="PUT", + json=user_dict + ) + # Get user attributes + first_name = user_dict.get("firstName", None) + last_name = user_dict.get("lastName", None) + full_name = f"{first_name or ""} {last_name or ""}".strip() + username = user_dict.get("username") + attributes: JSONDict = {} + attributes["cn"] = username + attributes["uid"] = username + attributes["oauth_username"] = username + attributes["displayName"] = full_name + attributes["mail"] = user_dict.get("email") + attributes["description"] = "" + attributes["gidNumber"] = user_uid + attributes["givenName"] = first_name if first_name else "" + attributes["homeDirectory"] = f"/home/{username}" if username else None + attributes["oauth_id"] = user_dict.get("id", None) + attributes["sn"] = last_name if last_name else "" + attributes["uidNumber"] = user_uid + output.append(attributes) + except KeyError: + pass + return output diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 857b553..a4e1753 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -1,5 +1,6 @@ import os from abc import ABC, abstractmethod +from http import HTTPStatus from typing import Any import requests @@ -116,9 +117,7 @@ def query(self, url: str) -> dict[str, Any]: def query_(url: str) -> requests.Response: return self.session_application.get( # type: ignore[no-any-return] url=url, - headers={"Authorization": f"Bearer {self.bearer_token}"}, - client_id=self.session_application._client.client_id, - client_secret=self.client_secret, + headers={"Authorization": f"Bearer {self.bearer_token}"} ) try: @@ -130,6 +129,28 @@ def query_(url: str) -> requests.Response: result = query_(url) return result.json() # type: ignore + def request(self, *args, method="GET", **kwargs) -> dict[str, Any]: + """ + Make a query against the OAuth backend + """ + + def query_(*args, **kwargs) -> requests.Response: + return self.session_application.request( # type: ignore[no-any-return] + method, + *args, **kwargs, + headers={"Authorization": f"Bearer {self.bearer_token}"} + ) + + try: + result = query_(*args, **kwargs) + result.raise_for_status() + except (TokenExpiredError, requests.exceptions.HTTPError): + log.msg("Authentication token has expired.") + self.bearer_token_ = None + result = query_( *args, **kwargs) + if result.status_code != HTTPStatus.NO_CONTENT: + return result.json() # type: ignore + def verify(self, username: str, password: str) -> bool: """ Verify username and password by attempting to authenticate against the OAuth backend. diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index 701e55a..8a7048a 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -21,10 +21,11 @@ class OAuthDataAdaptor: """Adaptor for converting raw user and group data into LDAP format.""" - def __init__(self, domain: str, oauth_client: OAuthClient): + def __init__(self, domain: str, oauth_client: OAuthClient, enable_group_of_groups: bool): self.debug = oauth_client.debug self.oauth_client = oauth_client self.root_dn = "DC=" + domain.replace(".", ",DC=") + self.enable_group_of_groups = enable_group_of_groups # Retrieve and validate user and group information annotated_groups, annotated_users = self._retrieve_entries() @@ -105,20 +106,21 @@ def _retrieve_entries( # Add one group of groups for each existing group. # Its members are the primary user groups for each original group member. groups_of_groups = [] - for group in oauth_groups: - group_dict = {} - group_dict["cn"] = f"Primary user groups for {group['cn']}" - group_dict["description"] = ( - f"Primary user groups for members of '{group['cn']}'" - ) - # Replace each member user with a member group - group_dict["member"] = [ - str(member).replace("OU=users", "OU=groups") - for member in group["member"] - ] - # Groups do not have UIDs so memberUid must be empty - group_dict["memberUid"] = [] - groups_of_groups.append(group_dict) + if self.enable_group_of_groups: + for group in oauth_groups: + group_dict = {} + group_dict["cn"] = f"Primary user groups for {group['cn']}" + group_dict["description"] = ( + f"Primary user groups for members of '{group['cn']}'" + ) + # Replace each member user with a member group + group_dict["member"] = [ + str(member).replace("OU=users", "OU=groups") + for member in group["member"] + ] + # Groups do not have UIDs so memberUid must be empty + group_dict["memberUid"] = [] + groups_of_groups.append(group_dict) # Ensure memberOf is set correctly for users for child_dict in oauth_users: diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 8a21379..1e46c03 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -37,6 +37,10 @@ if [ -n "${DEBUG}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --debug" fi +if [ -n "${DISABLE_GROUP_OF_GROUPS}" ]; then + EXTRA_OPTS="${EXTRA_OPTS} --disable-group-of-groups" +fi + if [ -n "${ENTRA_TENANT_ID}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --entra-tenant-id $ENTRA_TENANT_ID" fi @@ -49,6 +53,14 @@ if [ -n "${REDIS_HOST}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --redis-host $REDIS_HOST --redis-port $REDIS_PORT" fi +if [ -n "${KEYCLOAK_BASE_URL}" ]; then + if [ -z "${KEYCLOAK_REALM}" ]; then + echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] KEYCLOAK_REALM environment variable is not set" + exit 1 + fi + EXTRA_OPTS="${EXTRA_OPTS} --extra-keycloak-base-url $KEYCLOAK_BASE_URL --extra-keycloak-realm $KEYCLOAK_REALM" +fi + # Run the server hatch run python run.py \ --backend "${BACKEND}" \ diff --git a/run.py b/run.py index 5ac4230..af26c32 100644 --- a/run.py +++ b/run.py @@ -16,10 +16,17 @@ parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.") parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.") parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.") + parser.add_argument("--disable-group-of-groups", action="store_false", + dest="enable_group_of_groups", default=True, + help="Disable creation of group-of-groups.") parser.add_argument("--debug", action="store_true", help="Enable debug logging.") # Options for Microsoft Entra backend entra_group = parser.add_argument_group("Microsoft Entra") entra_group.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False) + + entra_group_keycloak = parser.add_argument_group("Keycloak") + entra_group_keycloak.add_argument("--extra-keycloak-base-url", type=str, help="Keycloak base URL.", required=False) + entra_group_keycloak.add_argument("--extra-keycloak-realm", type=str, help="Keycloak Realm.", required=False) # Options for Redis cache redis_group = parser.add_argument_group("Redis") redis_group.add_argument("--redis-host", type=str, help="Host for Redis server.")