diff --git a/CHANGELOG.md b/CHANGELOG.md index 08c103c91b..93676f6f91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ FEATURES: ENHANCEMENTS: BUG FIXES: +* Enabling support for more than 20 users/groups in Workspace API ([#3759](https://github.com/microsoft/AzureTRE/pull/3759 )) COMPONENTS: diff --git a/api_app/_version.py b/api_app/_version.py index a99557a02f..fe051f14ca 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.15.17" +__version__ = "0.15.18" diff --git a/api_app/services/aad_authentication.py b/api_app/services/aad_authentication.py index 29dc1716a7..61ff23f1fb 100644 --- a/api_app/services/aad_authentication.py +++ b/api_app/services/aad_authentication.py @@ -5,7 +5,6 @@ from typing import List, Optional import jwt import requests -import rsa from fastapi import Request, HTTPException, status from msal import ConfidentialClientApplication @@ -19,6 +18,10 @@ from api.dependencies.database import get_db_client_from_request from db.repositories.workspaces import WorkspaceRepository +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization + MICROSOFT_GRAPH_URL = config.MICROSOFT_GRAPH_URL.strip("/") @@ -179,9 +182,12 @@ def _get_token_key(self, key_id: str) -> str: for key in keys['keys']: n = int.from_bytes(base64.urlsafe_b64decode(self._ensure_b64padding(key['n'])), "big") e = int.from_bytes(base64.urlsafe_b64decode(self._ensure_b64padding(key['e'])), "big") - pub_key = rsa.PublicKey(n, e) + pub_key = rsa.RSAPublicNumbers(e, n).public_key(default_backend()) # Cache the PEM formatted public key. - AzureADAuthorization._jwt_keys[key['kid']] = pub_key.save_pkcs1() + AzureADAuthorization._jwt_keys[key['kid']] = pub_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.PKCS1 + ) return AzureADAuthorization._jwt_keys[key_id] @@ -245,7 +251,18 @@ def _get_user_emails(self, roles_graph_data, msgraph_token): batch_request_body = self._get_batch_users_by_role_assignments_body(roles_graph_data) headers = self._get_auth_header(msgraph_token) headers["Content-type"] = "application/json" - users_graph_data = requests.post(batch_endpoint, json=batch_request_body, headers=headers).json() + max_number_request = 20 + requests_from_batch = batch_request_body["requests"] + # We split the original batch request body in sub-lits with at most max_number_request elements + batch_request_body_list = [requests_from_batch[i:i + max_number_request] for i in range(0, len(requests_from_batch), max_number_request)] + users_graph_data = {"responses": []} + + # For each sub-list it's required to call the batch endpoint for retrieveing user/group information + for request_body_element in batch_request_body_list: + batch_request_body_tmp = {"requests": request_body_element} + users_graph_data_tmp = requests.post(batch_endpoint, json=batch_request_body_tmp, headers=headers).json() + users_graph_data["responses"] = users_graph_data["responses"] + users_graph_data_tmp["responses"] + return users_graph_data def _get_user_emails_from_response(self, users_graph_data): diff --git a/api_app/tests_ma/test_services/test_aad_access_service.py b/api_app/tests_ma/test_services/test_aad_access_service.py index 668de31562..eeb7754106 100644 --- a/api_app/tests_ma/test_services/test_aad_access_service.py +++ b/api_app/tests_ma/test_services/test_aad_access_service.py @@ -1,5 +1,5 @@ import pytest -from mock import patch +from mock import call, patch from models.domain.authentication import User, RoleAssignment from models.domain.workspace import Workspace, WorkspaceRole @@ -554,6 +554,66 @@ def test_get_workspace_role_assignment_details_with_groups_and_users_assigned_re assert "test_user1@email.com" in role_assignment_details["WorkspaceOwner"] +@patch("services.aad_authentication.AzureADAuthorization._get_auth_header") +@patch("services.aad_authentication.AzureADAuthorization._get_batch_users_by_role_assignments_body") +@patch("requests.post") +def test_get_user_emails_with_batch_of_more_than_20_requests(mock_graph_post, mock_get_batch_users_by_role_assignments_body, mock_headers): + # Arrange + access_service = AzureADAuthorization() + roles_graph_data = [{"id": "role1"}, {"id": "role2"}] + msgraph_token = "token" + batch_endpoint = access_service._get_batch_endpoint() + + # mock the response of _get_auth_header + headers = {"Authorization": f"Bearer {msgraph_token}"} + mock_headers.return_value = headers + headers["Content-type"] = "application/json" + + # mock the response of the get batch request for 30 users + batch_request_body_first_20 = { + "requests": [ + {"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(20) + ] + } + + batch_request_body_last_10 = { + "requests": [ + {"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(20, 30) + ] + } + + batch_request_body = { + "requests": [ + {"id": f"{i}", "method": "GET", "url": f"/users/{i}"} for i in range(30) + ] + } + + mock_get_batch_users_by_role_assignments_body.return_value = batch_request_body + + # Mock the response of the post request + mock_graph_post_response = {"responses": [{"id": "user1"}, {"id": "user2"}]} + mock_graph_post.return_value.json.return_value = mock_graph_post_response + + # Act + users_graph_data = access_service._get_user_emails(roles_graph_data, msgraph_token) + + # Assert + assert len(users_graph_data["responses"]) == 4 + calls = [ + call( + f"{batch_endpoint}", + json=batch_request_body_first_20, + headers=headers + ), + call( + f"{batch_endpoint}", + json=batch_request_body_last_10, + headers=headers + ) + ] + mock_graph_post.assert_has_calls(calls, any_order=True) + + def get_mock_batch_response(user_principals, group_principals): response_body = {"responses": []} for user_principal in user_principals: