Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaner exit when user credentials are incorrect #2296

Merged
merged 20 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data_safe_haven/commands/sre.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def deploy(
)
# Set Entra options
application = graph_api.get_application_by_name(context.entra_application_name)

if not application:
msg = f"No Entra application '{context.entra_application_name}' was found. Please redeploy your SHM."
raise DataSafeHavenConfigError(msg)
Expand Down
10 changes: 10 additions & 0 deletions data_safe_haven/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ class DataSafeHavenAzureError(DataSafeHavenError):
pass


class DataSafeHavenCachedCredentialError(DataSafeHavenError):
"""
Exception class for handling errors related to cached credentials.

Raise this error when a cached credential is not the credential a user wants to use.
"""

pass


class DataSafeHavenAzureStorageError(DataSafeHavenAzureError):
"""
Exception class for handling errors when interacting with Azure Storage.
Expand Down
60 changes: 36 additions & 24 deletions data_safe_haven/external/api/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@

from data_safe_haven import console
from data_safe_haven.directories import config_dir
from data_safe_haven.exceptions import DataSafeHavenAzureError, DataSafeHavenValueError
from data_safe_haven.exceptions import (
DataSafeHavenAzureError,
DataSafeHavenCachedCredentialError,
DataSafeHavenValueError,
)
from data_safe_haven.logging import get_logger
from data_safe_haven.types import AzureSdkCredentialScope

Expand All @@ -28,6 +32,7 @@ class DeferredCredential(TokenCredential):

tokens_: ClassVar[dict[str, AccessToken]] = {}
cache_: ClassVar[set[tuple[str, str]]] = set()
name: ClassVar[str] = "Credential name"

def __init__(
self,
Expand Down Expand Up @@ -66,32 +71,28 @@ def get_credential(self) -> TokenCredential:

def confirm_credentials_interactive(
self,
target_name: str,
user_name: str,
user_id: str,
tenant_name: str,
tenant_id: str,
) -> None:
) -> bool:
"""
Allow user to confirm that credentials are correct.

Responses are cached so the user will only be prompted once per run.
If 'skip_confirmation' is set, then no confirmation will be performed.

Raises:
DataSafeHavenValueError: if the user indicates that the credentials are wrong
"""
if self.skip_confirmation:
return
return True
if (user_id, tenant_id) in DeferredCredential.cache_:
return
return True

DeferredCredential.cache_.add((user_id, tenant_id))
self.logger.info(f"You are logged into the [blue]{target_name}[/] as:")
self.logger.info(f"You are logged into the [blue]{self.name}[/] as:")
self.logger.info(f"\tuser: [green]{user_name}[/] ({user_id})")
self.logger.info(f"\ttenant: [green]{tenant_name}[/] ({tenant_id})")
if not console.confirm("Are these details correct?", default_to_yes=True):
msg = "Selected credentials are incorrect."
raise DataSafeHavenValueError(msg)

return console.confirm("Are these details correct?", default_to_yes=True)

def get_token(
self,
Expand Down Expand Up @@ -119,6 +120,8 @@ class AzureSdkCredential(DeferredCredential):
Uses AzureCliCredential for authentication
"""

name: ClassVar[str] = "Azure CLI"

def __init__(
self,
scope: AzureSdkCredentialScope = AzureSdkCredentialScope.DEFAULT,
Expand All @@ -133,19 +136,22 @@ def get_credential(self) -> TokenCredential:
# Confirm that these are the desired credentials
try:
decoded = self.decode_token(credential.get_token(*self.scopes).token)
self.confirm_credentials_interactive(
"Azure CLI",
user_name=decoded["name"],
user_id=decoded["oid"],
tenant_name=decoded["upn"].split("@")[1],
tenant_id=decoded["tid"],
)
except (CredentialUnavailableError, DataSafeHavenValueError) as exc:
msg = "Error getting account information from Azure CLI."
raise DataSafeHavenAzureError(msg) from exc

if not self.confirm_credentials_interactive(
user_name=decoded["name"],
user_id=decoded["oid"],
tenant_name=decoded["upn"].split("@")[1],
tenant_id=decoded["tid"],
):
self.logger.error(
"Please authenticate with Azure: run '[green]az login[/]' using [bold]infrastructure administrator[/] credentials."
)
msg = "Error getting account information from Azure CLI."
raise DataSafeHavenAzureError(msg) from exc
msg = "Selected credentials are incorrect."
raise DataSafeHavenCachedCredentialError(msg)

return credential


Expand All @@ -156,6 +162,8 @@ class GraphApiCredential(DeferredCredential):
Uses DeviceCodeCredential for authentication
"""

name: ClassVar[str] = "Microsoft Graph API"

def __init__(
self,
tenant_id: str,
Expand Down Expand Up @@ -214,13 +222,17 @@ def callback(verification_uri: str, user_code: str, _: datetime) -> None:
raise DataSafeHavenAzureError(msg) from exc

# Confirm that these are the desired credentials
self.confirm_credentials_interactive(
"Microsoft Graph API",
if not self.confirm_credentials_interactive(
user_name=new_auth_record.username,
user_id=new_auth_record._home_account_id.split(".")[0],
tenant_name=new_auth_record._username.split("@")[1],
tenant_id=new_auth_record._tenant_id,
)
):
self.logger.error(
f"Delete the cached credential file [green]{authentication_record_path}[/] and rerun dsh to authenticate with {self.name}"
)
msg = "Selected credentials are incorrect."
raise DataSafeHavenCachedCredentialError(msg)

# Return the credential
return credential
5 changes: 4 additions & 1 deletion data_safe_haven/external/api/graph_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,10 @@ def read_applications(self) -> Sequence[dict[str, Any]]:
"value"
]
]
except Exception as exc:
except (
DataSafeHavenMicrosoftGraphError,
requests.JSONDecodeError,
) as exc:
msg = "Could not load list of applications."
raise DataSafeHavenMicrosoftGraphError(msg) from exc

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ pip-compile-constraint = "default"
features = ["test"]

[tool.hatch.envs.test.scripts]
test = "coverage run -m pytest {args: tests}"
test = "coverage run -m pytest {args:} ./tests"
test-report = "coverage report {args:}"
test-coverage = ["test", "test-report"]

Expand Down
6 changes: 5 additions & 1 deletion tests/commands/test_sre.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from data_safe_haven.commands.sre import sre_command_group
from data_safe_haven.config import Context, ContextManager
from data_safe_haven.exceptions import DataSafeHavenAzureError
from data_safe_haven.external import AzureSdk
from data_safe_haven.external import AzureSdk, GraphApi


class TestDeploySRE:
Expand All @@ -31,13 +31,17 @@ def test_no_application(
self,
caplog: LogCaptureFixture,
runner: CliRunner,
mocker,
mock_azuresdk_get_subscription_name, # noqa: ARG002
mock_contextmanager_assert_context, # noqa: ARG002
mock_ip_1_2_3_4, # noqa: ARG002
mock_pulumi_config_from_remote_or_create, # noqa: ARG002
mock_shm_config_from_remote, # noqa: ARG002
mock_sre_config_from_remote, # noqa: ARG002
mock_graphapi_get_credential, # noqa: ARG002
) -> None:
mocker.patch.object(GraphApi, "get_application_by_name", return_value=None)

result = runner.invoke(sre_command_group, ["deploy", "sandbox"])
assert result.exit_code == 1
assert (
Expand Down
18 changes: 17 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
)
from data_safe_haven.exceptions import DataSafeHavenAzureError
from data_safe_haven.external import AzureSdk, PulumiAccount
from data_safe_haven.external.api.credentials import AzureSdkCredential
from data_safe_haven.external.api.credentials import (
AzureSdkCredential,
GraphApiCredential,
)
from data_safe_haven.infrastructure import SREProjectManager
from data_safe_haven.infrastructure.project_manager import ProjectManager
from data_safe_haven.logging import init_logging
Expand Down Expand Up @@ -215,6 +218,19 @@ def mock_azuresdk_get_subscription_name(mocker):
)


@fixture
def mock_graphapi_get_credential(mocker):
class MockCredential(TokenCredential):
def get_token(*args, **kwargs): # noqa: ARG002
return AccessToken("dummy-token", 0)

mocker.patch.object(
GraphApiCredential,
"get_credential",
return_value=MockCredential(),
)


@fixture
def mock_azuresdk_get_credential(mocker):
class MockCredential(TokenCredential):
Expand Down
16 changes: 10 additions & 6 deletions tests/external/api/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
)

from data_safe_haven.directories import config_dir
from data_safe_haven.exceptions import DataSafeHavenAzureError
from data_safe_haven.exceptions import (
DataSafeHavenAzureError,
DataSafeHavenCachedCredentialError,
)
from data_safe_haven.external.api.credentials import (
AzureSdkCredential,
DeferredCredential,
GraphApiCredential,
)


class TestDeferredCredential:
class TestAzureSdkCredential:
def test_confirm_credentials_interactive(
self,
mock_confirm_yes, # noqa: ARG002
Expand All @@ -33,14 +36,17 @@ def test_confirm_credentials_interactive_fail(
self,
mock_confirm_no, # noqa: ARG002
mock_azureclicredential_get_token, # noqa: ARG002
capsys,
):
DeferredCredential.cache_ = set()
credential = AzureSdkCredential(skip_confirmation=False)
with pytest.raises(
DataSafeHavenAzureError,
match="Error getting account information from Azure CLI.",
DataSafeHavenCachedCredentialError,
match="Selected credentials are incorrect.",
):
credential.get_credential()
out, _ = capsys.readouterr()
assert "Please authenticate with Azure: run 'az login'" in out

def test_confirm_credentials_interactive_cache(
self,
Expand All @@ -67,8 +73,6 @@ def test_decode_token_error(
):
credential.decode_token(credential.token)


class TestAzureSdkCredential:
def test_get_credential(self, mock_azureclicredential_get_token): # noqa: ARG002
credential = AzureSdkCredential(skip_confirmation=True)
assert isinstance(credential.get_credential(), AzureCliCredential)
Expand Down
Loading