Skip to content

Commit

Permalink
🐛 Revert change to OAuthClient.query which stopped passing client_id …
Browse files Browse the repository at this point in the history
…and client_secret to the OAuth backend
  • Loading branch information
jemrobinson committed May 23, 2024
1 parent 56464c6 commit 278c72d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
2 changes: 1 addition & 1 deletion apricot/cache/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class RedisCache(UidCache):
def __init__(self, redis_host: str, redis_port: int) -> None:
self.redis_host = redis_host
self.redis_port = redis_port
self.cache_: "redis.Redis[str]" | None = None
self.cache_: "redis.Redis[str]" | None = None # noqa: UP037

@property
def cache(self) -> "redis.Redis[str]":
Expand Down
9 changes: 6 additions & 3 deletions apricot/oauth/keycloak_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def groups(self) -> list[JSONDict]:
try:
group_data = []
while data := self.query(
f"{self.base_url}/admin/realms/{self.realm}/groups?first={len(group_data)}&max={self.max_rows}&briefRepresentation=false"
f"{self.base_url}/admin/realms/{self.realm}/groups?first={len(group_data)}&max={self.max_rows}&briefRepresentation=false",
use_client_secret=False,
):
group_data.extend(data)
if len(data) != self.max_rows:
Expand Down Expand Up @@ -107,7 +108,8 @@ def groups(self) -> list[JSONDict]:
attributes["oauth_id"] = group_dict.get("id", None)
# Add membership attributes
members = self.query(
f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}/members"
f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}/members",
use_client_secret=False,
)
attributes["memberUid"] = [
user["username"] for user in cast(list[JSONDict], members)
Expand All @@ -122,7 +124,8 @@ def users(self) -> list[JSONDict]:
try:
user_data = []
while data := self.query(
f"{self.base_url}/admin/realms/{self.realm}/users?first={len(user_data)}&max={self.max_rows}&briefRepresentation=false"
f"{self.base_url}/admin/realms/{self.realm}/users?first={len(user_data)}&max={self.max_rows}&briefRepresentation=false",
use_client_secret=False,
):
user_data.extend(data)
if len(data) != self.max_rows:
Expand Down
31 changes: 15 additions & 16 deletions apricot/oauth/oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,28 +109,27 @@ def users(self) -> list[JSONDict]:
"""
pass

def query(self, url: str) -> dict[str, Any]:
def query(self, url: str, *, use_client_secret=True) -> dict[str, Any]:
"""
Make a query against the OAuth backend
"""

def query_(url: str) -> requests.Response:
return self.session_application.get( # type: ignore[no-any-return]
url=url, headers={"Authorization": f"Bearer {self.bearer_token}"}
)

try:
result = query_(url)
result.raise_for_status()
except (TokenExpiredError, requests.exceptions.HTTPError):
log.msg("Authentication token has expired.")
self.bearer_token_ = None
result = query_(url)
return result.json() # type: ignore
kwargs = (
{
"client_id": self.session_application._client.client_id,
"client_secret": self.client_secret,
}
if use_client_secret
else {}
)
return self.request(
url=url,
method="GET",
**kwargs,
)

def request(self, *args, method="GET", **kwargs) -> dict[str, Any]:
"""
Make a query against the OAuth backend
Make a request to the OAuth backend
"""

def query_(*args, **kwargs) -> requests.Response:
Expand Down

0 comments on commit 278c72d

Please sign in to comment.