diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index b45ea3e..fa98c22 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -22,7 +22,7 @@ def __init__( port: int, *, debug: bool = False, - enable_mirrored_groups: bool, + enable_mirrored_groups: bool = True, redis_host: str | None = None, redis_port: int | None = None, **kwargs: Any, diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py index ab46029..eb9c729 100644 --- a/apricot/cache/uid_cache.py +++ b/apricot/cache/uid_cache.py @@ -77,3 +77,31 @@ def _get_max_uid(self, category: str | None) -> int: keys = self.keys() values = [*self.values(keys), -999] return max(values) + + def overwrite_group_uid(self, identifier: str, uid: int) -> None: + """ + Set UID for a group, overwriting the existing value if there is one + + @param identifier: Identifier for group + @param uid: Desired UID + """ + return self.overwrite_uid(identifier, category="group", uid=uid) + + def overwrite_user_uid(self, identifier: str, uid: int) -> None: + """ + Get UID for a user, constructing one if necessary + + @param identifier: Identifier for user + @param uid: Desired UID + """ + return self.overwrite_uid(identifier, category="user", uid=uid) + + def overwrite_uid(self, identifier: str, category: str, uid: int) -> None: + """ + Set UID, overwriting the existing one if necessary. + + @param identifier: Identifier for object + @param category: Category the object belongs to + @param uid: Desired UID + """ + self.set(f"{category}-{identifier}", uid) diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index 3c21a9c..8c2f58d 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -5,23 +5,6 @@ from .oauth_client import OAuthClient -def get_single_value_attribute( - obj: JSONDict, key: str, default: str | None = None -) -> Any: - for part in key.split("."): - obj = obj.get(part) # type: ignore - if obj is None: - return default - if isinstance(obj, list): - try: - return next(iter(obj)) - except StopIteration: - pass - else: - return obj - return default - - class KeycloakClient(OAuthClient): """OAuth client for the Keycloak backend.""" @@ -62,41 +45,25 @@ def groups(self) -> list[JSONDict]: if len(data) != self.max_rows: break - group_data = sorted( - group_data, - key=lambda group: int( - get_single_value_attribute( - group, "attributes.gid", default="9999999999" - ), - base=10, - ), - ) - - next_gid = max( - *( - int( - get_single_value_attribute( - group, "attributes.gid", default="-1" - ), - base=10, + # Ensure that gid attribute exists for all groups + for group_dict in group_data: + group_dict["attributes"] = group_dict.get("attributes", {}) + if "gid" not in group_dict["attributes"]: + group_dict["attributes"]["gid"] = None + # If group_gid exists then set the cache to the same value + # This ensures that any groups without a `gid` attribute will receive a + # UID that does not overlap with existing groups + if group_gid := group_dict["attributes"]["gid"]: + self.uid_cache.overwrite_group_uid( + group_dict["id"], int(group_gid, 10) ) - + 1 - for group in group_data - ), - 3000, - ) + # Read group attributes for group_dict in group_data: - group_gid = get_single_value_attribute( - group_dict, "attributes.gid", default=None - ) - if group_gid: - group_gid = int(group_gid, 10) - if not group_gid: - group_gid = next_gid - next_gid += 1 - group_dict["attributes"] = group_dict.get("attributes", {}) - group_dict["attributes"]["gid"] = [str(group_gid)] + if not group_dict["attributes"]["gid"]: + group_dict["attributes"]["gid"] = [ + str(self.uid_cache.get_group_uid(group_dict["id"])) + ] self.request( f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}", method="PUT", @@ -105,7 +72,7 @@ def groups(self) -> list[JSONDict]: attributes: JSONDict = {} attributes["cn"] = group_dict.get("name", None) attributes["description"] = group_dict.get("id", None) - attributes["gidNumber"] = group_gid + attributes["gidNumber"] = group_dict["attributes"]["gid"] attributes["oauth_id"] = group_dict.get("id", None) # Add membership attributes members = self.query( @@ -132,44 +99,27 @@ def users(self) -> list[JSONDict]: if len(data) != self.max_rows: break - user_data = sorted( - user_data, - key=lambda user: int( - get_single_value_attribute( - user, "attributes.uid", default="9999999999" - ), - base=10, - ), - ) - - next_uid = max( - *( - int( - get_single_value_attribute( - user, "attributes.uid", default="-1" - ), - base=10, + # Ensure that uid attribute exists for all users + for user_dict in user_data: + user_dict["attributes"] = user_dict.get("attributes", {}) + if "uid" not in user_dict["attributes"]: + user_dict["attributes"]["uid"] = None + # If user_uid exists then set the cache to the same value. + # This ensures that any groups without a `gid` attribute will receive a + # UID that does not overlap with existing groups + if user_uid := user_dict["attributes"]["uid"]: + self.uid_cache.overwrite_user_uid( + user_dict["id"], int(user_uid, 10) ) - + 1 - for user in user_data - ), - 3000, - ) + # Read user attributes for user_dict in sorted( user_data, key=lambda user: user["createdTimestamp"] ): - user_uid = get_single_value_attribute( - user_dict, "attributes.uid", default=None - ) - if user_uid: - user_uid = int(user_uid, base=10) - if not user_uid: - user_uid = next_uid - next_uid += 1 - - user_dict["attributes"] = user_dict.get("attributes", {}) - user_dict["attributes"]["uid"] = [str(user_uid)] + if not user_dict["attributes"]["uid"]: + user_dict["attributes"]["uid"] = [ + str(self.uid_cache.get_user_uid(user_dict["id"])) + ] self.request( f"{self.base_url}/admin/realms/{self.realm}/users/{user_dict['id']}", method="PUT", @@ -189,12 +139,12 @@ def users(self) -> list[JSONDict]: attributes["displayName"] = full_name attributes["mail"] = user_dict.get("email") attributes["description"] = "" - attributes["gidNumber"] = user_uid + attributes["gidNumber"] = user_dict["attributes"]["uid"] attributes["givenName"] = first_name if first_name else "" attributes["homeDirectory"] = f"/home/{username}" if username else None attributes["oauth_id"] = user_dict.get("id", None) attributes["sn"] = last_name if last_name else "" - attributes["uidNumber"] = user_uid + attributes["uidNumber"] = user_dict["attributes"]["uid"] output.append(attributes) except KeyError: pass