Skip to content

Commit

Permalink
Merge pull request #19 from alan-turing-institute/16-redis
Browse files Browse the repository at this point in the history
Switch to Redis for ID caching
  • Loading branch information
jemrobinson authored Feb 26, 2024
2 parents 4115fc1 + fd1eb33 commit 5a19e78
Show file tree
Hide file tree
Showing 14 changed files with 221 additions and 67 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ The name is a slightly tortured acronym for: LD**A**P **pr**oxy for Open**I**D *

## Usage

**N.B.** As Apricot uses a Redis server to store generated `uidNumber` and `gidNumber` values.

Start the `Apricot` server on port 1389 by running:

```bash
python run.py --client-id "<your client ID>" --client-secret "<your client secret>" --backend "<your backend>" --port 1389 --domain "<your domain name>"
python run.py --client-id "<your client ID>" --client-secret "<your client secret>" --backend "<your backend>" --port 1389 --domain "<your domain name>" --redis-host "<your Redis server>"
```

Alternatively, you can run in Docker by editing `docker/docker-compose.yaml` and running:
Expand Down Expand Up @@ -40,18 +42,20 @@ Each user will have an entry like

```ldif
dn: CN=<user name>,OU=users,DC=<your domain>
objectClass: organizationalPerson
objectClass: inetOrgPerson
objectClass: inetuser
objectClass: person
objectClass: posixAccount
objectClass: top
objectClass: user
<user data fields here>
```

Each group will have an entry like

```ldif
dn: CN=<group name>,OU=groups,DC=<your domain>
objectClass: group
objectClass: groupOfNames
objectClass: posixGroup
objectClass: top
<group data fields here>
```
Expand Down
1 change: 1 addition & 0 deletions apricot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .__about__ import __version__, __version_info__
from .apricot_server import ApricotServer
from .patches import LDAPString # noqa: F401

__all__ = [
"__version__",
Expand Down
6 changes: 4 additions & 2 deletions apricot/apricot_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def __init__(
client_secret: str,
domain: str,
port: int,
uid_attribute: str,
redis_host: str,
redis_port: int,
**kwargs: Any,
) -> None:
# Log to stdout
Expand All @@ -30,7 +31,8 @@ def __init__(
client_id=client_id,
client_secret=client_secret,
domain=domain,
uid_attribute=uid_attribute,
redis_host=redis_host,
redis_port=redis_port,
**kwargs,
)
except Exception as exc:
Expand Down
5 changes: 5 additions & 0 deletions apricot/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .uid_cache import UidCache

__all__ = [
"UidCache",
]
75 changes: 75 additions & 0 deletions apricot/cache/uid_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import cast

import redis


class UidCache:
def __init__(self, redis_host: str, redis_port: str) -> None:
self.redis_host = redis_host
self.redis_port = redis_port
self.cache_ = None

@property
def cache(self) -> redis.Redis: # type: ignore[type-arg]
"""
Lazy-load the cache on request
"""
if not self.cache_:
self.cache_ = redis.Redis( # type: ignore[call-overload]
host=self.redis_host, port=self.redis_port, decode_responses=True
)
return self.cache_ # type: ignore[return-value]

@property
def keys(self) -> list[str]:
"""
Get list of keys from the cache
"""
return [str(k) for k in self.cache.keys()]

def get_group_uid(self, identifier: str) -> int:
"""
Get UID for a group, constructing one if necessary
@param identifier: Identifier for group needing a UID
"""
return self.get_uid(identifier, category="group", min_value=3000)

def get_user_uid(self, identifier: str) -> int:
"""
Get UID for a user, constructing one if necessary
@param identifier: Identifier for user needing a UID
"""
return self.get_uid(identifier, category="user", min_value=2000)

def get_uid(
self, identifier: str, category: str, min_value: int | None = None
) -> int:
"""
Get UID, constructing one if necessary.
@param identifier: Identifier for object needing a UID
@param category: Category the object belongs to
@param min_value: Minimum allowed value for the UID
"""
identifier_ = f"{category}-{identifier}"
uid = self.cache.get(identifier_)
if not uid:
min_value = min_value if min_value else 0
next_uid = max(self._get_max_uid(category) + 1, min_value)
self.cache.set(identifier_, next_uid)
return cast(int, self.cache.get(identifier_))

def _get_max_uid(self, category: str | None) -> int:
"""
Get maximum UID for a given category
@param category: Category to check UIDs for
"""
if category:
keys = [k for k in self.keys if k.startswith(category)]
else:
keys = self.keys
values = [int(cast(str, v)) for v in self.cache.mget(keys)] + [-999]
return max(values)
63 changes: 31 additions & 32 deletions apricot/ldap/oauth_ldap_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from zope.interface import implementer

from apricot.ldap.oauth_ldap_entry import OAuthLDAPEntry
from apricot.oauth import LDAPAttributeDict, OAuthClient
from apricot.oauth import OAuthClient


@implementer(IConnectedLDAPEntry)
Expand All @@ -17,41 +17,40 @@ def __init__(self, oauth_client: OAuthClient) -> None:
@param oauth_client: An OAuth client used to construct the LDAP tree
"""
self.oauth_client = oauth_client
self.oauth_client: OAuthClient = oauth_client
self.root_: OAuthLDAPEntry | None = None

# Create a root node for the tree
self.root = self.build_root(
dn=self.oauth_client.root_dn, attributes={"objectClass": ["dcObject"]}
)
# Add OUs for users and groups
groups_ou = self.root.add_child(
"OU=groups", {"ou": ["groups"], "objectClass": ["organizationalUnit"]}
)
users_ou = self.root.add_child(
"OU=users", {"ou": ["users"], "objectClass": ["organizationalUnit"]}
)
# Add groups to the groups OU
for group_attrs in self.oauth_client.validated_groups():
groups_ou.add_child(f"CN={group_attrs['cn'][0]}", group_attrs)
# Add users to the users OU
for user_attrs in self.oauth_client.validated_users():
users_ou.add_child(f"CN={user_attrs['cn'][0]}", user_attrs)

def __repr__(self) -> str:
return f"{self.__class__.__name__} with backend {self.oauth_client.__class__.__name__}"

def build_root(self, dn: str, attributes: LDAPAttributeDict) -> OAuthLDAPEntry:
@property
def root(self) -> OAuthLDAPEntry:
"""
Construct the root of the LDAP tree
Lazy-load the LDAP tree on request
@param dn: Distinguished Name of the object
@param attributes: Attributes of the object.
@return: An OAuthLDAPEntry
@return: An OAuthLDAPEntry for the tree
"""
return OAuthLDAPEntry(
dn=dn, attributes=attributes, oauth_client=self.oauth_client
)
if not self.root_:
# Create a root node for the tree
self.root_ = OAuthLDAPEntry(
dn=self.oauth_client.root_dn,
attributes={"objectClass": ["dcObject"]},
oauth_client=self.oauth_client,
)
# Add OUs for users and groups
groups_ou = self.root_.add_child(
"OU=groups", {"ou": ["groups"], "objectClass": ["organizationalUnit"]}
)
users_ou = self.root_.add_child(
"OU=users", {"ou": ["users"], "objectClass": ["organizationalUnit"]}
)
# Add groups to the groups OU
for group_attrs in self.oauth_client.validated_groups():
groups_ou.add_child(f"CN={group_attrs['cn'][0]}", group_attrs)
# Add users to the users OU
for user_attrs in self.oauth_client.validated_users():
users_ou.add_child(f"CN={user_attrs['cn'][0]}", user_attrs)
return self.root_

def __repr__(self) -> str:
return f"{self.__class__.__name__} with backend {self.oauth_client.__class__.__name__}"

def lookup(self, dn: DistinguishedName | str) -> defer.Deferred[ILDAPEntry]:
"""
Expand Down
32 changes: 22 additions & 10 deletions apricot/oauth/microsoft_entra_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,23 @@ def extract_token(self, json_response: JSONDict) -> str:
def groups(self) -> list[dict[str, Any]]:
output = []
try:
group_data = self.query("https://graph.microsoft.com/v1.0/groups/")
for group_dict in cast(list[dict[str, Any]], group_data["value"]):
queries = [
"createdDateTime",
"displayName",
"id",
]
group_data = self.query(
f"https://graph.microsoft.com/v1.0/groups?$select={','.join(queries)}"
)
for group_dict in cast(
list[dict[str, Any]],
sorted(group_data["value"], key=lambda group: group["createdDateTime"]),
):
group_uid = self.uid_cache.get_group_uid(group_dict["id"])
attributes = {}
attributes["cn"] = group_dict.get("displayName", None)
attributes["description"] = group_dict.get("id", None)
# As we cannot manually set any attributes we take the last part of the securityIdentifier
attributes["gidNumber"] = str(
group_dict.get("securityIdentifier", "")
).split("-")[-1]
attributes["gidNumber"] = group_uid
# Add membership attributes
members = self.query(
f"https://graph.microsoft.com/v1.0/groups/{group_dict['id']}/members"
Expand All @@ -59,30 +67,34 @@ def users(self) -> list[dict[str, Any]]:
output = []
try:
queries = [
"createdDateTime",
"displayName",
"givenName",
"id",
"surname",
"userPrincipalName",
self.uid_attribute,
]
user_data = self.query(
f"https://graph.microsoft.com/v1.0/users?$select={','.join(queries)}"
)
for user_dict in cast(list[dict[str, Any]], user_data["value"]):
for user_dict in cast(
list[dict[str, Any]],
sorted(user_data["value"], key=lambda user: user["createdDateTime"]),
):
# Get user attributes
uid, domain = str(user_dict.get("userPrincipalName", "@")).split("@")
user_uid = self.uid_cache.get_user_uid(user_dict["id"])
attributes = {}
attributes["cn"] = user_dict.get("displayName", None)
attributes["description"] = user_dict.get("id", None)
attributes["displayName"] = attributes.get("cn", None)
attributes["domain"] = domain
attributes["gidNumber"] = user_dict.get(self.uid_attribute, None)
attributes["gidNumber"] = user_uid
attributes["givenName"] = user_dict.get("givenName", "")
attributes["homeDirectory"] = f"/home/{uid}" if uid else None
attributes["sn"] = user_dict.get("surname", "")
attributes["uid"] = uid if uid else None
attributes["uidNumber"] = user_dict.get(self.uid_attribute, None)
attributes["uidNumber"] = user_uid
# Add group attributes
group_memberships = self.query(
f"https://graph.microsoft.com/v1.0/users/{user_dict['id']}/memberOf"
Expand Down
13 changes: 11 additions & 2 deletions apricot/oauth/oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from requests_oauthlib import OAuth2Session
from twisted.python import log

from apricot.cache import UidCache
from apricot.models import (
LdapGroupOfNames,
LdapInetOrgPerson,
Expand All @@ -32,15 +33,16 @@ def __init__(
client_secret: str,
domain: str,
redirect_uri: str,
redis_host: str,
redis_port: str,
scopes: list[str],
token_url: str,
uid_attribute: str,
) -> None:
# Set attributes
self.client_secret = client_secret
self.domain = domain
self.token_url = token_url
self.uid_attribute = uid_attribute
self.uid_cache = UidCache(redis_host=redis_host, redis_port=redis_port)
# Allow token scope to not match requested scope. (Other auth libraries allow
# this, but Requests-OAuthlib raises exception on scope mismatch by default.)
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" # noqa: S105
Expand Down Expand Up @@ -107,6 +109,9 @@ def root_dn(self) -> str:
return "DC=" + self.domain.replace(".", ",DC=")

def query(self, url: str) -> dict[str, Any]:
"""
Make a query against the Microsoft Entra directory
"""
result = self.session_application.request(
method="GET",
url=url,
Expand Down Expand Up @@ -163,6 +168,10 @@ def validated_users(self) -> list[LDAPAttributeDict]:
for user_dict in self.users():
try:
attributes = {"objectclass": ["top"]}
# Add user to self-titled group
user_dict["memberOf"].append(
f"CN={user_dict['cn']},OU=groups,{self.root_dn}"
)
# Add 'inetOrgPerson' attributes
inetorg_person = LdapInetOrgPerson(**user_dict)
attributes.update(inetorg_person.model_dump())
Expand Down
5 changes: 5 additions & 0 deletions apricot/patches/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .ldap_string import LDAPString # type: ignore[attr-defined]

__all__ = [
"LDAPString",
]
17 changes: 17 additions & 0 deletions apricot/patches/ldap_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Patch LDAPString to avoid TypeError when parsing LDAP filter strings"""

from typing import Any

from ldaptor.protocols.pureldap import LDAPString

old_init = LDAPString.__init__


def patched_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore[no-untyped-def]
"""Patch LDAPString init to store its value as 'str' not 'bytes'"""
old_init(self, *args, **kwargs)
if isinstance(self.value, bytes):
self.value = self.value.decode()


LDAPString.__init__ = patched_init
10 changes: 10 additions & 0 deletions docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ services:
CLIENT_SECRET: "<your OpenID client secret here>"
DOMAIN: "<your domain here>"
ENTRA_TENANT_ID: "<your Entra tenant ID here>"
REDIS_HOST: "redis"
ports:
- "1389:1389"
restart: always

redis:
container_name: redis
image: redis:7.2
ports:
- "6379:6379"
volumes:
- <local path>:/data
restart: always
Loading

0 comments on commit 5a19e78

Please sign in to comment.