diff --git a/homeassistant/components/zha/core/cluster_handlers/manufacturerspecific.py b/homeassistant/components/zha/core/cluster_handlers/manufacturerspecific.py index bda2834fbd225e..167e502751dc43 100644 --- a/homeassistant/components/zha/core/cluster_handlers/manufacturerspecific.py +++ b/homeassistant/components/zha/core/cluster_handlers/manufacturerspecific.py @@ -7,6 +7,7 @@ from zhaquirks.inovelli.types import AllLEDEffectType, SingleLEDEffectType import zigpy.zcl +from zigpy.zcl.clusters.closures import DoorLock from zigpy.zcl import clusters from homeassistant.core import callback @@ -25,6 +26,7 @@ UNKNOWN, ) from . import AttrReportConfig, ClientClusterHandler, ClusterHandler +from .general import MultistateInput from .homeautomation import Diagnostic from .hvac import ThermostatClusterHandler, UserInterface @@ -381,6 +383,12 @@ class IkeaRemote(ClusterHandler): REPORT_CONFIG = () +@registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register( + DoorLock.cluster_id, "xiaomi_aqara_vibration_aq1" +) +class XiaomiVibrationAQ1ClusterHandler(MultistateInput): + """Xiaomi DoorLock Cluster is in fact a MultiStateInput Cluster.""" + def compare_quirk_class(endpoint: Endpoint, names: str | Collection[str]): """Return True if the last two words separated by dots equal the words between the dots in name. diff --git a/homeassistant/components/zha/core/decorators.py b/homeassistant/components/zha/core/decorators.py index 71bfd510bea09e..192f68489895ef 100644 --- a/homeassistant/components/zha/core/decorators.py +++ b/homeassistant/components/zha/core/decorators.py @@ -21,6 +21,24 @@ def decorator(cluster_handler: _TypeT) -> _TypeT: return decorator +class NestedDictRegistry(dict[int | str, dict[int | str | None, _TypeT]]): + """Dict Registry of multiple items per key.""" + + def register( + self, name: int | str, sub_name: int | str | None = None + ) -> Callable[[_TypeT], _TypeT]: + """Return decorator to register item with a specific and a quirk name.""" + + def decorator(cluster_handler: _TypeT) -> _TypeT: + """Register decorated cluster handler or item.""" + if name not in self: + self[name] = {} + self[name][sub_name] = cluster_handler + return cluster_handler + + return decorator + + class SetRegistry(set[int | str]): """Set Registry of items.""" diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index 90ed68f9b00da4..1944f632e9a699 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -203,9 +203,20 @@ def handle_on_off_output_cluster_exception(self, endpoint: Endpoint) -> None: if platform is None: continue - cluster_handler_class = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( - cluster_id, ClusterHandler + cluster_handler_classes = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( + cluster_id, {None: ClusterHandler} ) + + quirk_id = ( + endpoint.device.quirk_id + if endpoint.device.quirk_id in cluster_handler_classes + else None + ) + + cluster_handler_class = cluster_handler_classes.get( + quirk_id, ClusterHandler + ) + cluster_handler = cluster_handler_class(cluster, endpoint) self.probe_single_cluster(platform, cluster_handler, endpoint) diff --git a/homeassistant/components/zha/core/endpoint.py b/homeassistant/components/zha/core/endpoint.py index c87ee60d6b30d6..04c253128ee5b1 100644 --- a/homeassistant/components/zha/core/endpoint.py +++ b/homeassistant/components/zha/core/endpoint.py @@ -6,7 +6,6 @@ import logging from typing import TYPE_CHECKING, Any, Final, TypeVar -import zigpy from zigpy.typing import EndpointType as ZigpyEndpointType from homeassistant.const import Platform @@ -15,7 +14,6 @@ from . import const, discovery, registries from .cluster_handlers import ClusterHandler -from .cluster_handlers.general import MultistateInput from .helpers import get_zha_data if TYPE_CHECKING: @@ -116,8 +114,16 @@ def new(cls, zigpy_endpoint: ZigpyEndpointType, device: ZHADevice) -> Endpoint: def add_all_cluster_handlers(self) -> None: """Create and add cluster handlers for all input clusters.""" for cluster_id, cluster in self.zigpy_endpoint.in_clusters.items(): - cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( - cluster_id, ClusterHandler + cluster_handler_classes = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( + cluster_id, {None: ClusterHandler} + ) + quirk_id = ( + self.device.quirk_id + if self.device.quirk_id in cluster_handler_classes + else None + ) + cluster_handler_class = cluster_handler_classes.get( + quirk_id, ClusterHandler ) # Allow cluster handler to filter out bad matches @@ -129,15 +135,6 @@ def add_all_cluster_handlers(self) -> None: cluster_id, cluster_handler_class, ) - # really ugly hack to deal with xiaomi using the door lock cluster - # incorrectly. - if ( - hasattr(cluster, "ep_attribute") - and cluster_id == zigpy.zcl.clusters.closures.DoorLock.cluster_id - and cluster.ep_attribute == "multistate_input" - ): - cluster_handler_class = MultistateInput - # end of ugly hack try: cluster_handler = cluster_handler_class(cluster, self) diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index 4bdedebfff9ff0..418fdbc4918af6 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -15,7 +15,7 @@ from homeassistant.const import Platform -from .decorators import DictRegistry, SetRegistry +from .decorators import DictRegistry, NestedDictRegistry, SetRegistry if TYPE_CHECKING: from ..entity import ZhaEntity, ZhaGroupEntity @@ -110,7 +110,9 @@ CLIENT_CLUSTER_HANDLER_REGISTRY: DictRegistry[ type[ClientClusterHandler] ] = DictRegistry() -ZIGBEE_CLUSTER_HANDLER_REGISTRY: DictRegistry[type[ClusterHandler]] = DictRegistry() +ZIGBEE_CLUSTER_HANDLER_REGISTRY: NestedDictRegistry[ + type[ClusterHandler] +] = NestedDictRegistry() WEIGHT_ATTR = attrgetter("weight") diff --git a/tests/components/zha/test_cluster_handlers.py b/tests/components/zha/test_cluster_handlers.py index 24162296cd504a..d705fff76a6f7a 100644 --- a/tests/components/zha/test_cluster_handlers.py +++ b/tests/components/zha/test_cluster_handlers.py @@ -3,6 +3,7 @@ from collections.abc import Callable import logging import math +from types import NoneType from unittest import mock from unittest.mock import AsyncMock, patch @@ -17,6 +18,9 @@ import zigpy.zdo.types as zdo_t import homeassistant.components.zha.core.cluster_handlers as cluster_handlers +from homeassistant.components.zha.core.cluster_handlers.lighting import ( + ColorClusterHandler, +) import homeassistant.components.zha.core.const as zha_const from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.endpoint import Endpoint @@ -97,7 +101,9 @@ def poll_control_ch(endpoint, zigpy_device_mock): ) cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id] - cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(cluster_id) + cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( + cluster_id + ).get(None) return cluster_handler_class(cluster, endpoint) @@ -258,8 +264,8 @@ async def test_in_cluster_handler_config( cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id] cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( - cluster_id, cluster_handlers.ClusterHandler - ) + cluster_id, {None, cluster_handlers.ClusterHandler} + ).get(None) cluster_handler = cluster_handler_class(cluster, endpoint) await cluster_handler.async_configure() @@ -322,8 +328,8 @@ async def test_out_cluster_handler_config( cluster = zigpy_dev.endpoints[1].out_clusters[cluster_id] cluster.bind_only = True cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( - cluster_id, cluster_handlers.ClusterHandler - ) + cluster_id, {None: cluster_handlers.ClusterHandler} + ).get(None) cluster_handler = cluster_handler_class(cluster, endpoint) await cluster_handler.async_configure() @@ -336,11 +342,14 @@ def test_cluster_handler_registry() -> None: """Test ZIGBEE cluster handler Registry.""" for ( cluster_id, - cluster_handler, + cluster_handler_classes, ) in registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.items(): assert isinstance(cluster_id, int) assert 0 <= cluster_id <= 0xFFFF - assert issubclass(cluster_handler, cluster_handlers.ClusterHandler) + assert isinstance(cluster_handler_classes, dict) + for quirk_id, cluster_handler in cluster_handler_classes.items(): + assert isinstance(quirk_id, NoneType) or isinstance(quirk_id, str) + assert issubclass(cluster_handler, cluster_handlers.ClusterHandler) def test_epch_unclaimed_cluster_handlers(cluster_handler) -> None: @@ -818,7 +827,8 @@ class TestZigbeeClusterHandler(cluster_handlers.ClusterHandler): ], ) - mock_zha_device = mock.AsyncMock(spec_set=ZHADevice) + mock_zha_device = mock.AsyncMock(spec=ZHADevice) + mock_zha_device.quirk_id = None zha_endpoint = Endpoint(zigpy_ep, mock_zha_device) # The cluster handler throws an error when matching this cluster @@ -827,14 +837,84 @@ class TestZigbeeClusterHandler(cluster_handlers.ClusterHandler): # And one is also logged at runtime with patch.dict( - registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY, - {cluster.cluster_id: TestZigbeeClusterHandler}, + registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id], + {None: TestZigbeeClusterHandler}, ), caplog.at_level(logging.WARNING): zha_endpoint.add_all_cluster_handlers() assert "missing_attr" in caplog.text +async def test_standard_cluster_handler(hass: HomeAssistant, caplog) -> None: + """Test setting up a cluster handler that matches a standard cluster.""" + + class TestZigbeeClusterHandler(ColorClusterHandler): + pass + + mock_device = mock.AsyncMock(spec_set=zigpy.device.Device) + zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1) + + cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id) + cluster.configure_reporting_multiple = AsyncMock( + spec_set=cluster.configure_reporting_multiple, + return_value=[ + foundation.ConfigureReportingResponseRecord( + status=foundation.Status.SUCCESS + ) + ], + ) + + mock_zha_device = mock.AsyncMock(spec=ZHADevice) + mock_zha_device.quirk_id = None + zha_endpoint = Endpoint(zigpy_ep, mock_zha_device) + + with patch.dict( + registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id], + {"__test_quirk_id": TestZigbeeClusterHandler}, + ): + zha_endpoint.add_all_cluster_handlers() + + assert len(zha_endpoint.all_cluster_handlers) == 1 + assert isinstance( + list(zha_endpoint.all_cluster_handlers.values())[0], ColorClusterHandler + ) + + +async def test_quirk_id_cluster_handler(hass: HomeAssistant, caplog) -> None: + """Test setting up a cluster handler that matches a standard cluster.""" + + class TestZigbeeClusterHandler(ColorClusterHandler): + pass + + mock_device = mock.AsyncMock(spec_set=zigpy.device.Device) + zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1) + + cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id) + cluster.configure_reporting_multiple = AsyncMock( + spec_set=cluster.configure_reporting_multiple, + return_value=[ + foundation.ConfigureReportingResponseRecord( + status=foundation.Status.SUCCESS + ) + ], + ) + + mock_zha_device = mock.AsyncMock(spec=ZHADevice) + mock_zha_device.quirk_id = "__test_quirk_id" + zha_endpoint = Endpoint(zigpy_ep, mock_zha_device) + + with patch.dict( + registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id], + {"__test_quirk_id": TestZigbeeClusterHandler}, + ): + zha_endpoint.add_all_cluster_handlers() + + assert len(zha_endpoint.all_cluster_handlers) == 1 + assert isinstance( + list(zha_endpoint.all_cluster_handlers.values())[0], TestZigbeeClusterHandler + ) + + # parametrize side effects: @pytest.mark.parametrize( ("side_effect", "expected_error"),