Skip to content

Commit

Permalink
Merge pull request #3 from Caius-Bonus/MatchCustomCluster
Browse files Browse the repository at this point in the history
Match custom cluster
  • Loading branch information
Caius-Bonus authored Oct 26, 2023
2 parents 44e2691 + 38a0663 commit de2b435
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from zhaquirks.inovelli.types import AllLEDEffectType, SingleLEDEffectType
import zigpy.zcl
from zigpy.zcl.clusters.closures import DoorLock
from zigpy.zcl import clusters

from homeassistant.core import callback
Expand All @@ -25,6 +26,7 @@
UNKNOWN,
)
from . import AttrReportConfig, ClientClusterHandler, ClusterHandler
from .general import MultistateInput
from .homeautomation import Diagnostic
from .hvac import ThermostatClusterHandler, UserInterface

Expand Down Expand Up @@ -381,6 +383,12 @@ class IkeaRemote(ClusterHandler):
REPORT_CONFIG = ()


@registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register(
DoorLock.cluster_id, "xiaomi_aqara_vibration_aq1"
)
class XiaomiVibrationAQ1ClusterHandler(MultistateInput):
"""Xiaomi DoorLock Cluster is in fact a MultiStateInput Cluster."""

def compare_quirk_class(endpoint: Endpoint, names: str | Collection[str]):
"""Return True if the last two words separated by dots equal the words between the dots in name.
Expand Down
18 changes: 18 additions & 0 deletions homeassistant/components/zha/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,24 @@ def decorator(cluster_handler: _TypeT) -> _TypeT:
return decorator


class NestedDictRegistry(dict[int | str, dict[int | str | None, _TypeT]]):
"""Dict Registry of multiple items per key."""

def register(
self, name: int | str, sub_name: int | str | None = None
) -> Callable[[_TypeT], _TypeT]:
"""Return decorator to register item with a specific and a quirk name."""

def decorator(cluster_handler: _TypeT) -> _TypeT:
"""Register decorated cluster handler or item."""
if name not in self:
self[name] = {}
self[name][sub_name] = cluster_handler
return cluster_handler

return decorator


class SetRegistry(set[int | str]):
"""Set Registry of items."""

Expand Down
15 changes: 13 additions & 2 deletions homeassistant/components/zha/core/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,20 @@ def handle_on_off_output_cluster_exception(self, endpoint: Endpoint) -> None:
if platform is None:
continue

cluster_handler_class = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, ClusterHandler
cluster_handler_classes = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, {None: ClusterHandler}
)

quirk_id = (
endpoint.device.quirk_id
if endpoint.device.quirk_id in cluster_handler_classes
else None
)

cluster_handler_class = cluster_handler_classes.get(
quirk_id, ClusterHandler
)

cluster_handler = cluster_handler_class(cluster, endpoint)
self.probe_single_cluster(platform, cluster_handler, endpoint)

Expand Down
23 changes: 10 additions & 13 deletions homeassistant/components/zha/core/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
from typing import TYPE_CHECKING, Any, Final, TypeVar

import zigpy
from zigpy.typing import EndpointType as ZigpyEndpointType

from homeassistant.const import Platform
Expand All @@ -15,7 +14,6 @@

from . import const, discovery, registries
from .cluster_handlers import ClusterHandler
from .cluster_handlers.general import MultistateInput
from .helpers import get_zha_data

if TYPE_CHECKING:
Expand Down Expand Up @@ -116,8 +114,16 @@ def new(cls, zigpy_endpoint: ZigpyEndpointType, device: ZHADevice) -> Endpoint:
def add_all_cluster_handlers(self) -> None:
"""Create and add cluster handlers for all input clusters."""
for cluster_id, cluster in self.zigpy_endpoint.in_clusters.items():
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, ClusterHandler
cluster_handler_classes = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, {None: ClusterHandler}
)
quirk_id = (
self.device.quirk_id
if self.device.quirk_id in cluster_handler_classes
else None
)
cluster_handler_class = cluster_handler_classes.get(
quirk_id, ClusterHandler
)

# Allow cluster handler to filter out bad matches
Expand All @@ -129,15 +135,6 @@ def add_all_cluster_handlers(self) -> None:
cluster_id,
cluster_handler_class,
)
# really ugly hack to deal with xiaomi using the door lock cluster
# incorrectly.
if (
hasattr(cluster, "ep_attribute")
and cluster_id == zigpy.zcl.clusters.closures.DoorLock.cluster_id
and cluster.ep_attribute == "multistate_input"
):
cluster_handler_class = MultistateInput
# end of ugly hack

try:
cluster_handler = cluster_handler_class(cluster, self)
Expand Down
6 changes: 4 additions & 2 deletions homeassistant/components/zha/core/registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from homeassistant.const import Platform

from .decorators import DictRegistry, SetRegistry
from .decorators import DictRegistry, NestedDictRegistry, SetRegistry

if TYPE_CHECKING:
from ..entity import ZhaEntity, ZhaGroupEntity
Expand Down Expand Up @@ -110,7 +110,9 @@
CLIENT_CLUSTER_HANDLER_REGISTRY: DictRegistry[
type[ClientClusterHandler]
] = DictRegistry()
ZIGBEE_CLUSTER_HANDLER_REGISTRY: DictRegistry[type[ClusterHandler]] = DictRegistry()
ZIGBEE_CLUSTER_HANDLER_REGISTRY: NestedDictRegistry[
type[ClusterHandler]
] = NestedDictRegistry()

WEIGHT_ATTR = attrgetter("weight")

Expand Down
100 changes: 90 additions & 10 deletions tests/components/zha/test_cluster_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable
import logging
import math
from types import NoneType
from unittest import mock
from unittest.mock import AsyncMock, patch

Expand All @@ -17,6 +18,9 @@
import zigpy.zdo.types as zdo_t

import homeassistant.components.zha.core.cluster_handlers as cluster_handlers
from homeassistant.components.zha.core.cluster_handlers.lighting import (
ColorClusterHandler,
)
import homeassistant.components.zha.core.const as zha_const
from homeassistant.components.zha.core.device import ZHADevice
from homeassistant.components.zha.core.endpoint import Endpoint
Expand Down Expand Up @@ -97,7 +101,9 @@ def poll_control_ch(endpoint, zigpy_device_mock):
)

cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id]
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(cluster_id)
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id
).get(None)
return cluster_handler_class(cluster, endpoint)


Expand Down Expand Up @@ -258,8 +264,8 @@ async def test_in_cluster_handler_config(

cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id]
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, cluster_handlers.ClusterHandler
)
cluster_id, {None, cluster_handlers.ClusterHandler}
).get(None)
cluster_handler = cluster_handler_class(cluster, endpoint)

await cluster_handler.async_configure()
Expand Down Expand Up @@ -322,8 +328,8 @@ async def test_out_cluster_handler_config(
cluster = zigpy_dev.endpoints[1].out_clusters[cluster_id]
cluster.bind_only = True
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, cluster_handlers.ClusterHandler
)
cluster_id, {None: cluster_handlers.ClusterHandler}
).get(None)
cluster_handler = cluster_handler_class(cluster, endpoint)

await cluster_handler.async_configure()
Expand All @@ -336,11 +342,14 @@ def test_cluster_handler_registry() -> None:
"""Test ZIGBEE cluster handler Registry."""
for (
cluster_id,
cluster_handler,
cluster_handler_classes,
) in registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.items():
assert isinstance(cluster_id, int)
assert 0 <= cluster_id <= 0xFFFF
assert issubclass(cluster_handler, cluster_handlers.ClusterHandler)
assert isinstance(cluster_handler_classes, dict)
for quirk_id, cluster_handler in cluster_handler_classes.items():
assert isinstance(quirk_id, NoneType) or isinstance(quirk_id, str)
assert issubclass(cluster_handler, cluster_handlers.ClusterHandler)


def test_epch_unclaimed_cluster_handlers(cluster_handler) -> None:
Expand Down Expand Up @@ -818,7 +827,8 @@ class TestZigbeeClusterHandler(cluster_handlers.ClusterHandler):
],
)

mock_zha_device = mock.AsyncMock(spec_set=ZHADevice)
mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = None
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)

# The cluster handler throws an error when matching this cluster
Expand All @@ -827,14 +837,84 @@ class TestZigbeeClusterHandler(cluster_handlers.ClusterHandler):

# And one is also logged at runtime
with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY,
{cluster.cluster_id: TestZigbeeClusterHandler},
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{None: TestZigbeeClusterHandler},
), caplog.at_level(logging.WARNING):
zha_endpoint.add_all_cluster_handlers()

assert "missing_attr" in caplog.text


async def test_standard_cluster_handler(hass: HomeAssistant, caplog) -> None:
"""Test setting up a cluster handler that matches a standard cluster."""

class TestZigbeeClusterHandler(ColorClusterHandler):
pass

mock_device = mock.AsyncMock(spec_set=zigpy.device.Device)
zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1)

cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id)
cluster.configure_reporting_multiple = AsyncMock(
spec_set=cluster.configure_reporting_multiple,
return_value=[
foundation.ConfigureReportingResponseRecord(
status=foundation.Status.SUCCESS
)
],
)

mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = None
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)

with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{"__test_quirk_id": TestZigbeeClusterHandler},
):
zha_endpoint.add_all_cluster_handlers()

assert len(zha_endpoint.all_cluster_handlers) == 1
assert isinstance(
list(zha_endpoint.all_cluster_handlers.values())[0], ColorClusterHandler
)


async def test_quirk_id_cluster_handler(hass: HomeAssistant, caplog) -> None:
"""Test setting up a cluster handler that matches a standard cluster."""

class TestZigbeeClusterHandler(ColorClusterHandler):
pass

mock_device = mock.AsyncMock(spec_set=zigpy.device.Device)
zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1)

cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id)
cluster.configure_reporting_multiple = AsyncMock(
spec_set=cluster.configure_reporting_multiple,
return_value=[
foundation.ConfigureReportingResponseRecord(
status=foundation.Status.SUCCESS
)
],
)

mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = "__test_quirk_id"
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)

with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{"__test_quirk_id": TestZigbeeClusterHandler},
):
zha_endpoint.add_all_cluster_handlers()

assert len(zha_endpoint.all_cluster_handlers) == 1
assert isinstance(
list(zha_endpoint.all_cluster_handlers.values())[0], TestZigbeeClusterHandler
)


# parametrize side effects:
@pytest.mark.parametrize(
("side_effect", "expected_error"),
Expand Down

0 comments on commit de2b435

Please sign in to comment.