diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index a40c431..b45ea3e 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -48,7 +48,9 @@ def __init__( if self.debug: log.msg(f"Creating an OAuthClient for {backend}.") oauth_backend = OAuthClientMap[backend] - oauth_backend_args = inspect.getfullargspec(oauth_backend.__init__).args + oauth_backend_args = inspect.getfullargspec( + oauth_backend.__init__ # type: ignore + ).args oauth_client = oauth_backend( client_id=client_id, client_secret=client_secret, diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index 72afe0d..bcabc6c 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -17,7 +17,9 @@ def __init__( @param oauth_client: An OAuth client used to construct the LDAP tree """ # Create an LDAP lookup tree - self.adaptor = OAuthLDAPTree(domain, oauth_client, enable_mirrored_groups) + self.adaptor = OAuthLDAPTree( + domain, oauth_client, enable_mirrored_groups=enable_mirrored_groups + ) def __repr__(self) -> str: return f"{self.__class__.__name__} using adaptor {self.adaptor}" diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index d23d638..3c21a9c 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -5,9 +5,11 @@ from .oauth_client import OAuthClient -def get_single_value_attribute(obj: JSONDict, key: str, default=None) -> Any: +def get_single_value_attribute( + obj: JSONDict, key: str, default: str | None = None +) -> Any: for part in key.split("."): - obj = obj.get(part) + obj = obj.get(part) # type: ignore if obj is None: return default if isinstance(obj, list): @@ -35,7 +37,7 @@ def __init__( self.realm = keycloak_realm redirect_uri = "urn:ietf:wg:oauth:2.0:oob" # this is the "no redirect" URL - scopes = [] # this is the default scope + scopes: list[str] = [] # this is the default scope token_url = f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token" super().__init__( @@ -51,41 +53,40 @@ def extract_token(self, json_response: JSONDict) -> str: def groups(self) -> list[JSONDict]: output = [] try: - group_data = [] + group_data: list[JSONDict] = [] while data := self.query( f"{self.base_url}/admin/realms/{self.realm}/groups?first={len(group_data)}&max={self.max_rows}&briefRepresentation=false", use_client_secret=False, ): - group_data.extend(data) + group_data.extend(cast(list[JSONDict], data)) if len(data) != self.max_rows: break group_data = sorted( group_data, - key=lambda g: int( + key=lambda group: int( get_single_value_attribute( - g, "attributes.gid", default="9999999999" + group, "attributes.gid", default="9999999999" ), - 10, + base=10, ), ) next_gid = max( *( int( - get_single_value_attribute(g, "attributes.gid", default="-1"), - 10, + get_single_value_attribute( + group, "attributes.gid", default="-1" + ), + base=10, ) + 1 - for g in group_data + for group in group_data ), 3000, ) - for group_dict in cast( - list[JSONDict], - group_data, - ): + for group_dict in group_data: group_gid = get_single_value_attribute( group_dict, "attributes.gid", default=None ) @@ -122,46 +123,47 @@ def groups(self) -> list[JSONDict]: def users(self) -> list[JSONDict]: output = [] try: - user_data = [] + user_data: list[JSONDict] = [] while data := self.query( f"{self.base_url}/admin/realms/{self.realm}/users?first={len(user_data)}&max={self.max_rows}&briefRepresentation=false", use_client_secret=False, ): - user_data.extend(data) + user_data.extend(cast(list[JSONDict], data)) if len(data) != self.max_rows: break user_data = sorted( user_data, - key=lambda u: int( + key=lambda user: int( get_single_value_attribute( - u, "attributes.uid", default="9999999999" + user, "attributes.uid", default="9999999999" ), - 10, + base=10, ), ) next_uid = max( *( int( - get_single_value_attribute(g, "attributes.uid", default="-1"), - 10, + get_single_value_attribute( + user, "attributes.uid", default="-1" + ), + base=10, ) + 1 - for g in user_data + for user in user_data ), 3000, ) - for user_dict in cast( - list[JSONDict], - sorted(user_data, key=lambda user: user["createdTimestamp"]), + for user_dict in sorted( + user_data, key=lambda user: user["createdTimestamp"] ): user_uid = get_single_value_attribute( user_dict, "attributes.uid", default=None ) if user_uid: - user_uid = int(user_uid, 10) + user_uid = int(user_uid, base=10) if not user_uid: user_uid = next_uid next_uid += 1 diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 90d2143..b47f98c 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -109,7 +109,7 @@ def users(self) -> list[JSONDict]: """ pass - def query(self, url: str, *, use_client_secret=True) -> dict[str, Any]: + def query(self, url: str, *, use_client_secret: bool = True) -> dict[str, Any]: """ Make a query against the OAuth backend """ @@ -127,12 +127,12 @@ def query(self, url: str, *, use_client_secret=True) -> dict[str, Any]: **kwargs, ) - def request(self, *args, method="GET", **kwargs) -> dict[str, Any]: + def request(self, *args: Any, method: str = "GET", **kwargs: Any) -> dict[str, Any]: """ Make a request to the OAuth backend """ - def query_(*args, **kwargs) -> requests.Response: + def query_(*args: Any, **kwargs: Any) -> requests.Response: return self.session_application.request( # type: ignore[no-any-return] method, *args, @@ -147,8 +147,9 @@ def query_(*args, **kwargs) -> requests.Response: log.msg("Authentication token has expired.") self.bearer_token_ = None result = query_(*args, **kwargs) - if result.status_code != HTTPStatus.NO_CONTENT: - return result.json() # type: ignore + if result.status_code == HTTPStatus.NO_CONTENT: + return {} + return result.json() # type: ignore def verify(self, username: str, password: str) -> bool: """ diff --git a/apricot/types.py b/apricot/types.py index e93f9ea..5cc0617 100644 --- a/apricot/types.py +++ b/apricot/types.py @@ -1,5 +1,6 @@ from typing import Any JSONDict = dict[str, Any] +JSONKey = list[Any] | dict[str, Any] | Any LDAPAttributeDict = dict[str, list[str]] LDAPControlTuple = tuple[str, bool, Any]