Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match custom cluster #3

Merged
merged 14 commits into from
Oct 26, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from zhaquirks.inovelli.types import AllLEDEffectType, SingleLEDEffectType
import zigpy.zcl
from zigpy.zcl.clusters.closures import DoorLock
from zigpy.zcl import clusters

from homeassistant.core import callback
Expand All @@ -25,6 +26,7 @@
UNKNOWN,
)
from . import AttrReportConfig, ClientClusterHandler, ClusterHandler
from .general import MultistateInput
from .homeautomation import Diagnostic
from .hvac import ThermostatClusterHandler, UserInterface

Expand Down Expand Up @@ -381,6 +383,12 @@ class IkeaRemote(ClusterHandler):
REPORT_CONFIG = ()


@registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register(
DoorLock.cluster_id, "xiaomi_aqara_vibration_aq1"
)
class XiaomiVibrationAQ1ClusterHandler(MultistateInput):
"""Xiaomi DoorLock Cluster is in fact a MultiStateInput Cluster."""

def compare_quirk_class(endpoint: Endpoint, names: str | Collection[str]):
"""Return True if the last two words separated by dots equal the words between the dots in name.

Expand Down
18 changes: 18 additions & 0 deletions homeassistant/components/zha/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,24 @@ def decorator(cluster_handler: _TypeT) -> _TypeT:
return decorator


class NestedDictRegistry(dict[int | str, dict[int | str | None, _TypeT]]):
"""Dict Registry of multiple items per key."""

def register(
self, name: int | str, sub_name: int | str | None = None
) -> Callable[[_TypeT], _TypeT]:
"""Return decorator to register item with a specific and a quirk name."""

def decorator(cluster_handler: _TypeT) -> _TypeT:
"""Register decorated cluster handler or item."""
if name not in self:
self[name] = {}
self[name][sub_name] = cluster_handler
return cluster_handler

return decorator


class SetRegistry(set[int | str]):
"""Set Registry of items."""

Expand Down
15 changes: 13 additions & 2 deletions homeassistant/components/zha/core/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,20 @@ def handle_on_off_output_cluster_exception(self, endpoint: Endpoint) -> None:
if platform is None:
continue

cluster_handler_class = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, ClusterHandler
cluster_handler_classes = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, {None: ClusterHandler}
)

quirk_id = (
endpoint.device.quirk_id
if endpoint.device.quirk_id in cluster_handler_classes
else None
)

cluster_handler_class = cluster_handler_classes.get(
quirk_id, ClusterHandler
)

cluster_handler = cluster_handler_class(cluster, endpoint)
self.probe_single_cluster(platform, cluster_handler, endpoint)

Expand Down
23 changes: 10 additions & 13 deletions homeassistant/components/zha/core/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
from typing import TYPE_CHECKING, Any, Final, TypeVar

import zigpy
from zigpy.typing import EndpointType as ZigpyEndpointType

from homeassistant.const import Platform
Expand All @@ -15,7 +14,6 @@

from . import const, discovery, registries
from .cluster_handlers import ClusterHandler
from .cluster_handlers.general import MultistateInput
from .helpers import get_zha_data

if TYPE_CHECKING:
Expand Down Expand Up @@ -116,8 +114,16 @@ def new(cls, zigpy_endpoint: ZigpyEndpointType, device: ZHADevice) -> Endpoint:
def add_all_cluster_handlers(self) -> None:
"""Create and add cluster handlers for all input clusters."""
for cluster_id, cluster in self.zigpy_endpoint.in_clusters.items():
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, ClusterHandler
cluster_handler_classes = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, {None: ClusterHandler}
)
quirk_id = (
self.device.quirk_id
if self.device.quirk_id in cluster_handler_classes
else None
)
cluster_handler_class = cluster_handler_classes.get(
quirk_id, ClusterHandler
)

# Allow cluster handler to filter out bad matches
Expand All @@ -129,15 +135,6 @@ def add_all_cluster_handlers(self) -> None:
cluster_id,
cluster_handler_class,
)
# really ugly hack to deal with xiaomi using the door lock cluster
# incorrectly.
if (
hasattr(cluster, "ep_attribute")
and cluster_id == zigpy.zcl.clusters.closures.DoorLock.cluster_id
and cluster.ep_attribute == "multistate_input"
):
cluster_handler_class = MultistateInput
# end of ugly hack

try:
cluster_handler = cluster_handler_class(cluster, self)
Expand Down
6 changes: 4 additions & 2 deletions homeassistant/components/zha/core/registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from homeassistant.const import Platform

from .decorators import DictRegistry, SetRegistry
from .decorators import DictRegistry, NestedDictRegistry, SetRegistry

if TYPE_CHECKING:
from ..entity import ZhaEntity, ZhaGroupEntity
Expand Down Expand Up @@ -110,7 +110,9 @@
CLIENT_CLUSTER_HANDLER_REGISTRY: DictRegistry[
type[ClientClusterHandler]
] = DictRegistry()
ZIGBEE_CLUSTER_HANDLER_REGISTRY: DictRegistry[type[ClusterHandler]] = DictRegistry()
ZIGBEE_CLUSTER_HANDLER_REGISTRY: NestedDictRegistry[
type[ClusterHandler]
] = NestedDictRegistry()

WEIGHT_ATTR = attrgetter("weight")

Expand Down
100 changes: 90 additions & 10 deletions tests/components/zha/test_cluster_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable
import logging
import math
from types import NoneType
from unittest import mock
from unittest.mock import AsyncMock, patch

Expand All @@ -17,6 +18,9 @@
import zigpy.zdo.types as zdo_t

import homeassistant.components.zha.core.cluster_handlers as cluster_handlers
from homeassistant.components.zha.core.cluster_handlers.lighting import (
ColorClusterHandler,
)
import homeassistant.components.zha.core.const as zha_const
from homeassistant.components.zha.core.device import ZHADevice
from homeassistant.components.zha.core.endpoint import Endpoint
Expand Down Expand Up @@ -97,7 +101,9 @@ def poll_control_ch(endpoint, zigpy_device_mock):
)

cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id]
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(cluster_id)
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id
).get(None)
return cluster_handler_class(cluster, endpoint)


Expand Down Expand Up @@ -258,8 +264,8 @@ async def test_in_cluster_handler_config(

cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id]
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, cluster_handlers.ClusterHandler
)
cluster_id, {None, cluster_handlers.ClusterHandler}
).get(None)
cluster_handler = cluster_handler_class(cluster, endpoint)

await cluster_handler.async_configure()
Expand Down Expand Up @@ -322,8 +328,8 @@ async def test_out_cluster_handler_config(
cluster = zigpy_dev.endpoints[1].out_clusters[cluster_id]
cluster.bind_only = True
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, cluster_handlers.ClusterHandler
)
cluster_id, {None: cluster_handlers.ClusterHandler}
).get(None)
cluster_handler = cluster_handler_class(cluster, endpoint)

await cluster_handler.async_configure()
Expand All @@ -336,11 +342,14 @@ def test_cluster_handler_registry() -> None:
"""Test ZIGBEE cluster handler Registry."""
for (
cluster_id,
cluster_handler,
cluster_handler_classes,
) in registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.items():
assert isinstance(cluster_id, int)
assert 0 <= cluster_id <= 0xFFFF
assert issubclass(cluster_handler, cluster_handlers.ClusterHandler)
assert isinstance(cluster_handler_classes, dict)
for quirk_id, cluster_handler in cluster_handler_classes.items():
assert isinstance(quirk_id, NoneType) or isinstance(quirk_id, str)
assert issubclass(cluster_handler, cluster_handlers.ClusterHandler)


def test_epch_unclaimed_cluster_handlers(cluster_handler) -> None:
Expand Down Expand Up @@ -818,7 +827,8 @@ class TestZigbeeClusterHandler(cluster_handlers.ClusterHandler):
],
)

mock_zha_device = mock.AsyncMock(spec_set=ZHADevice)
mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = None
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)

# The cluster handler throws an error when matching this cluster
Expand All @@ -827,14 +837,84 @@ class TestZigbeeClusterHandler(cluster_handlers.ClusterHandler):

# And one is also logged at runtime
with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY,
{cluster.cluster_id: TestZigbeeClusterHandler},
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{None: TestZigbeeClusterHandler},
), caplog.at_level(logging.WARNING):
zha_endpoint.add_all_cluster_handlers()

assert "missing_attr" in caplog.text


async def test_standard_cluster_handler(hass: HomeAssistant, caplog) -> None:
"""Test setting up a cluster handler that matches a standard cluster."""

class TestZigbeeClusterHandler(ColorClusterHandler):
pass

mock_device = mock.AsyncMock(spec_set=zigpy.device.Device)
zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1)

cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id)
cluster.configure_reporting_multiple = AsyncMock(
spec_set=cluster.configure_reporting_multiple,
return_value=[
foundation.ConfigureReportingResponseRecord(
status=foundation.Status.SUCCESS
)
],
)

mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = None
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)

with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{"__test_quirk_id": TestZigbeeClusterHandler},
):
zha_endpoint.add_all_cluster_handlers()

assert len(zha_endpoint.all_cluster_handlers) == 1
assert isinstance(
list(zha_endpoint.all_cluster_handlers.values())[0], ColorClusterHandler
)


async def test_quirk_id_cluster_handler(hass: HomeAssistant, caplog) -> None:
"""Test setting up a cluster handler that matches a standard cluster."""

class TestZigbeeClusterHandler(ColorClusterHandler):
pass

mock_device = mock.AsyncMock(spec_set=zigpy.device.Device)
zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1)

cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id)
cluster.configure_reporting_multiple = AsyncMock(
spec_set=cluster.configure_reporting_multiple,
return_value=[
foundation.ConfigureReportingResponseRecord(
status=foundation.Status.SUCCESS
)
],
)

mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = "__test_quirk_id"
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)

with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{"__test_quirk_id": TestZigbeeClusterHandler},
):
zha_endpoint.add_all_cluster_handlers()

assert len(zha_endpoint.all_cluster_handlers) == 1
assert isinstance(
list(zha_endpoint.all_cluster_handlers.values())[0], TestZigbeeClusterHandler
)


# parametrize side effects:
@pytest.mark.parametrize(
("side_effect", "expected_error"),
Expand Down
Loading