Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
spacemanspiff2007 committed Dec 10, 2024
1 parent 200ab3f commit d842764
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
17 changes: 13 additions & 4 deletions src/HABApp/core/connections/manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

import asyncio
from typing import Final, TypeVar
from typing import TYPE_CHECKING, Final, TypeVar

import HABApp
from HABApp.core.connections import BaseConnection
from HABApp.core.connections._definitions import connection_log


if TYPE_CHECKING:
from collections.abc import Generator


T = TypeVar('T', bound=BaseConnection)


Expand All @@ -16,16 +20,21 @@ def __init__(self) -> None:
self.connections: dict[str, BaseConnection] = {}

def add(self, connection: T) -> T:
assert connection.name not in self.connections
if connection.name in self.connections:
msg = f'Connection {connection.name:s} already exists!'
raise ValueError(msg)

self.connections[connection.name] = connection
connection_log.debug(f'Added {connection.name:s}')

return connection

def get(self, name: str) -> BaseConnection:
return self.connections[name]

def remove(self, name):
def get_names(self) -> Generator[str, None, None]:
yield from self.connections.keys()

def remove(self, name: str) -> None:
con = self.get(name)
if not con.is_shutdown:
raise ValueError()
Expand Down
12 changes: 8 additions & 4 deletions src/HABApp/core/connections/plugin_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from inspect import getmembers, iscoroutinefunction, signature
from typing import TYPE_CHECKING, Any

from typing_extensions import Self

from ._definitions import ConnectionStatus


Expand Down Expand Up @@ -54,9 +56,10 @@ async def run(self, connection: BaseConnection, context: Any):
return await self.coro(**kwargs)

@staticmethod
def _get_coro_kwargs(plugin: BaseConnectionPlugin, coro: Callable[[...], Awaitable]):
def _get_coro_kwargs(plugin: BaseConnectionPlugin, coro: Callable[[...], Awaitable]) -> tuple[str, ...]:
if not iscoroutinefunction(coro):
raise ValueError(f'Coroutine function expected for {plugin.plugin_name}.{coro.__name__}')
msg = f'Coroutine function expected for {plugin.plugin_name}.{coro.__name__}'
raise ValueError(msg)

sig = signature(coro)

Expand All @@ -65,9 +68,10 @@ def _get_coro_kwargs(plugin: BaseConnectionPlugin, coro: Callable[[...], Awaitab
if name in ('connection', 'context'):
kwargs.append(name)
else:
raise ValueError(f'Invalid parameter name "{name:s}" for {plugin.plugin_name}.{coro.__name__}')
msg = f'Invalid parameter name "{name:s}" for {plugin.plugin_name}.{coro.__name__}'
raise ValueError(msg)
return tuple(kwargs)

@classmethod
def create(cls, plugin: BaseConnectionPlugin, coro: Callable[[...], Awaitable]):
def create(cls, plugin: BaseConnectionPlugin, coro: Callable[[...], Awaitable]) -> Self:
return cls(plugin, coro, cls._get_coro_kwargs(plugin, coro))
8 changes: 4 additions & 4 deletions src/HABApp/core/connections/status_transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def _next_step(self) -> ConnectionStatus:
return transitions.get(status)

def __repr__(self) -> str:
return f'<{self.__class__.__name__} {self.status} ' \
f'[{"x" if self.error else " "}] Error, ' \
f'[{"x" if self.setup else " "}] Setup>'
return (f'<{self.__class__.__name__} {self.status} '
f'[{"x" if self.error else " "}] Error, '
f'[{"x" if self.setup else " "}] Setup>')

def __eq__(self, other: ConnectionStatus):
def __eq__(self, other: ConnectionStatus) -> bool:
if not isinstance(other, ConnectionStatus):
return NotImplemented
return self.status == other

0 comments on commit d842764

Please sign in to comment.