diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 21a4b6113d0da8..afe3b2d7aa3a30 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -115,7 +115,7 @@ async def async_create_flow( *, context: AuthFlowContext | None = None, data: dict[str, Any] | None = None, - ) -> LoginFlow: + ) -> LoginFlow[Any]: """Create a login flow.""" auth_provider = self.auth_manager.get_auth_provider(*handler_key) if not auth_provider: diff --git a/homeassistant/auth/mfa_modules/__init__.py b/homeassistant/auth/mfa_modules/__init__.py index d57a274c7ffd8b..8a6430d770a028 100644 --- a/homeassistant/auth/mfa_modules/__init__.py +++ b/homeassistant/auth/mfa_modules/__init__.py @@ -4,8 +4,9 @@ import logging import types -from typing import Any +from typing import Any, Generic +from typing_extensions import TypeVar import voluptuous as vol from voluptuous.humanize import humanize_error @@ -34,6 +35,12 @@ _LOGGER = logging.getLogger(__name__) +_MultiFactorAuthModuleT = TypeVar( + "_MultiFactorAuthModuleT", + bound="MultiFactorAuthModule", + default="MultiFactorAuthModule", +) + class MultiFactorAuthModule: """Multi-factor Auth Module of validation function.""" @@ -71,7 +78,7 @@ def input_schema(self) -> vol.Schema: """Return a voluptuous schema to define mfa auth module's input.""" raise NotImplementedError - async def async_setup_flow(self, user_id: str) -> SetupFlow: + async def async_setup_flow(self, user_id: str) -> SetupFlow[Any]: """Return a data entry flow handler for setup module. Mfa module should extend SetupFlow @@ -95,11 +102,14 @@ async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool raise NotImplementedError -class SetupFlow(data_entry_flow.FlowHandler): +class SetupFlow(data_entry_flow.FlowHandler, Generic[_MultiFactorAuthModuleT]): """Handler for the setup flow.""" def __init__( - self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str + self, + auth_module: _MultiFactorAuthModuleT, + setup_schema: vol.Schema, + user_id: str, ) -> None: """Initialize the setup flow.""" self._auth_module = auth_module diff --git a/homeassistant/auth/mfa_modules/notify.py b/homeassistant/auth/mfa_modules/notify.py index d2010dc2c9dc4f..b60a3012aace11 100644 --- a/homeassistant/auth/mfa_modules/notify.py +++ b/homeassistant/auth/mfa_modules/notify.py @@ -162,7 +162,7 @@ def aync_get_available_notify_services(self) -> list[str]: return sorted(unordered_services) - async def async_setup_flow(self, user_id: str) -> SetupFlow: + async def async_setup_flow(self, user_id: str) -> NotifySetupFlow: """Return a data entry flow handler for setup module. Mfa module should extend SetupFlow @@ -268,7 +268,7 @@ async def async_notify( await self.hass.services.async_call("notify", notify_service, data) -class NotifySetupFlow(SetupFlow): +class NotifySetupFlow(SetupFlow[NotifyAuthModule]): """Handler for the setup flow.""" def __init__( @@ -280,8 +280,6 @@ def __init__( ) -> None: """Initialize the setup flow.""" super().__init__(auth_module, setup_schema, user_id) - # to fix typing complaint - self._auth_module: NotifyAuthModule = auth_module self._available_notify_services = available_notify_services self._secret: str | None = None self._count: int | None = None diff --git a/homeassistant/auth/mfa_modules/totp.py b/homeassistant/auth/mfa_modules/totp.py index 3306f76217feca..625b273f39af03 100644 --- a/homeassistant/auth/mfa_modules/totp.py +++ b/homeassistant/auth/mfa_modules/totp.py @@ -114,7 +114,7 @@ def _add_ota_secret(self, user_id: str, secret: str | None = None) -> str: self._users[user_id] = ota_secret # type: ignore[index] return ota_secret - async def async_setup_flow(self, user_id: str) -> SetupFlow: + async def async_setup_flow(self, user_id: str) -> TotpSetupFlow: """Return a data entry flow handler for setup module. Mfa module should extend SetupFlow @@ -174,10 +174,9 @@ def _validate_2fa(self, user_id: str, code: str) -> bool: return bool(pyotp.TOTP(ota_secret).verify(code, valid_window=1)) -class TotpSetupFlow(SetupFlow): +class TotpSetupFlow(SetupFlow[TotpAuthModule]): """Handler for the setup flow.""" - _auth_module: TotpAuthModule _ota_secret: str _url: str _image: str diff --git a/homeassistant/auth/providers/__init__.py b/homeassistant/auth/providers/__init__.py index 34278c47df7dc6..02f99e7bd71767 100644 --- a/homeassistant/auth/providers/__init__.py +++ b/homeassistant/auth/providers/__init__.py @@ -5,8 +5,9 @@ from collections.abc import Mapping import logging import types -from typing import Any +from typing import Any, Generic +from typing_extensions import TypeVar import voluptuous as vol from voluptuous.humanize import humanize_error @@ -46,6 +47,8 @@ extra=vol.ALLOW_EXTRA, ) +_AuthProviderT = TypeVar("_AuthProviderT", bound="AuthProvider", default="AuthProvider") + class AuthProvider: """Provider of user authentication.""" @@ -105,7 +108,7 @@ def async_create_credentials(self, data: dict[str, str]) -> Credentials: # Implement by extending class - async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: + async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow[Any]: """Return the data flow for logging in with auth provider. Auth provider should extend LoginFlow and return an instance. @@ -192,12 +195,15 @@ async def load_auth_provider_module( return module -class LoginFlow(FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]]): +class LoginFlow( + FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]], + Generic[_AuthProviderT], +): """Handler for the login flow.""" _flow_result = AuthFlowResult - def __init__(self, auth_provider: AuthProvider) -> None: + def __init__(self, auth_provider: _AuthProviderT) -> None: """Initialize the login flow.""" self._auth_provider = auth_provider self._auth_module_id: str | None = None diff --git a/homeassistant/auth/providers/command_line.py b/homeassistant/auth/providers/command_line.py index 12447bc8c18b04..74630d925e11c9 100644 --- a/homeassistant/auth/providers/command_line.py +++ b/homeassistant/auth/providers/command_line.py @@ -6,7 +6,7 @@ from collections.abc import Mapping import logging import os -from typing import Any, cast +from typing import Any import voluptuous as vol @@ -59,7 +59,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._user_meta: dict[str, dict[str, Any]] = {} - async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: + async def async_login_flow( + self, context: AuthFlowContext | None + ) -> CommandLineLoginFlow: """Return a flow to login.""" return CommandLineLoginFlow(self) @@ -133,7 +135,7 @@ async def async_user_meta_for_credentials( ) -class CommandLineLoginFlow(LoginFlow): +class CommandLineLoginFlow(LoginFlow[CommandLineAuthProvider]): """Handler for the login flow.""" async def async_step_init( @@ -145,9 +147,9 @@ async def async_step_init( if user_input is not None: user_input["username"] = user_input["username"].strip() try: - await cast( - CommandLineAuthProvider, self._auth_provider - ).async_validate_login(user_input["username"], user_input["password"]) + await self._auth_provider.async_validate_login( + user_input["username"], user_input["password"] + ) except InvalidAuthError: errors["base"] = "invalid_auth" diff --git a/homeassistant/auth/providers/homeassistant.py b/homeassistant/auth/providers/homeassistant.py index e5dded74762195..522e5d77a29ba4 100644 --- a/homeassistant/auth/providers/homeassistant.py +++ b/homeassistant/auth/providers/homeassistant.py @@ -305,7 +305,7 @@ async def async_initialize(self) -> None: await data.async_load() self.data = data - async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: + async def async_login_flow(self, context: AuthFlowContext | None) -> HassLoginFlow: """Return a flow to login.""" return HassLoginFlow(self) @@ -400,7 +400,7 @@ async def async_will_remove_credentials(self, credentials: Credentials) -> None: pass -class HassLoginFlow(LoginFlow): +class HassLoginFlow(LoginFlow[HassAuthProvider]): """Handler for the login flow.""" async def async_step_init( @@ -411,7 +411,7 @@ async def async_step_init( if user_input is not None: try: - await cast(HassAuthProvider, self._auth_provider).async_validate_login( + await self._auth_provider.async_validate_login( user_input["username"], user_input["password"] ) except InvalidAuth: diff --git a/homeassistant/auth/providers/insecure_example.py b/homeassistant/auth/providers/insecure_example.py index a7dced851a301c..a92f5b558486c5 100644 --- a/homeassistant/auth/providers/insecure_example.py +++ b/homeassistant/auth/providers/insecure_example.py @@ -4,7 +4,6 @@ from collections.abc import Mapping import hmac -from typing import cast import voluptuous as vol @@ -36,7 +35,9 @@ class InvalidAuthError(HomeAssistantError): class ExampleAuthProvider(AuthProvider): """Example auth provider based on hardcoded usernames and passwords.""" - async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: + async def async_login_flow( + self, context: AuthFlowContext | None + ) -> ExampleLoginFlow: """Return a flow to login.""" return ExampleLoginFlow(self) @@ -93,7 +94,7 @@ async def async_user_meta_for_credentials( return UserMeta(name=name, is_active=True) -class ExampleLoginFlow(LoginFlow): +class ExampleLoginFlow(LoginFlow[ExampleAuthProvider]): """Handler for the login flow.""" async def async_step_init( @@ -104,7 +105,7 @@ async def async_step_init( if user_input is not None: try: - cast(ExampleAuthProvider, self._auth_provider).async_validate_login( + self._auth_provider.async_validate_login( user_input["username"], user_input["password"] ) except InvalidAuthError: diff --git a/homeassistant/auth/providers/trusted_networks.py b/homeassistant/auth/providers/trusted_networks.py index f32c35d4bd554f..799fd4d2e16a28 100644 --- a/homeassistant/auth/providers/trusted_networks.py +++ b/homeassistant/auth/providers/trusted_networks.py @@ -104,7 +104,9 @@ def support_mfa(self) -> bool: """Trusted Networks auth provider does not support MFA.""" return False - async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: + async def async_login_flow( + self, context: AuthFlowContext | None + ) -> TrustedNetworksLoginFlow: """Return a flow to login.""" assert context is not None ip_addr = cast(IPAddress, context.get("ip_address")) @@ -214,7 +216,7 @@ def async_validate_refresh_token( self.async_validate_access(ip_address(remote_ip)) -class TrustedNetworksLoginFlow(LoginFlow): +class TrustedNetworksLoginFlow(LoginFlow[TrustedNetworksAuthProvider]): """Handler for the login flow.""" def __init__( @@ -235,9 +237,7 @@ async def async_step_init( ) -> AuthFlowResult: """Handle the step of the form.""" try: - cast( - TrustedNetworksAuthProvider, self._auth_provider - ).async_validate_access(self._ip_address) + self._auth_provider.async_validate_access(self._ip_address) except InvalidAuthError: return self.async_abort(reason="not_allowed")