Skip to content

Commit

Permalink
Merge pull request #30 from alan-turing-institute/add-debug-option
Browse files Browse the repository at this point in the history
Add debug option
  • Loading branch information
jemrobinson authored Mar 8, 2024
2 parents d54e2b7 + 776e06b commit d39b9e5
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 27 deletions.
13 changes: 13 additions & 0 deletions apricot/apricot_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ def __init__(
client_secret: str,
domain: str,
port: int,
*,
debug: bool = False,
redis_host: str | None = None,
redis_port: int | None = None,
**kwargs: Any,
) -> None:
self.debug = debug

# Log to stdout
log.startLogging(sys.stdout)

Expand All @@ -39,9 +43,12 @@ def __init__(

# Initialize the appropriate OAuth client
try:
if self.debug:
log.msg(f"Creating an OAuthClient for {backend}.")
oauth_client = OAuthClientMap[backend](
client_id=client_id,
client_secret=client_secret,
debug=debug,
domain=domain,
uid_cache=uid_cache,
**kwargs,
Expand All @@ -51,9 +58,13 @@ def __init__(
raise ValueError(msg) from exc

# Create an LDAPServerFactory
if self.debug:
log.msg("Creating an LDAPServerFactory.")
factory = OAuthLDAPServerFactory(oauth_client)

# Attach a listening endpoint
if self.debug:
log.msg("Attaching a listening endpoint.")
endpoint: IStreamServerEndpoint = serverFromString(reactor, f"tcp:{port}")
endpoint.listen(factory)

Expand All @@ -62,4 +73,6 @@ def __init__(

def run(self) -> None:
"""Start the Twisted reactor"""
if self.debug:
log.msg("Starting the Twisted reactor.")
self.reactor.run()
4 changes: 1 addition & 3 deletions apricot/ldap/oauth_ldap_server_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@


class OAuthLDAPServerFactory(ServerFactory):
protocol = ReadOnlyLDAPServer

def __init__(self, oauth_client: OAuthClient):
"""
Initialise an LDAPServerFactory
Expand All @@ -31,6 +29,6 @@ def buildProtocol(self, addr: IAddress) -> Protocol: # noqa: N802
@param addr: an object implementing L{IAddress}
"""
id(addr) # ignore unused arguments
proto = self.protocol()
proto = ReadOnlyLDAPServer(debug=self.adaptor.debug)
proto.factory = self.adaptor
return proto
7 changes: 7 additions & 0 deletions apricot/ldap/oauth_ldap_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, oauth_client: OAuthClient, refresh_interval: int = 60) -> Non
@param oauth_client: An OAuth client used to construct the LDAP tree
@param refresh_interval: Interval in seconds after which the tree must be refreshed
"""
self.debug = oauth_client.debug
self.last_update = time.monotonic()
self.oauth_client: OAuthClient = oauth_client
self.refresh_interval = refresh_interval
Expand Down Expand Up @@ -52,9 +53,13 @@ def root(self) -> OAuthLDAPEntry:
"OU=users", {"ou": ["users"], "objectClass": ["organizationalUnit"]}
)
# Add groups to the groups OU
if self.debug:
log.msg("Adding groups to the LDAP tree.")
for group_attrs in self.oauth_client.validated_groups():
groups_ou.add_child(f"CN={group_attrs.cn}", group_attrs.to_dict())
# Add users to the users OU
if self.debug:
log.msg("Adding users to the LDAP tree.")
for user_attrs in self.oauth_client.validated_users():
users_ou.add_child(f"CN={user_attrs.cn}", user_attrs.to_dict())
# Set last updated time
Expand All @@ -75,4 +80,6 @@ def lookup(self, dn: DistinguishedName | str) -> defer.Deferred[ILDAPEntry]:
"""
if not isinstance(dn, DistinguishedName):
dn = DistinguishedName(stringValue=dn)
if self.debug:
log.msg(f"Starting an LDAP lookup for {dn.getText()}.")
return self.root.lookup(dn)
74 changes: 52 additions & 22 deletions apricot/ldap/read_only_ldap_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,53 @@
from ldaptor.protocols.ldap.ldaperrors import LDAPProtocolError
from ldaptor.protocols.ldap.ldapserver import LDAPServer
from ldaptor.protocols.pureldap import (
LDAPAddRequest,
LDAPBindRequest,
LDAPCompareRequest,
LDAPDelRequest,
LDAPExtendedRequest,
LDAPModifyDNRequest,
LDAPModifyRequest,
LDAPProtocolRequest,
LDAPSearchRequest,
LDAPSearchResultDone,
LDAPSearchResultEntry,
LDAPUnbindRequest,
)
from twisted.internet import defer
from twisted.python import log

from apricot.oauth import LDAPControlTuple


class ReadOnlyLDAPServer(LDAPServer):
def __init__(self) -> None:
def __init__(self, *, debug: bool = False) -> None:
super().__init__()
self.debug = True
self.debug = debug

def getRootDSE( # noqa: N802
self,
request: LDAPBindRequest,
request: LDAPProtocolRequest,
reply: Callable[[LDAPSearchResultEntry], None] | None,
) -> LDAPSearchResultDone:
"""
Handle an LDAP Root RSE request
Handle an LDAP Root DSE request
"""
if self.debug:
log.msg("Handling an LDAP Root DSE request.")
return super().getRootDSE(request, reply)

def handle_LDAPAddRequest( # noqa: N802
self,
request: LDAPBindRequest,
request: LDAPAddRequest,
controls: list[LDAPControlTuple] | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""
Refuse to handle an LDAP add request
"""
if self.debug:
log.msg("Handling an LDAP add request.")
id((request, controls, reply)) # ignore unused arguments
msg = "ReadOnlyLDAPServer will not handle LDAP add requests"
raise LDAPProtocolError(msg)
Expand All @@ -50,87 +64,103 @@ def handle_LDAPBindRequest( # noqa: N802
"""
Handle an LDAP bind request
"""
if self.debug:
log.msg("Handling an LDAP bind request.")
return super().handle_LDAPBindRequest(request, controls, reply)

def handle_LDAPCompareRequest( # noqa: N802
self,
request: LDAPBindRequest,
request: LDAPCompareRequest,
controls: list[LDAPControlTuple] | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""
Handle an LDAP compare request
"""
if self.debug:
log.msg("Handling an LDAP compare request.")
return super().handle_LDAPCompareRequest(request, controls, reply)

def handle_LDAPDelRequest( # noqa: N802
self,
request: LDAPBindRequest,
request: LDAPDelRequest,
controls: list[LDAPControlTuple] | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""
Refuse to handle an LDAP delete request
"""
if self.debug:
log.msg("Handling an LDAP delete request.")
id((request, controls, reply)) # ignore unused arguments
msg = "ReadOnlyLDAPServer will not handle LDAP delete requests"
raise LDAPProtocolError(msg)

def handle_LDAPExtendedRequest( # noqa: N802
self,
request: LDAPBindRequest,
request: LDAPExtendedRequest,
controls: list[LDAPControlTuple] | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""
Handle an LDAP extended request
"""
if self.debug:
log.msg("Handling an LDAP extended request.")
return super().handle_LDAPExtendedRequest(request, controls, reply)

def handle_LDAPModifyDNRequest( # noqa: N802
self,
request: LDAPBindRequest,
request: LDAPModifyDNRequest,
controls: list[LDAPControlTuple] | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""
Refuse to handle an LDAP modify DN request
"""
if self.debug:
log.msg("Handling an LDAP modify DN request.")
id((request, controls, reply)) # ignore unused arguments
msg = "ReadOnlyLDAPServer will not handle LDAP modify DN requests"
raise LDAPProtocolError(msg)

def handle_LDAPModifyRequest( # noqa: N802
self,
request: LDAPBindRequest,
request: LDAPModifyRequest,
controls: list[LDAPControlTuple] | None,
reply: Callable[..., None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""
Refuse to handle an LDAP modify request
"""
if self.debug:
log.msg("Handling an LDAP modify request.")
id((request, controls, reply)) # ignore unused arguments
msg = "ReadOnlyLDAPServer will not handle LDAP modify requests"
raise LDAPProtocolError(msg)

def handle_LDAPUnbindRequest( # noqa: N802
self,
request: LDAPBindRequest,
controls: list[LDAPControlTuple] | None,
reply: Callable[..., None] | None,
) -> None:
"""
Handle an LDAP unbind request
"""
super().handle_LDAPUnbindRequest(request, controls, reply)

def handle_LDAPSearchRequest( # noqa: N802
self,
request: LDAPBindRequest,
request: LDAPSearchRequest,
controls: list[LDAPControlTuple] | None,
reply: Callable[[LDAPSearchResultEntry], None] | None,
) -> defer.Deferred[ILDAPEntry]:
"""
Handle an LDAP search request
"""
if self.debug:
log.msg("Handling an LDAP search request.")
return super().handle_LDAPSearchRequest(request, controls, reply)

def handle_LDAPUnbindRequest( # noqa: N802
self,
request: LDAPUnbindRequest,
controls: list[LDAPControlTuple] | None,
reply: Callable[..., None] | None,
) -> None:
"""
Handle an LDAP unbind request
"""
if self.debug:
log.msg("Handling an LDAP unbind request.")
super().handle_LDAPUnbindRequest(request, controls, reply)
10 changes: 10 additions & 0 deletions apricot/oauth/oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
self,
client_id: str,
client_secret: str,
debug: bool, # noqa: FBT001
domain: str,
redirect_uri: str,
scopes: list[str],
Expand All @@ -43,6 +44,7 @@ def __init__(
# Set attributes
self.bearer_token_: str | None = None
self.client_secret = client_secret
self.debug = debug
self.domain = domain
self.token_url = token_url
self.uid_cache = uid_cache
Expand All @@ -53,6 +55,8 @@ def __init__(

try:
# OAuth client that uses application credentials
if self.debug:
log.msg("Initialising application credential client.")
self.session_application = OAuth2Session(
client=BackendApplicationClient(
client_id=client_id, scope=scopes, redirect_uri=redirect_uri
Expand All @@ -64,6 +68,8 @@ def __init__(

try:
# OAuth client that uses delegated credentials
if self.debug:
log.msg("Initialising delegated credential client.")
self.session_interactive = OAuth2Session(
client=LegacyApplicationClient(
client_id=client_id, scope=scopes, redirect_uri=redirect_uri
Expand Down Expand Up @@ -145,6 +151,8 @@ def validated_groups(self) -> list[LDAPAttributeAdaptor]:
"""
Validate output via pydantic and return a list of LDAPAttributeAdaptor
"""
if self.debug:
log.msg("Constructing and validating list of groups")
output = []
# Add one self-titled group for each user
user_group_dicts = []
Expand Down Expand Up @@ -180,6 +188,8 @@ def validated_users(self) -> list[LDAPAttributeAdaptor]:
"""
Validate output via pydantic and return a list of LDAPAttributeAdaptor
"""
if self.debug:
log.msg("Constructing and validating list of users")
output = []
for user_dict in self.users():
try:
Expand Down
4 changes: 4 additions & 0 deletions docker/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ fi

# Optional arguments
EXTRA_OPTS=""
if [ -n "${DEBUG}" ]; then
EXTRA_OPTS="${EXTRA_OPTS} --debug"
fi

if [ -n "${ENTRA_TENANT_ID}" ]; then
EXTRA_OPTS="${EXTRA_OPTS} --entra-tenant-id $ENTRA_TENANT_ID"
fi
Expand Down
5 changes: 3 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
)
# Common options needed for all backends
parser.add_argument("-b", "--backend", type=OAuthBackend, help="Which OAuth backend to use.")
parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.")
parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.")
parser.add_argument("-d", "--domain", type=str, help="Which domain users belong to.")
parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.")
parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.")
parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.")
parser.add_argument("--debug", action="store_true", help="Enable debug logging.")
# Options for Microsoft Entra backend
entra_group = parser.add_argument_group("Microsoft Entra")
entra_group.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False)
Expand Down

0 comments on commit d39b9e5

Please sign in to comment.