Skip to content

Commit

Permalink
Improve registry store data typing (home-assistant#115066)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p authored Apr 7, 2024
1 parent a093690 commit cb93521
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 54 deletions.
58 changes: 37 additions & 21 deletions homeassistant/helpers/area_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@
STORAGE_VERSION_MINOR = 6


class _AreaStoreData(TypedDict):
"""Data type for individual area. Used in AreasRegistryStoreData."""

aliases: list[str]
floor_id: str | None
icon: str | None
id: str
labels: list[str]
name: str
picture: str | None


class AreasRegistryStoreData(TypedDict):
"""Store data type for AreaRegistry."""

areas: list[_AreaStoreData]


class EventAreaRegistryUpdatedData(TypedDict):
"""EventAreaRegistryUpdated data."""

Expand All @@ -45,15 +63,15 @@ class AreaEntry(NormalizedNameBaseRegistryEntry):
picture: str | None


class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
class AreaRegistryStore(Store[AreasRegistryStoreData]):
"""Store area registry data."""

async def _async_migrate_func(
self,
old_major_version: int,
old_minor_version: int,
old_data: dict[str, list[dict[str, Any]]],
) -> dict[str, Any]:
) -> AreasRegistryStoreData:
"""Migrate to the new version."""
if old_major_version < 2:
if old_minor_version < 2:
Expand Down Expand Up @@ -84,7 +102,7 @@ async def _async_migrate_func(

if old_major_version > 1:
raise NotImplementedError
return old_data
return old_data # type: ignore[return-value]


class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
Expand Down Expand Up @@ -126,7 +144,7 @@ def get_areas_for_floor(self, floor: str) -> list[AreaEntry]:
return [data[key] for key in self._floors_index.get(floor, ())]


class AreaRegistry(BaseRegistry):
class AreaRegistry(BaseRegistry[AreasRegistryStoreData]):
"""Class to hold a registry of areas."""

areas: AreaRegistryItems
Expand Down Expand Up @@ -314,24 +332,22 @@ async def async_load(self) -> None:
self._area_data = areas.data

@callback
def _data_to_save(self) -> dict[str, list[dict[str, Any]]]:
def _data_to_save(self) -> AreasRegistryStoreData:
"""Return data of area registry to store in a file."""
data = {}

data["areas"] = [
{
"aliases": list(entry.aliases),
"floor_id": entry.floor_id,
"icon": entry.icon,
"id": entry.id,
"labels": list(entry.labels),
"name": entry.name,
"picture": entry.picture,
}
for entry in self.areas.values()
]

return data
return {
"areas": [
{
"aliases": list(entry.aliases),
"floor_id": entry.floor_id,
"icon": entry.icon,
"id": entry.id,
"labels": list(entry.labels),
"name": entry.name,
"picture": entry.picture,
}
for entry in self.areas.values()
]
}

def _generate_area_id(self, name: str) -> str:
"""Generate area ID."""
Expand Down
20 changes: 17 additions & 3 deletions homeassistant/helpers/category_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@
STORAGE_VERSION_MAJOR = 1


class _CategoryStoreData(TypedDict):
"""Data type for individual category. Used in CategoryRegistryStoreData."""

category_id: str
icon: str | None
name: str


class CategoryRegistryStoreData(TypedDict):
"""Store data type for CategoryRegistry."""

categories: dict[str, list[_CategoryStoreData]]


class EventCategoryRegistryUpdatedData(TypedDict):
"""Event data for when the category registry is updated."""

Expand All @@ -40,14 +54,14 @@ class CategoryEntry:
name: str


class CategoryRegistry(BaseRegistry):
class CategoryRegistry(BaseRegistry[CategoryRegistryStoreData]):
"""Class to hold a registry of categories by scope."""

def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the category registry."""
self.hass = hass
self.categories: dict[str, dict[str, CategoryEntry]] = {}
self._store: Store[dict[str, dict[str, list[dict[str, str]]]]] = Store(
self._store = Store(
hass,
STORAGE_VERSION_MAJOR,
STORAGE_KEY,
Expand Down Expand Up @@ -167,7 +181,7 @@ async def async_load(self) -> None:
self.categories = category_entries

@callback
def _data_to_save(self) -> dict[str, dict[str, list[dict[str, str | None]]]]:
def _data_to_save(self) -> CategoryRegistryStoreData:
"""Return data of category registry to store in a file."""
return {
"categories": {
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/helpers/device_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def get_devices_for_config_entry_id(
]


class DeviceRegistry(BaseRegistry):
class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
"""Class to hold a registry of devices."""

devices: ActiveDeviceRegistryItems
Expand Down
41 changes: 24 additions & 17 deletions homeassistant/helpers/floor_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Iterable
import dataclasses
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, TypedDict, cast
from typing import Literal, TypedDict, cast

from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.util import slugify
Expand All @@ -25,6 +25,22 @@
STORAGE_VERSION_MAJOR = 1


class _FloorStoreData(TypedDict):
"""Data type for individual floor. Used in FloorRegistryStoreData."""

aliases: list[str]
floor_id: str
icon: str | None
level: int | None
name: str


class FloorRegistryStoreData(TypedDict):
"""Store data type for FloorRegistry."""

floors: list[_FloorStoreData]


class EventFloorRegistryUpdatedData(TypedDict):
"""Event data for when the floor registry is updated."""

Expand All @@ -45,7 +61,7 @@ class FloorEntry(NormalizedNameBaseRegistryEntry):
level: int | None = None


class FloorRegistry(BaseRegistry):
class FloorRegistry(BaseRegistry[FloorRegistryStoreData]):
"""Class to hold a registry of floors."""

floors: NormalizedNameBaseRegistryItems[FloorEntry]
Expand All @@ -54,13 +70,11 @@ class FloorRegistry(BaseRegistry):
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the floor registry."""
self.hass = hass
self._store: Store[dict[str, list[dict[str, str | int | list[str] | None]]]] = (
Store(
hass,
STORAGE_VERSION_MAJOR,
STORAGE_KEY,
atomic_writes=True,
)
self._store = Store(
hass,
STORAGE_VERSION_MAJOR,
STORAGE_KEY,
atomic_writes=True,
)

@callback
Expand Down Expand Up @@ -190,13 +204,6 @@ async def async_load(self) -> None:

if data is not None:
for floor in data["floors"]:
if TYPE_CHECKING:
assert isinstance(floor["aliases"], list)
assert isinstance(floor["icon"], str)
assert isinstance(floor["level"], int)
assert isinstance(floor["name"], str)
assert isinstance(floor["floor_id"], str)

normalized_name = normalize_name(floor["name"])
floors[floor["floor_id"]] = FloorEntry(
aliases=set(floor["aliases"]),
Expand All @@ -211,7 +218,7 @@ async def async_load(self) -> None:
self._floor_data = floors.data

@callback
def _data_to_save(self) -> dict[str, list[dict[str, str | int | list[str] | None]]]:
def _data_to_save(self) -> FloorRegistryStoreData:
"""Return data of floor registry to store in a file."""
return {
"floors": [
Expand Down
26 changes: 19 additions & 7 deletions homeassistant/helpers/label_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@
STORAGE_VERSION_MAJOR = 1


class _LabelStoreData(TypedDict):
"""Data type for individual label. Used in LabelRegistryStoreData."""

color: str | None
description: str | None
icon: str | None
label_id: str
name: str


class LabelRegistryStoreData(TypedDict):
"""Store data type for LabelRegistry."""

labels: list[_LabelStoreData]


class EventLabelRegistryUpdatedData(TypedDict):
"""Event data for when the label registry is updated."""

Expand All @@ -45,7 +61,7 @@ class LabelEntry(NormalizedNameBaseRegistryEntry):
icon: str | None = None


class LabelRegistry(BaseRegistry):
class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
"""Class to hold a registry of labels."""

labels: NormalizedNameBaseRegistryItems[LabelEntry]
Expand All @@ -54,7 +70,7 @@ class LabelRegistry(BaseRegistry):
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the label registry."""
self.hass = hass
self._store: Store[dict[str, list[dict[str, str | None]]]] = Store(
self._store = Store(
hass,
STORAGE_VERSION_MAJOR,
STORAGE_KEY,
Expand Down Expand Up @@ -189,10 +205,6 @@ async def async_load(self) -> None:

if data is not None:
for label in data["labels"]:
# Check if the necessary keys are present
if label["label_id"] is None or label["name"] is None:
continue

normalized_name = normalize_name(label["name"])
labels[label["label_id"]] = LabelEntry(
color=label["color"],
Expand All @@ -207,7 +219,7 @@ async def async_load(self) -> None:
self._label_data = labels.data

@callback
def _data_to_save(self) -> dict[str, list[dict[str, str | None]]]:
def _data_to_save(self) -> LabelRegistryStoreData:
"""Return data of label registry to store in a file."""
return {
"labels": [
Expand Down
11 changes: 6 additions & 5 deletions homeassistant/helpers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import ValuesView
from typing import TYPE_CHECKING, Any, Literal, TypeVar
from collections.abc import Mapping, Sequence, ValuesView
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar

from homeassistant.core import CoreState, HomeAssistant, callback

Expand All @@ -17,6 +17,7 @@


_DataT = TypeVar("_DataT")
_StoreDataT = TypeVar("_StoreDataT", bound=Mapping[str, Any] | Sequence[Any])


class BaseRegistryItems(UserDict[str, _DataT], ABC):
Expand Down Expand Up @@ -64,11 +65,11 @@ def __delitem__(self, key: str) -> None:
super().__delitem__(key)


class BaseRegistry(ABC):
class BaseRegistry(ABC, Generic[_StoreDataT]):
"""Class to implement a registry."""

hass: HomeAssistant
_store: Store
_store: Store[_StoreDataT]

@callback
def async_schedule_save(self) -> None:
Expand All @@ -80,5 +81,5 @@ def async_schedule_save(self) -> None:

@callback
@abstractmethod
def _data_to_save(self) -> dict[str, Any]:
def _data_to_save(self) -> _StoreDataT:
"""Return data of registry to store in a file."""

0 comments on commit cb93521

Please sign in to comment.