Skip to content

Commit

Permalink
Fix saving subclassed datetime objects in storage (home-assistant#97502)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Jul 31, 2023
1 parent c2e9fd8 commit 094f2cb
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
2 changes: 2 additions & 0 deletions homeassistant/helpers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def json_encoder_default(obj: Any) -> Any:
return obj.as_dict()
if isinstance(obj, Path):
return obj.as_posix()
if isinstance(obj, datetime.datetime):
return obj.isoformat()
raise TypeError


Expand Down
11 changes: 9 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
storage,
)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder
from homeassistant.helpers.typing import ConfigType, StateType
from homeassistant.setup import setup_component
from homeassistant.util.async_ import run_callback_threadsafe
Expand Down Expand Up @@ -1260,7 +1260,14 @@ async def mock_write_data(
# To ensure that the data can be serialized
_LOGGER.debug("Writing data to %s: %s", store.key, data_to_write)
raise_contains_mocks(data_to_write)
data[store.key] = json.loads(json.dumps(data_to_write, cls=store._encoder))
encoder = store._encoder
if encoder and encoder is not JSONEncoder:
# If they pass a custom encoder that is not the
# default JSONEncoder, we use the slow path of json.dumps
dump = ft.partial(json.dumps, cls=store._encoder)
else:
dump = _orjson_default_encoder
data[store.key] = json.loads(dump(data_to_write))

async def mock_remove(store: storage.Store) -> None:
"""Remove data."""
Expand Down
14 changes: 14 additions & 0 deletions tests/helpers/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,20 @@ def default(self, o):
assert data == "9"


def test_saving_subclassed_datetime(tmp_path: Path) -> None:
"""Test saving subclassed datetime objects."""

class SubClassDateTime(datetime.datetime):
"""Subclass datetime."""

time = SubClassDateTime.fromtimestamp(0)

fname = tmp_path / "test6.json"
save_json(fname, {"time": time})
data = load_json(fname)
assert data == {"time": time.isoformat()}


def test_default_encoder_is_passed(tmp_path: Path) -> None:
"""Test we use orjson if they pass in the default encoder."""
fname = tmp_path / "test6.json"
Expand Down

0 comments on commit 094f2cb

Please sign in to comment.