From edafc0c6055a887d9699636ef44a8c26116ef999 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 30 May 2024 13:18:36 +0100 Subject: [PATCH] :rotating_light: Fix linting issues --- README.md | 6 ++--- apricot/apricot_server.py | 27 +++++++++++++++++------ apricot/cache/redis_cache.py | 2 +- apricot/ldap/oauth_ldap_server_factory.py | 20 ++++++++++++++--- apricot/ldap/oauth_ldap_tree.py | 15 +++++++++---- apricot/oauth/oauth_data_adaptor.py | 7 ++++++ 6 files changed, 59 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 348dc60..8911e95 100644 --- a/README.md +++ b/README.md @@ -26,10 +26,10 @@ To do this, you will need to provide the `--redis-host` and `--redis-port` argum ### Configure background refresh [Optional] -By default Apricot will refresh on demand when the data is older than 60 seconds. -If it takes a long time to fetch all users and groups, or you want to ensure that each request prompty gets a respose, you may want to configure background refresh to have it periodically be refreshed in the background. +By default Apricot will refresh the LDAP tree whenever it is accessed and it contains data older than 60 seconds. +If it takes a long time to fetch all users and groups, or you want to ensure that each request gets a prompt response, you may want to configure background refresh to have it periodically be refreshed in the background. -This is enabled with the `--background-refresh` flag, which uses the `--refresh-interval=60` parameter as the interval to refresh the ldap database. +This is enabled with the `--background-refresh` flag, which uses the `--refresh-interval` parameter as the interval to refresh the ldap database. ### Using TLS [Optional] diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 6d9f015..637c0fe 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -1,9 +1,9 @@ import inspect import sys -from typing import Any, cast, Optional +from typing import Any, Optional, cast from twisted.internet import reactor, task -from twisted.internet.endpoints import serverFromString, quoteStringArgument +from twisted.internet.endpoints import quoteStringArgument, serverFromString from twisted.internet.interfaces import IReactorCore, IStreamServerEndpoint from twisted.python import log @@ -71,12 +71,18 @@ def __init__( if self.debug: log.msg("Creating an LDAPServerFactory.") factory = OAuthLDAPServerFactory( - domain, oauth_client, background_refresh=background_refresh, enable_mirrored_groups=enable_mirrored_groups, refresh_interval=refresh_interval + domain, + oauth_client, + background_refresh=background_refresh, + enable_mirrored_groups=enable_mirrored_groups, + refresh_interval=refresh_interval, ) if background_refresh: if self.debug: - log.msg(f"Starting background refresh (interval={factory.adaptor.refresh_interval})") + log.msg( + f"Starting background refresh (interval={factory.adaptor.refresh_interval})" + ) loop = task.LoopingCall(factory.adaptor.refresh) loop.start(factory.adaptor.refresh_interval) @@ -88,11 +94,18 @@ def __init__( # Attach a listening endpoint if tls_port: - if not (tls_certificate or tls_private_key): - raise ValueError("No TLS certificate or private key provided. Make sure you provide these parameters or disable TLS by not providing the TLS port") + if not tls_certificate: + msg = "No TLS certificate provided. Please provide one with --tls-certificate or disable TLS by not providing the --tls-port argument." + raise ValueError(msg) + if not tls_private_key: + msg = "No TLS private key provided. Please provide one with --tls-private-key or disable TLS by not providing the --tls-port argument." + raise ValueError(msg) if self.debug: log.msg("Attaching a listening endpoint (TLS).") - ssl_endpoint: IStreamServerEndpoint = serverFromString(reactor, f"ssl:{tls_port}:privateKey={quoteStringArgument(tls_private_key)}:certKey={quoteStringArgument(tls_certificate)}") + ssl_endpoint: IStreamServerEndpoint = serverFromString( + reactor, + f"ssl:{tls_port}:privateKey={quoteStringArgument(tls_private_key)}:certKey={quoteStringArgument(tls_certificate)}", + ) ssl_endpoint.listen(factory) # Load the Twisted reactor diff --git a/apricot/cache/redis_cache.py b/apricot/cache/redis_cache.py index 24ac506..4a1d919 100644 --- a/apricot/cache/redis_cache.py +++ b/apricot/cache/redis_cache.py @@ -9,7 +9,7 @@ class RedisCache(UidCache): def __init__(self, redis_host: str, redis_port: int) -> None: self.redis_host = redis_host self.redis_port = redis_port - self.cache_: "redis.Redis[str]" | None = None # noqa: UP037 + self.cache_: "redis.Redis[str]" | None = None @property def cache(self) -> "redis.Redis[str]": diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index 3a61578..303d9e4 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -9,16 +9,30 @@ class OAuthLDAPServerFactory(ServerFactory): def __init__( - self, domain: str, oauth_client: OAuthClient, *, background_refresh: bool, enable_mirrored_groups: bool, refresh_interval: int, + self, + domain: str, + oauth_client: OAuthClient, + *, + background_refresh: bool, + enable_mirrored_groups: bool, + refresh_interval: int, ): """ - Initialise an LDAPServerFactory + Initialise an OAuthLDAPServerFactory + @param background_refresh: Whether to refresh the LDAP tree in the background rather than on access + @param domain: The root domain of the LDAP tree + @param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users @param oauth_client: An OAuth client used to construct the LDAP tree + @param refresh_interval: Interval in seconds after which the tree must be refreshed """ # Create an LDAP lookup tree self.adaptor = OAuthLDAPTree( - domain, oauth_client, background_refresh=background_refresh, enable_mirrored_groups=enable_mirrored_groups, refresh_interval=refresh_interval + domain, + oauth_client, + background_refresh=background_refresh, + enable_mirrored_groups=enable_mirrored_groups, + refresh_interval=refresh_interval, ) def __repr__(self) -> str: diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 9d11023..66e649f 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -20,12 +20,14 @@ def __init__( *, background_refresh: bool, enable_mirrored_groups: bool, - refresh_interval, + refresh_interval: int, ) -> None: """ Initialise an OAuthLDAPTree + @param background_refresh: Whether to refresh the LDAP tree in the background rather than on access @param domain: The root domain of the LDAP tree + @param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users @param oauth_client: An OAuth client used to construct the LDAP tree @param refresh_interval: Interval in seconds after which the tree must be refreshed """ @@ -48,15 +50,20 @@ def root(self) -> OAuthLDAPEntry: Lazy-load the LDAP tree on request @return: An OAuthLDAPEntry for the tree + + @raises: ValueError. """ if not self.background_refresh: self.refresh() + if not self.root_: + msg = "LDAP tree could not be loaded" + raise ValueError(msg) return self.root_ - def refresh(self): + def refresh(self) -> None: if ( - not self.root_ - or (time.monotonic() - self.last_update) > self.refresh_interval + not self.root_ + or (time.monotonic() - self.last_update) > self.refresh_interval ): # Update users and groups from the OAuth server log.msg("Retrieving OAuth data.") diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index e2e6ea5..58aaf8d 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -24,6 +24,13 @@ class OAuthDataAdaptor: def __init__( self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool ): + """ + Initialise an OAuthDataAdaptor + + @param domain: The root domain of the LDAP tree + @param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users + @param oauth_client: An OAuth client used to construct the LDAP tree + """ self.debug = oauth_client.debug self.oauth_client = oauth_client self.root_dn = "DC=" + domain.replace(".", ",DC=")