Skip to content

Commit

Permalink
Improve auth generic typing (#133061)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p authored Dec 12, 2024
1 parent ce70cb9 commit 32c1b51
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 34 deletions.
2 changes: 1 addition & 1 deletion homeassistant/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def async_create_flow(
*,
context: AuthFlowContext | None = None,
data: dict[str, Any] | None = None,
) -> LoginFlow:
) -> LoginFlow[Any]:
"""Create a login flow."""
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
if not auth_provider:
Expand Down
18 changes: 14 additions & 4 deletions homeassistant/auth/mfa_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import logging
import types
from typing import Any
from typing import Any, Generic

from typing_extensions import TypeVar
import voluptuous as vol
from voluptuous.humanize import humanize_error

Expand Down Expand Up @@ -34,6 +35,12 @@

_LOGGER = logging.getLogger(__name__)

_MultiFactorAuthModuleT = TypeVar(
"_MultiFactorAuthModuleT",
bound="MultiFactorAuthModule",
default="MultiFactorAuthModule",
)


class MultiFactorAuthModule:
"""Multi-factor Auth Module of validation function."""
Expand Down Expand Up @@ -71,7 +78,7 @@ def input_schema(self) -> vol.Schema:
"""Return a voluptuous schema to define mfa auth module's input."""
raise NotImplementedError

async def async_setup_flow(self, user_id: str) -> SetupFlow:
async def async_setup_flow(self, user_id: str) -> SetupFlow[Any]:
"""Return a data entry flow handler for setup module.
Mfa module should extend SetupFlow
Expand All @@ -95,11 +102,14 @@ async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool
raise NotImplementedError


class SetupFlow(data_entry_flow.FlowHandler):
class SetupFlow(data_entry_flow.FlowHandler, Generic[_MultiFactorAuthModuleT]):
"""Handler for the setup flow."""

def __init__(
self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str
self,
auth_module: _MultiFactorAuthModuleT,
setup_schema: vol.Schema,
user_id: str,
) -> None:
"""Initialize the setup flow."""
self._auth_module = auth_module
Expand Down
6 changes: 2 additions & 4 deletions homeassistant/auth/mfa_modules/notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def aync_get_available_notify_services(self) -> list[str]:

return sorted(unordered_services)

async def async_setup_flow(self, user_id: str) -> SetupFlow:
async def async_setup_flow(self, user_id: str) -> NotifySetupFlow:
"""Return a data entry flow handler for setup module.
Mfa module should extend SetupFlow
Expand Down Expand Up @@ -268,7 +268,7 @@ async def async_notify(
await self.hass.services.async_call("notify", notify_service, data)


class NotifySetupFlow(SetupFlow):
class NotifySetupFlow(SetupFlow[NotifyAuthModule]):
"""Handler for the setup flow."""

def __init__(
Expand All @@ -280,8 +280,6 @@ def __init__(
) -> None:
"""Initialize the setup flow."""
super().__init__(auth_module, setup_schema, user_id)
# to fix typing complaint
self._auth_module: NotifyAuthModule = auth_module
self._available_notify_services = available_notify_services
self._secret: str | None = None
self._count: int | None = None
Expand Down
5 changes: 2 additions & 3 deletions homeassistant/auth/mfa_modules/totp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _add_ota_secret(self, user_id: str, secret: str | None = None) -> str:
self._users[user_id] = ota_secret # type: ignore[index]
return ota_secret

async def async_setup_flow(self, user_id: str) -> SetupFlow:
async def async_setup_flow(self, user_id: str) -> TotpSetupFlow:
"""Return a data entry flow handler for setup module.
Mfa module should extend SetupFlow
Expand Down Expand Up @@ -174,10 +174,9 @@ def _validate_2fa(self, user_id: str, code: str) -> bool:
return bool(pyotp.TOTP(ota_secret).verify(code, valid_window=1))


class TotpSetupFlow(SetupFlow):
class TotpSetupFlow(SetupFlow[TotpAuthModule]):
"""Handler for the setup flow."""

_auth_module: TotpAuthModule
_ota_secret: str
_url: str
_image: str
Expand Down
14 changes: 10 additions & 4 deletions homeassistant/auth/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from collections.abc import Mapping
import logging
import types
from typing import Any
from typing import Any, Generic

from typing_extensions import TypeVar
import voluptuous as vol
from voluptuous.humanize import humanize_error

Expand Down Expand Up @@ -46,6 +47,8 @@
extra=vol.ALLOW_EXTRA,
)

_AuthProviderT = TypeVar("_AuthProviderT", bound="AuthProvider", default="AuthProvider")


class AuthProvider:
"""Provider of user authentication."""
Expand Down Expand Up @@ -105,7 +108,7 @@ def async_create_credentials(self, data: dict[str, str]) -> Credentials:

# Implement by extending class

async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow[Any]:
"""Return the data flow for logging in with auth provider.
Auth provider should extend LoginFlow and return an instance.
Expand Down Expand Up @@ -192,12 +195,15 @@ async def load_auth_provider_module(
return module


class LoginFlow(FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]]):
class LoginFlow(
FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]],
Generic[_AuthProviderT],
):
"""Handler for the login flow."""

_flow_result = AuthFlowResult

def __init__(self, auth_provider: AuthProvider) -> None:
def __init__(self, auth_provider: _AuthProviderT) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
self._auth_module_id: str | None = None
Expand Down
14 changes: 8 additions & 6 deletions homeassistant/auth/providers/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Mapping
import logging
import os
from typing import Any, cast
from typing import Any

import voluptuous as vol

Expand Down Expand Up @@ -59,7 +59,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._user_meta: dict[str, dict[str, Any]] = {}

async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
async def async_login_flow(
self, context: AuthFlowContext | None
) -> CommandLineLoginFlow:
"""Return a flow to login."""
return CommandLineLoginFlow(self)

Expand Down Expand Up @@ -133,7 +135,7 @@ async def async_user_meta_for_credentials(
)


class CommandLineLoginFlow(LoginFlow):
class CommandLineLoginFlow(LoginFlow[CommandLineAuthProvider]):
"""Handler for the login flow."""

async def async_step_init(
Expand All @@ -145,9 +147,9 @@ async def async_step_init(
if user_input is not None:
user_input["username"] = user_input["username"].strip()
try:
await cast(
CommandLineAuthProvider, self._auth_provider
).async_validate_login(user_input["username"], user_input["password"])
await self._auth_provider.async_validate_login(
user_input["username"], user_input["password"]
)
except InvalidAuthError:
errors["base"] = "invalid_auth"

Expand Down
6 changes: 3 additions & 3 deletions homeassistant/auth/providers/homeassistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ async def async_initialize(self) -> None:
await data.async_load()
self.data = data

async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
async def async_login_flow(self, context: AuthFlowContext | None) -> HassLoginFlow:
"""Return a flow to login."""
return HassLoginFlow(self)

Expand Down Expand Up @@ -400,7 +400,7 @@ async def async_will_remove_credentials(self, credentials: Credentials) -> None:
pass


class HassLoginFlow(LoginFlow):
class HassLoginFlow(LoginFlow[HassAuthProvider]):
"""Handler for the login flow."""

async def async_step_init(
Expand All @@ -411,7 +411,7 @@ async def async_step_init(

if user_input is not None:
try:
await cast(HassAuthProvider, self._auth_provider).async_validate_login(
await self._auth_provider.async_validate_login(
user_input["username"], user_input["password"]
)
except InvalidAuth:
Expand Down
9 changes: 5 additions & 4 deletions homeassistant/auth/providers/insecure_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from collections.abc import Mapping
import hmac
from typing import cast

import voluptuous as vol

Expand Down Expand Up @@ -36,7 +35,9 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords."""

async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
async def async_login_flow(
self, context: AuthFlowContext | None
) -> ExampleLoginFlow:
"""Return a flow to login."""
return ExampleLoginFlow(self)

Expand Down Expand Up @@ -93,7 +94,7 @@ async def async_user_meta_for_credentials(
return UserMeta(name=name, is_active=True)


class ExampleLoginFlow(LoginFlow):
class ExampleLoginFlow(LoginFlow[ExampleAuthProvider]):
"""Handler for the login flow."""

async def async_step_init(
Expand All @@ -104,7 +105,7 @@ async def async_step_init(

if user_input is not None:
try:
cast(ExampleAuthProvider, self._auth_provider).async_validate_login(
self._auth_provider.async_validate_login(
user_input["username"], user_input["password"]
)
except InvalidAuthError:
Expand Down
10 changes: 5 additions & 5 deletions homeassistant/auth/providers/trusted_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def support_mfa(self) -> bool:
"""Trusted Networks auth provider does not support MFA."""
return False

async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
async def async_login_flow(
self, context: AuthFlowContext | None
) -> TrustedNetworksLoginFlow:
"""Return a flow to login."""
assert context is not None
ip_addr = cast(IPAddress, context.get("ip_address"))
Expand Down Expand Up @@ -214,7 +216,7 @@ def async_validate_refresh_token(
self.async_validate_access(ip_address(remote_ip))


class TrustedNetworksLoginFlow(LoginFlow):
class TrustedNetworksLoginFlow(LoginFlow[TrustedNetworksAuthProvider]):
"""Handler for the login flow."""

def __init__(
Expand All @@ -235,9 +237,7 @@ async def async_step_init(
) -> AuthFlowResult:
"""Handle the step of the form."""
try:
cast(
TrustedNetworksAuthProvider, self._auth_provider
).async_validate_access(self._ip_address)
self._auth_provider.async_validate_access(self._ip_address)

except InvalidAuthError:
return self.async_abort(reason="not_allowed")
Expand Down

0 comments on commit 32c1b51

Please sign in to comment.