From 1f28a867fd80517c008809729380f834270704a7 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 27 Feb 2024 13:11:45 +0000 Subject: [PATCH 1/5] :loud_sound: Add missing docstring --- apricot/oauth/oauth_client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 7642655..f5dc908 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -92,6 +92,9 @@ def bearer_token(self) -> str: @abstractmethod def extract_token(self, json_response: JSONDict) -> str: + """ + Extract the bearer token from an OAuth2Session JSON response + """ pass @abstractmethod From 306ea9c7accbce737fb0c6af1bccb6d22ac2ffed Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 27 Feb 2024 13:31:31 +0000 Subject: [PATCH 2/5] :truck: Refactor UidCache into a base class and a child class --- apricot/cache/__init__.py | 4 +-- apricot/cache/redis_cache.py | 35 +++++++++++++++++++++++ apricot/cache/uid_cache.py | 53 ++++++++++++++++++++--------------- apricot/oauth/oauth_client.py | 4 +-- 4 files changed, 69 insertions(+), 27 deletions(-) create mode 100644 apricot/cache/redis_cache.py diff --git a/apricot/cache/__init__.py b/apricot/cache/__init__.py index 2a785f5..d4ccad6 100644 --- a/apricot/cache/__init__.py +++ b/apricot/cache/__init__.py @@ -1,5 +1,5 @@ -from .uid_cache import UidCache +from .redis_cache import RedisCache __all__ = [ - "UidCache", + "RedisCache", ] diff --git a/apricot/cache/redis_cache.py b/apricot/cache/redis_cache.py new file mode 100644 index 0000000..8e753ef --- /dev/null +++ b/apricot/cache/redis_cache.py @@ -0,0 +1,35 @@ +from typing import cast + +import redis + +from .uid_cache import UidCache + + +class RedisCache(UidCache): + def __init__(self, redis_host: str, redis_port: str) -> None: + self.redis_host = redis_host + self.redis_port = redis_port + self.cache_ = None + + @property + def cache(self) -> redis.Redis: # type: ignore[type-arg] + """ + Lazy-load the cache on request + """ + if not self.cache_: + self.cache_ = redis.Redis( # type: ignore[call-overload] + host=self.redis_host, port=self.redis_port, decode_responses=True + ) + return self.cache_ # type: ignore[return-value] + + def get(self, identifier: str) -> int | None: + return self.cache.get(identifier) + + def keys(self) -> list[str]: + return [str(k) for k in self.cache.keys()] + + def set(self, identifier: str, uid_value: int) -> None: + self.cache.set(identifier, uid_value) + + def values(self, keys: list[str]) -> list[int]: + return [int(cast(str, v)) for v in self.cache.mget(keys)] diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py index e5962ef..eab70c0 100644 --- a/apricot/cache/uid_cache.py +++ b/apricot/cache/uid_cache.py @@ -1,31 +1,38 @@ +from abc import ABC, abstractmethod from typing import cast -import redis - -class UidCache: - def __init__(self, redis_host: str, redis_port: str) -> None: - self.redis_host = redis_host - self.redis_port = redis_port +class UidCache(ABC): + def __init__(self) -> None: self.cache_ = None - @property - def cache(self) -> redis.Redis: # type: ignore[type-arg] + @abstractmethod + def get(self, identifier: str) -> int | None: """ - Lazy-load the cache on request + Get the UID for a given identifier, returning None if it does not exist """ - if not self.cache_: - self.cache_ = redis.Redis( # type: ignore[call-overload] - host=self.redis_host, port=self.redis_port, decode_responses=True - ) - return self.cache_ # type: ignore[return-value] + pass - @property + @abstractmethod def keys(self) -> list[str]: """ - Get list of keys from the cache + Get list of cached keys + """ + pass + + @abstractmethod + def set(self, identifier: str, uid_value: int) -> None: + """ + Set the UID for a given identifier + """ + pass + + @abstractmethod + def values(self, keys: list[str]) -> list[int]: + """ + Get list of cached values corresponding to requested keys """ - return [str(k) for k in self.cache.keys()] + pass def get_group_uid(self, identifier: str) -> int: """ @@ -54,12 +61,12 @@ def get_uid( @param min_value: Minimum allowed value for the UID """ identifier_ = f"{category}-{identifier}" - uid = self.cache.get(identifier_) + uid = self.get(identifier_) if not uid: min_value = min_value if min_value else 0 next_uid = max(self._get_max_uid(category) + 1, min_value) - self.cache.set(identifier_, next_uid) - return cast(int, self.cache.get(identifier_)) + self.set(identifier_, next_uid) + return cast(int, self.get(identifier_)) def _get_max_uid(self, category: str | None) -> int: """ @@ -68,8 +75,8 @@ def _get_max_uid(self, category: str | None) -> int: @param category: Category to check UIDs for """ if category: - keys = [k for k in self.keys if k.startswith(category)] + keys = [k for k in self.keys() if k.startswith(category)] else: - keys = self.keys - values = [int(cast(str, v)) for v in self.cache.mget(keys)] + [-999] + keys = self.keys() + values = [*self.values(keys), -999] return max(values) diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index f5dc908..71b3c13 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -13,7 +13,7 @@ from requests_oauthlib import OAuth2Session from twisted.python import log -from apricot.cache import UidCache +from apricot.cache import RedisCache from apricot.models import ( LdapGroupOfNames, LdapInetOrgPerson, @@ -45,7 +45,7 @@ def __init__( self.client_secret = client_secret self.domain = domain self.token_url = token_url - self.uid_cache = UidCache(redis_host=redis_host, redis_port=redis_port) + self.uid_cache = RedisCache(redis_host=redis_host, redis_port=redis_port) # Allow token scope to not match requested scope. (Other auth libraries allow # this, but Requests-OAuthlib raises exception on scope mismatch by default.) os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" # noqa: S105 From 58cde67379e7b62ddba593b0a90e66637c4a1000 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 27 Feb 2024 13:54:28 +0000 Subject: [PATCH 3/5] :sparkles: Add a local cache option --- apricot/cache/__init__.py | 2 ++ apricot/cache/local_cache.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+) create mode 100644 apricot/cache/local_cache.py diff --git a/apricot/cache/__init__.py b/apricot/cache/__init__.py index d4ccad6..8f67386 100644 --- a/apricot/cache/__init__.py +++ b/apricot/cache/__init__.py @@ -1,5 +1,7 @@ +from .local_cache import LocalCache from .redis_cache import RedisCache __all__ = [ + "LocalCache", "RedisCache", ] diff --git a/apricot/cache/local_cache.py b/apricot/cache/local_cache.py new file mode 100644 index 0000000..958217b --- /dev/null +++ b/apricot/cache/local_cache.py @@ -0,0 +1,18 @@ +from .uid_cache import UidCache + + +class LocalCache(UidCache): + def __init__(self) -> None: + self.cache: dict[str, int] = {} + + def get(self, identifier: str) -> int | None: + return self.cache.get(identifier, None) + + def keys(self) -> list[str]: + return [str(k) for k in self.cache.keys()] + + def set(self, identifier: str, uid_value: int) -> None: + self.cache[identifier] = uid_value + + def values(self, keys: list[str]) -> list[int]: + return [v for k, v in self.cache.items() if k in keys] From 42629f5b003133e2e51488e1f6824bfe1cee2184 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 27 Feb 2024 14:06:53 +0000 Subject: [PATCH 4/5] :truck: Default to using a local cache if no Redis server information is provided --- README.md | 7 +++++-- apricot/apricot_server.py | 18 ++++++++++++++---- apricot/cache/__init__.py | 2 ++ apricot/oauth/oauth_client.py | 7 +++---- docker/entrypoint.sh | 20 +++++++++----------- run.py | 10 ++++++---- 6 files changed, 39 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 6e57db4..b1d1f02 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,6 @@ The name is a slightly tortured acronym for: LD**A**P **pr**oxy for Open**I**D * ## Usage -**N.B.** As Apricot uses a Redis server to store generated `uidNumber` and `gidNumber` values. - Start the `Apricot` server on port 1389 by running: ```bash @@ -21,6 +19,11 @@ docker compose up from the `docker` directory. +### Using Redis [Optional] + +You can use a Redis server to store generated `uidNumber` and `gidNumber` values in a more persistent way. +To do this, you will need to provide the `--redis-host` and `--redis-port` arguments to `run.py`. + ## Outputs This will create an LDAP tree that looks like this: diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 4e34944..ae960e4 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -6,6 +6,7 @@ from twisted.internet.interfaces import IReactorCore, IStreamServerEndpoint from twisted.python import log +from apricot.cache import LocalCache, RedisCache from apricot.ldap import OAuthLDAPServerFactory from apricot.oauth import OAuthBackend, OAuthClientMap @@ -18,21 +19,30 @@ def __init__( client_secret: str, domain: str, port: int, - redis_host: str, - redis_port: int, + redis_host: str | None = None, + redis_port: int | None = None, **kwargs: Any, ) -> None: # Log to stdout log.startLogging(sys.stdout) + # Initialise the UID cache + if redis_host and redis_port: + log.msg( + f"Using a Redis user-id cache at host '{redis_host}' on port '{redis_port}'." + ) + uid_cache = RedisCache(redis_host=redis_host, redis_port=redis_port) + else: + log.msg("Using a local user-id cache.") + uid_cache = LocalCache() + # Initialize the appropriate OAuth client try: oauth_client = OAuthClientMap[backend]( client_id=client_id, client_secret=client_secret, domain=domain, - redis_host=redis_host, - redis_port=redis_port, + uid_cache=uid_cache, **kwargs, ) except Exception as exc: diff --git a/apricot/cache/__init__.py b/apricot/cache/__init__.py index 8f67386..478ecea 100644 --- a/apricot/cache/__init__.py +++ b/apricot/cache/__init__.py @@ -1,7 +1,9 @@ from .local_cache import LocalCache from .redis_cache import RedisCache +from .uid_cache import UidCache __all__ = [ "LocalCache", "RedisCache", + "UidCache", ] diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 71b3c13..7e76f7d 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -13,7 +13,7 @@ from requests_oauthlib import OAuth2Session from twisted.python import log -from apricot.cache import RedisCache +from apricot.cache import UidCache from apricot.models import ( LdapGroupOfNames, LdapInetOrgPerson, @@ -35,17 +35,16 @@ def __init__( client_secret: str, domain: str, redirect_uri: str, - redis_host: str, - redis_port: str, scopes: list[str], token_url: str, + uid_cache: UidCache, ) -> None: # Set attributes self.bearer_token_: str | None = None self.client_secret = client_secret self.domain = domain self.token_url = token_url - self.uid_cache = RedisCache(redis_host=redis_host, redis_port=redis_port) + self.uid_cache = uid_cache # Allow token scope to not match requested scope. (Other auth libraries allow # this, but Requests-OAuthlib raises exception on scope mismatch by default.) os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" # noqa: S105 diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index d4d43c7..ac54e27 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -23,21 +23,13 @@ if [ -z "${DOMAIN}" ]; then exit 1 fi -if [ -z "${REDIS_HOST}" ]; then - echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_HOST environment variable is not set" - exit 1 -fi - # Arguments with defaults if [ -z "${PORT}" ]; then PORT="1389" echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] PORT environment variable is not set: using default of '${PORT}'" fi -if [ -z "${REDIS_PORT}" ]; then - REDIS_PORT="6379" - echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_PORT environment variable is not set: using default of '${REDIS_PORT}'" -fi + # Optional arguments EXTRA_OPTS="" @@ -45,6 +37,14 @@ if [ -n "${ENTRA_TENANT_ID}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --entra-tenant-id $ENTRA_TENANT_ID" fi +if [ -n "${REDIS_HOST}" ]; then + if [ -z "${REDIS_PORT}" ]; then + REDIS_PORT="6379" + echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_PORT environment variable is not set: using default of '${REDIS_PORT}'" + fi + EXTRA_OPTS="${EXTRA_OPTS} --redis-host $REDIS_HOST --redis-port $REDIS_PORT" +fi + # Run the server hatch run python run.py \ --backend "${BACKEND}" \ @@ -52,6 +52,4 @@ hatch run python run.py \ --client-secret "${CLIENT_SECRET}" \ --domain "${DOMAIN}" \ --port "${PORT}" \ - --redis-host "${REDIS_HOST}" \ - --redis-port "${REDIS_PORT}" \ $EXTRA_OPTS diff --git a/run.py b/run.py index 5bb3868..67f28ed 100644 --- a/run.py +++ b/run.py @@ -16,11 +16,13 @@ parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.") parser.add_argument("-d", "--domain", type=str, help="Which domain users belong to.") parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.") - parser.add_argument("--redis-host", type=str, help="Host for Redis server.") - parser.add_argument("--redis-port", type=int, help="Port for Redis server.") # Options for Microsoft Entra backend - group = parser.add_argument_group("Microsoft Entra") - group.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False) + entra_group = parser.add_argument_group("Microsoft Entra") + entra_group.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False) + # Options for Redis cache + redis_group = parser.add_argument_group("Redis") + redis_group.add_argument("--redis-host", type=str, help="Host for Redis server.") + redis_group.add_argument("--redis-port", type=int, help="Port for Redis server.") # Parse arguments args = parser.parse_args() From 62c0e586fc3aa54603ce1642f3b20875c7d3d0d9 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Tue, 27 Feb 2024 14:18:22 +0000 Subject: [PATCH 5/5] :rotating_light: Fix linting errors --- apricot/apricot_server.py | 3 ++- apricot/cache/redis_cache.py | 13 +++++++------ apricot/cache/uid_cache.py | 3 --- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index ae960e4..863da98 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -6,7 +6,7 @@ from twisted.internet.interfaces import IReactorCore, IStreamServerEndpoint from twisted.python import log -from apricot.cache import LocalCache, RedisCache +from apricot.cache import LocalCache, RedisCache, UidCache from apricot.ldap import OAuthLDAPServerFactory from apricot.oauth import OAuthBackend, OAuthClientMap @@ -27,6 +27,7 @@ def __init__( log.startLogging(sys.stdout) # Initialise the UID cache + uid_cache: UidCache if redis_host and redis_port: log.msg( f"Using a Redis user-id cache at host '{redis_host}' on port '{redis_port}'." diff --git a/apricot/cache/redis_cache.py b/apricot/cache/redis_cache.py index 8e753ef..4a1d919 100644 --- a/apricot/cache/redis_cache.py +++ b/apricot/cache/redis_cache.py @@ -6,24 +6,25 @@ class RedisCache(UidCache): - def __init__(self, redis_host: str, redis_port: str) -> None: + def __init__(self, redis_host: str, redis_port: int) -> None: self.redis_host = redis_host self.redis_port = redis_port - self.cache_ = None + self.cache_: "redis.Redis[str]" | None = None @property - def cache(self) -> redis.Redis: # type: ignore[type-arg] + def cache(self) -> "redis.Redis[str]": """ Lazy-load the cache on request """ if not self.cache_: - self.cache_ = redis.Redis( # type: ignore[call-overload] + self.cache_ = redis.Redis( host=self.redis_host, port=self.redis_port, decode_responses=True ) - return self.cache_ # type: ignore[return-value] + return self.cache_ def get(self, identifier: str) -> int | None: - return self.cache.get(identifier) + value = self.cache.get(identifier) + return None if value is None else int(value) def keys(self) -> list[str]: return [str(k) for k in self.cache.keys()] diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py index eab70c0..ab46029 100644 --- a/apricot/cache/uid_cache.py +++ b/apricot/cache/uid_cache.py @@ -3,9 +3,6 @@ class UidCache(ABC): - def __init__(self) -> None: - self.cache_ = None - @abstractmethod def get(self, identifier: str) -> int | None: """