Skip to content

Commit

Permalink
♻️ Switch KeycloakClient to use UidCache for generating missing UIDs
Browse files Browse the repository at this point in the history
  • Loading branch information
jemrobinson committed May 23, 2024
1 parent 38a28e4 commit f986e41
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 86 deletions.
2 changes: 1 addition & 1 deletion apricot/apricot_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
port: int,
*,
debug: bool = False,
enable_mirrored_groups: bool,
enable_mirrored_groups: bool = True,
redis_host: str | None = None,
redis_port: int | None = None,
**kwargs: Any,
Expand Down
28 changes: 28 additions & 0 deletions apricot/cache/uid_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,31 @@ def _get_max_uid(self, category: str | None) -> int:
keys = self.keys()
values = [*self.values(keys), -999]
return max(values)

def overwrite_group_uid(self, identifier: str, uid: int) -> None:
"""
Set UID for a group, overwriting the existing value if there is one
@param identifier: Identifier for group
@param uid: Desired UID
"""
return self.overwrite_uid(identifier, category="group", uid=uid)

def overwrite_user_uid(self, identifier: str, uid: int) -> None:
"""
Get UID for a user, constructing one if necessary
@param identifier: Identifier for user
@param uid: Desired UID
"""
return self.overwrite_uid(identifier, category="user", uid=uid)

def overwrite_uid(self, identifier: str, category: str, uid: int) -> None:
"""
Set UID, overwriting the existing one if necessary.
@param identifier: Identifier for object
@param category: Category the object belongs to
@param uid: Desired UID
"""
self.set(f"{category}-{identifier}", uid)
120 changes: 35 additions & 85 deletions apricot/oauth/keycloak_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,6 @@
from .oauth_client import OAuthClient


def get_single_value_attribute(
obj: JSONDict, key: str, default: str | None = None
) -> Any:
for part in key.split("."):
obj = obj.get(part) # type: ignore
if obj is None:
return default
if isinstance(obj, list):
try:
return next(iter(obj))
except StopIteration:
pass
else:
return obj
return default


class KeycloakClient(OAuthClient):
"""OAuth client for the Keycloak backend."""

Expand Down Expand Up @@ -62,41 +45,25 @@ def groups(self) -> list[JSONDict]:
if len(data) != self.max_rows:
break

group_data = sorted(
group_data,
key=lambda group: int(
get_single_value_attribute(
group, "attributes.gid", default="9999999999"
),
base=10,
),
)

next_gid = max(
*(
int(
get_single_value_attribute(
group, "attributes.gid", default="-1"
),
base=10,
# Ensure that gid attribute exists for all groups
for group_dict in group_data:
group_dict["attributes"] = group_dict.get("attributes", {})
if "gid" not in group_dict["attributes"]:
group_dict["attributes"]["gid"] = None
# If group_gid exists then set the cache to the same value
# This ensures that any groups without a `gid` attribute will receive a
# UID that does not overlap with existing groups
if group_gid := group_dict["attributes"]["gid"]:
self.uid_cache.overwrite_group_uid(
group_dict["id"], int(group_gid, 10)
)
+ 1
for group in group_data
),
3000,
)

# Read group attributes
for group_dict in group_data:
group_gid = get_single_value_attribute(
group_dict, "attributes.gid", default=None
)
if group_gid:
group_gid = int(group_gid, 10)
if not group_gid:
group_gid = next_gid
next_gid += 1
group_dict["attributes"] = group_dict.get("attributes", {})
group_dict["attributes"]["gid"] = [str(group_gid)]
if not group_dict["attributes"]["gid"]:
group_dict["attributes"]["gid"] = [
str(self.uid_cache.get_group_uid(group_dict["id"]))
]
self.request(
f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}",
method="PUT",
Expand All @@ -105,7 +72,7 @@ def groups(self) -> list[JSONDict]:
attributes: JSONDict = {}
attributes["cn"] = group_dict.get("name", None)
attributes["description"] = group_dict.get("id", None)
attributes["gidNumber"] = group_gid
attributes["gidNumber"] = group_dict["attributes"]["gid"]
attributes["oauth_id"] = group_dict.get("id", None)
# Add membership attributes
members = self.query(
Expand All @@ -132,44 +99,27 @@ def users(self) -> list[JSONDict]:
if len(data) != self.max_rows:
break

user_data = sorted(
user_data,
key=lambda user: int(
get_single_value_attribute(
user, "attributes.uid", default="9999999999"
),
base=10,
),
)

next_uid = max(
*(
int(
get_single_value_attribute(
user, "attributes.uid", default="-1"
),
base=10,
# Ensure that uid attribute exists for all users
for user_dict in user_data:
user_dict["attributes"] = user_dict.get("attributes", {})
if "uid" not in user_dict["attributes"]:
user_dict["attributes"]["uid"] = None
# If user_uid exists then set the cache to the same value.
# This ensures that any groups without a `gid` attribute will receive a
# UID that does not overlap with existing groups
if user_uid := user_dict["attributes"]["uid"]:
self.uid_cache.overwrite_user_uid(
user_dict["id"], int(user_uid, 10)
)
+ 1
for user in user_data
),
3000,
)

# Read user attributes
for user_dict in sorted(
user_data, key=lambda user: user["createdTimestamp"]
):
user_uid = get_single_value_attribute(
user_dict, "attributes.uid", default=None
)
if user_uid:
user_uid = int(user_uid, base=10)
if not user_uid:
user_uid = next_uid
next_uid += 1

user_dict["attributes"] = user_dict.get("attributes", {})
user_dict["attributes"]["uid"] = [str(user_uid)]
if not user_dict["attributes"]["uid"]:
user_dict["attributes"]["uid"] = [
str(self.uid_cache.get_user_uid(user_dict["id"]))
]
self.request(
f"{self.base_url}/admin/realms/{self.realm}/users/{user_dict['id']}",
method="PUT",
Expand All @@ -189,12 +139,12 @@ def users(self) -> list[JSONDict]:
attributes["displayName"] = full_name
attributes["mail"] = user_dict.get("email")
attributes["description"] = ""
attributes["gidNumber"] = user_uid
attributes["gidNumber"] = user_dict["attributes"]["uid"]
attributes["givenName"] = first_name if first_name else ""
attributes["homeDirectory"] = f"/home/{username}" if username else None
attributes["oauth_id"] = user_dict.get("id", None)
attributes["sn"] = last_name if last_name else ""
attributes["uidNumber"] = user_uid
attributes["uidNumber"] = user_dict["attributes"]["uid"]
output.append(attributes)
except KeyError:
pass
Expand Down

0 comments on commit f986e41

Please sign in to comment.