diff --git a/README.md b/README.md index 7308a4c..6e57db4 100644 --- a/README.md +++ b/README.md @@ -5,10 +5,12 @@ The name is a slightly tortured acronym for: LD**A**P **pr**oxy for Open**I**D * ## Usage +**N.B.** As Apricot uses a Redis server to store generated `uidNumber` and `gidNumber` values. + Start the `Apricot` server on port 1389 by running: ```bash -python run.py --client-id "" --client-secret "" --backend "" --port 1389 --domain "" +python run.py --client-id "" --client-secret "" --backend "" --port 1389 --domain "" --redis-host "" ``` Alternatively, you can run in Docker by editing `docker/docker-compose.yaml` and running: @@ -40,10 +42,11 @@ Each user will have an entry like ```ldif dn: CN=,OU=users,DC= -objectClass: organizationalPerson +objectClass: inetOrgPerson +objectClass: inetuser objectClass: person +objectClass: posixAccount objectClass: top -objectClass: user ``` @@ -51,7 +54,8 @@ Each group will have an entry like ```ldif dn: CN=,OU=groups,DC= -objectClass: group +objectClass: groupOfNames +objectClass: posixGroup objectClass: top ``` diff --git a/apricot/__init__.py b/apricot/__init__.py index 201e2e7..0a47852 100644 --- a/apricot/__init__.py +++ b/apricot/__init__.py @@ -1,5 +1,6 @@ from .__about__ import __version__, __version_info__ from .apricot_server import ApricotServer +from .patches import LDAPString # noqa: F401 __all__ = [ "__version__", diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index d7f65c8..4e34944 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -18,7 +18,8 @@ def __init__( client_secret: str, domain: str, port: int, - uid_attribute: str, + redis_host: str, + redis_port: int, **kwargs: Any, ) -> None: # Log to stdout @@ -30,7 +31,8 @@ def __init__( client_id=client_id, client_secret=client_secret, domain=domain, - uid_attribute=uid_attribute, + redis_host=redis_host, + redis_port=redis_port, **kwargs, ) except Exception as exc: diff --git a/apricot/cache/__init__.py b/apricot/cache/__init__.py new file mode 100644 index 0000000..2a785f5 --- /dev/null +++ b/apricot/cache/__init__.py @@ -0,0 +1,5 @@ +from .uid_cache import UidCache + +__all__ = [ + "UidCache", +] diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py new file mode 100644 index 0000000..e5962ef --- /dev/null +++ b/apricot/cache/uid_cache.py @@ -0,0 +1,75 @@ +from typing import cast + +import redis + + +class UidCache: + def __init__(self, redis_host: str, redis_port: str) -> None: + self.redis_host = redis_host + self.redis_port = redis_port + self.cache_ = None + + @property + def cache(self) -> redis.Redis: # type: ignore[type-arg] + """ + Lazy-load the cache on request + """ + if not self.cache_: + self.cache_ = redis.Redis( # type: ignore[call-overload] + host=self.redis_host, port=self.redis_port, decode_responses=True + ) + return self.cache_ # type: ignore[return-value] + + @property + def keys(self) -> list[str]: + """ + Get list of keys from the cache + """ + return [str(k) for k in self.cache.keys()] + + def get_group_uid(self, identifier: str) -> int: + """ + Get UID for a group, constructing one if necessary + + @param identifier: Identifier for group needing a UID + """ + return self.get_uid(identifier, category="group", min_value=3000) + + def get_user_uid(self, identifier: str) -> int: + """ + Get UID for a user, constructing one if necessary + + @param identifier: Identifier for user needing a UID + """ + return self.get_uid(identifier, category="user", min_value=2000) + + def get_uid( + self, identifier: str, category: str, min_value: int | None = None + ) -> int: + """ + Get UID, constructing one if necessary. + + @param identifier: Identifier for object needing a UID + @param category: Category the object belongs to + @param min_value: Minimum allowed value for the UID + """ + identifier_ = f"{category}-{identifier}" + uid = self.cache.get(identifier_) + if not uid: + min_value = min_value if min_value else 0 + next_uid = max(self._get_max_uid(category) + 1, min_value) + self.cache.set(identifier_, next_uid) + return cast(int, self.cache.get(identifier_)) + + def _get_max_uid(self, category: str | None) -> int: + """ + Get maximum UID for a given category + + @param category: Category to check UIDs for + """ + if category: + keys = [k for k in self.keys if k.startswith(category)] + else: + keys = self.keys + values = [int(cast(str, v)) for v in self.cache.mget(keys)] + [-999] + return max(values) diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 4133371..c19e8af 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -4,7 +4,7 @@ from zope.interface import implementer from apricot.ldap.oauth_ldap_entry import OAuthLDAPEntry -from apricot.oauth import LDAPAttributeDict, OAuthClient +from apricot.oauth import OAuthClient @implementer(IConnectedLDAPEntry) @@ -17,41 +17,40 @@ def __init__(self, oauth_client: OAuthClient) -> None: @param oauth_client: An OAuth client used to construct the LDAP tree """ - self.oauth_client = oauth_client + self.oauth_client: OAuthClient = oauth_client + self.root_: OAuthLDAPEntry | None = None - # Create a root node for the tree - self.root = self.build_root( - dn=self.oauth_client.root_dn, attributes={"objectClass": ["dcObject"]} - ) - # Add OUs for users and groups - groups_ou = self.root.add_child( - "OU=groups", {"ou": ["groups"], "objectClass": ["organizationalUnit"]} - ) - users_ou = self.root.add_child( - "OU=users", {"ou": ["users"], "objectClass": ["organizationalUnit"]} - ) - # Add groups to the groups OU - for group_attrs in self.oauth_client.validated_groups(): - groups_ou.add_child(f"CN={group_attrs['cn'][0]}", group_attrs) - # Add users to the users OU - for user_attrs in self.oauth_client.validated_users(): - users_ou.add_child(f"CN={user_attrs['cn'][0]}", user_attrs) - - def __repr__(self) -> str: - return f"{self.__class__.__name__} with backend {self.oauth_client.__class__.__name__}" - - def build_root(self, dn: str, attributes: LDAPAttributeDict) -> OAuthLDAPEntry: + @property + def root(self) -> OAuthLDAPEntry: """ - Construct the root of the LDAP tree + Lazy-load the LDAP tree on request - @param dn: Distinguished Name of the object - @param attributes: Attributes of the object. - - @return: An OAuthLDAPEntry + @return: An OAuthLDAPEntry for the tree """ - return OAuthLDAPEntry( - dn=dn, attributes=attributes, oauth_client=self.oauth_client - ) + if not self.root_: + # Create a root node for the tree + self.root_ = OAuthLDAPEntry( + dn=self.oauth_client.root_dn, + attributes={"objectClass": ["dcObject"]}, + oauth_client=self.oauth_client, + ) + # Add OUs for users and groups + groups_ou = self.root_.add_child( + "OU=groups", {"ou": ["groups"], "objectClass": ["organizationalUnit"]} + ) + users_ou = self.root_.add_child( + "OU=users", {"ou": ["users"], "objectClass": ["organizationalUnit"]} + ) + # Add groups to the groups OU + for group_attrs in self.oauth_client.validated_groups(): + groups_ou.add_child(f"CN={group_attrs['cn'][0]}", group_attrs) + # Add users to the users OU + for user_attrs in self.oauth_client.validated_users(): + users_ou.add_child(f"CN={user_attrs['cn'][0]}", user_attrs) + return self.root_ + + def __repr__(self) -> str: + return f"{self.__class__.__name__} with backend {self.oauth_client.__class__.__name__}" def lookup(self, dn: DistinguishedName | str) -> defer.Deferred[ILDAPEntry]: """ diff --git a/apricot/oauth/microsoft_entra_client.py b/apricot/oauth/microsoft_entra_client.py index f2ff346..af429c9 100644 --- a/apricot/oauth/microsoft_entra_client.py +++ b/apricot/oauth/microsoft_entra_client.py @@ -28,15 +28,23 @@ def extract_token(self, json_response: JSONDict) -> str: def groups(self) -> list[dict[str, Any]]: output = [] try: - group_data = self.query("https://graph.microsoft.com/v1.0/groups/") - for group_dict in cast(list[dict[str, Any]], group_data["value"]): + queries = [ + "createdDateTime", + "displayName", + "id", + ] + group_data = self.query( + f"https://graph.microsoft.com/v1.0/groups?$select={','.join(queries)}" + ) + for group_dict in cast( + list[dict[str, Any]], + sorted(group_data["value"], key=lambda group: group["createdDateTime"]), + ): + group_uid = self.uid_cache.get_group_uid(group_dict["id"]) attributes = {} attributes["cn"] = group_dict.get("displayName", None) attributes["description"] = group_dict.get("id", None) - # As we cannot manually set any attributes we take the last part of the securityIdentifier - attributes["gidNumber"] = str( - group_dict.get("securityIdentifier", "") - ).split("-")[-1] + attributes["gidNumber"] = group_uid # Add membership attributes members = self.query( f"https://graph.microsoft.com/v1.0/groups/{group_dict['id']}/members" @@ -59,30 +67,34 @@ def users(self) -> list[dict[str, Any]]: output = [] try: queries = [ + "createdDateTime", "displayName", "givenName", "id", "surname", "userPrincipalName", - self.uid_attribute, ] user_data = self.query( f"https://graph.microsoft.com/v1.0/users?$select={','.join(queries)}" ) - for user_dict in cast(list[dict[str, Any]], user_data["value"]): + for user_dict in cast( + list[dict[str, Any]], + sorted(user_data["value"], key=lambda user: user["createdDateTime"]), + ): # Get user attributes uid, domain = str(user_dict.get("userPrincipalName", "@")).split("@") + user_uid = self.uid_cache.get_user_uid(user_dict["id"]) attributes = {} attributes["cn"] = user_dict.get("displayName", None) attributes["description"] = user_dict.get("id", None) attributes["displayName"] = attributes.get("cn", None) attributes["domain"] = domain - attributes["gidNumber"] = user_dict.get(self.uid_attribute, None) + attributes["gidNumber"] = user_uid attributes["givenName"] = user_dict.get("givenName", "") attributes["homeDirectory"] = f"/home/{uid}" if uid else None attributes["sn"] = user_dict.get("surname", "") attributes["uid"] = uid if uid else None - attributes["uidNumber"] = user_dict.get(self.uid_attribute, None) + attributes["uidNumber"] = user_uid # Add group attributes group_memberships = self.query( f"https://graph.microsoft.com/v1.0/users/{user_dict['id']}/memberOf" diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index f3de1cc..31c7e91 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -11,6 +11,7 @@ from requests_oauthlib import OAuth2Session from twisted.python import log +from apricot.cache import UidCache from apricot.models import ( LdapGroupOfNames, LdapInetOrgPerson, @@ -32,15 +33,16 @@ def __init__( client_secret: str, domain: str, redirect_uri: str, + redis_host: str, + redis_port: str, scopes: list[str], token_url: str, - uid_attribute: str, ) -> None: # Set attributes self.client_secret = client_secret self.domain = domain self.token_url = token_url - self.uid_attribute = uid_attribute + self.uid_cache = UidCache(redis_host=redis_host, redis_port=redis_port) # Allow token scope to not match requested scope. (Other auth libraries allow # this, but Requests-OAuthlib raises exception on scope mismatch by default.) os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" # noqa: S105 @@ -107,6 +109,9 @@ def root_dn(self) -> str: return "DC=" + self.domain.replace(".", ",DC=") def query(self, url: str) -> dict[str, Any]: + """ + Make a query against the Microsoft Entra directory + """ result = self.session_application.request( method="GET", url=url, @@ -163,6 +168,10 @@ def validated_users(self) -> list[LDAPAttributeDict]: for user_dict in self.users(): try: attributes = {"objectclass": ["top"]} + # Add user to self-titled group + user_dict["memberOf"].append( + f"CN={user_dict['cn']},OU=groups,{self.root_dn}" + ) # Add 'inetOrgPerson' attributes inetorg_person = LdapInetOrgPerson(**user_dict) attributes.update(inetorg_person.model_dump()) diff --git a/apricot/patches/__init__.py b/apricot/patches/__init__.py new file mode 100644 index 0000000..9d48553 --- /dev/null +++ b/apricot/patches/__init__.py @@ -0,0 +1,5 @@ +from .ldap_string import LDAPString # type: ignore[attr-defined] + +__all__ = [ + "LDAPString", +] diff --git a/apricot/patches/ldap_string.py b/apricot/patches/ldap_string.py new file mode 100644 index 0000000..41bfc45 --- /dev/null +++ b/apricot/patches/ldap_string.py @@ -0,0 +1,17 @@ +"""Patch LDAPString to avoid TypeError when parsing LDAP filter strings""" + +from typing import Any + +from ldaptor.protocols.pureldap import LDAPString + +old_init = LDAPString.__init__ + + +def patched_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore[no-untyped-def] + """Patch LDAPString init to store its value as 'str' not 'bytes'""" + old_init(self, *args, **kwargs) + if isinstance(self.value, bytes): + self.value = self.value.decode() + + +LDAPString.__init__ = patched_init diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 2d13d7a..3824a35 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -11,6 +11,16 @@ services: CLIENT_SECRET: "" DOMAIN: "" ENTRA_TENANT_ID: "" + REDIS_HOST: "redis" ports: - "1389:1389" restart: always + + redis: + container_name: redis + image: redis:7.2 + ports: + - "6379:6379" + volumes: + - :/data + restart: always diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 00789e9..c689ff2 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -23,8 +23,8 @@ if [ -z "${DOMAIN}" ]; then exit 1 fi -if [ -z "${UID_ATTRIBUTE}" ]; then - echo "UID_ATTRIBUTE environment variable is not set" +if [ -z "${REDIS_HOST}" ]; then + echo "REDIS_HOST environment variable is not set" exit 1 fi @@ -34,6 +34,11 @@ if [ -z "${PORT}" ]; then PORT="1389" fi +if [ -z "${REDIS_PORT}" ]; then + echo "REDIS_PORT environment variable is not set: using default of 6379" + REDIS_PORT="6379" +fi + # Optional arguments EXTRA_OPTS="" if [ -n "${ENTRA_TENANT_ID}" ]; then @@ -42,10 +47,11 @@ fi # Run the server hatch run python run.py \ - --backend "$BACKEND" \ - --client-id "$CLIENT_ID" \ - --client-secret "$CLIENT_SECRET" \ - --domain "$DOMAIN" \ + --backend "${BACKEND}" \ + --client-id "${CLIENT_ID}" \ + --client-secret "${CLIENT_SECRET}" \ + --domain "${DOMAIN}" \ --port "${PORT}" \ - --uid-attribute "${UID_ATTRIBUTE}" \ + --redis-host "${REDIS_HOST}" \ + --redis-port "${REDIS_PORT}" \ $EXTRA_OPTS diff --git a/pyproject.toml b/pyproject.toml index 0f689cc..d7b5f9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,15 +21,17 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ - "ldaptor~=21.2.0", - "oauthlib~=3.2.0", - "pydantic~=2.4.0", - "requests-oauthlib~=1.3.0", - "Twisted~=23.10.0", + "ldaptor~=21.2", + "oauthlib~=3.2", + "pydantic~=2.4", + "redis~=5.0", + "requests-oauthlib~=1.3", + "Twisted~=23.10", "zope.interface~=6.2", ] @@ -51,6 +53,7 @@ dependencies = [ "mypy~=1.8.0", "ruff~=0.2.0", "types-oauthlib~=3.2.0", + "types-redis~=4.6", ] [tool.hatch.envs.lint.scripts] diff --git a/run.py b/run.py index ca441f3..5bb3868 100644 --- a/run.py +++ b/run.py @@ -12,11 +12,12 @@ ) # Common options needed for all backends parser.add_argument("-b", "--backend", type=OAuthBackend, help="Which OAuth backend to use.") - parser.add_argument("-d", "--domain", type=str, help="Which domain users belong to.") - parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.") parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.") parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.") - parser.add_argument("-u", "--uid-attribute", type=str, help="Which user attribute to use for UID.") + parser.add_argument("-d", "--domain", type=str, help="Which domain users belong to.") + parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.") + parser.add_argument("--redis-host", type=str, help="Host for Redis server.") + parser.add_argument("--redis-port", type=int, help="Port for Redis server.") # Options for Microsoft Entra backend group = parser.add_argument_group("Microsoft Entra") group.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False) @@ -26,9 +27,14 @@ # Create the Apricot server reactor = ApricotServer(**vars(args)) except Exception as exc: - msg = f"Unable to initialise Apricot server.\n{str(exc)}" + msg = f"Unable to initialise Apricot server.\n{exc}" print(msg) sys.exit(1) # Run the Apricot server - reactor.run() + try: + reactor.run() + except Exception as exc: + msg = f"Apricot server encountered a runtime problem.\n{exc}" + print(msg) + sys.exit(1)