Skip to content

Commit

Permalink
JSON serialization support for numpy integer types (#617)
Browse files Browse the repository at this point in the history
* Ignore __venv__

* Fix serialization compatiblity for numpy types

* Update pre-commit config

* Ignore .mypy_cache
  • Loading branch information
HGSilveri authored Nov 27, 2023
1 parent 30b8450 commit b41d62a
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
docstring-convention = google
exclude = ./build, ./docs
exclude = ./build, ./docs, ./__venv__
extend-ignore =
# D105 Missing docstring in magic method
D105,
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.vscode
.python-version
.pytest_cache/
.mypy_cache/
.idea/
.coverage
.spyproject/
Expand All @@ -15,3 +16,4 @@ docs/build/
dist/
env*
*.egg-info/
__venv__/
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 22.1.0
rev: 23.10.1
hooks:
- id: black-jupyter

Expand All @@ -10,7 +10,7 @@ repos:
- id: flake8

- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
name: isort (python)
Expand Down
4 changes: 3 additions & 1 deletion pulser-core/pulser/json/abstract_repr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@
class AbstractReprEncoder(json.JSONEncoder):
"""The custom encoder for abstract representation of Pulser objects."""

def default(self, o: Any) -> Union[dict[str, Any], list[Any]]:
def default(self, o: Any) -> dict[str, Any] | list | int:
"""Handles JSON encoding of objects not supported by default."""
if hasattr(o, "_to_abstract_repr"):
return cast(dict, o._to_abstract_repr())
elif isinstance(o, np.ndarray):
return cast(list, o.tolist())
elif isinstance(o, np.integer):
return int(o)
elif isinstance(o, set):
return list(o)
else:
Expand Down
4 changes: 3 additions & 1 deletion pulser-core/pulser/json/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@
class PulserEncoder(JSONEncoder):
"""The custom encoder for Pulser objects."""

def default(self, o: Any) -> dict[str, Any]:
def default(self, o: Any) -> dict[str, Any] | int:
"""Handles JSON encoding of objects not supported by default."""
if hasattr(o, "_to_dict"):
return cast(dict, o._to_dict())
elif type(o) is type:
return obj_to_dict(o, _build=False, _name=o.__name__)
elif isinstance(o, np.ndarray):
return obj_to_dict(o, o.tolist(), _name="array")
elif isinstance(o, np.integer):
return int(o)
elif isinstance(o, set):
return obj_to_dict(o, list(o))
else:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,26 @@ def test_multi_qubit_target(self):
"rhs": 2,
}

def test_numpy_types(self):
assert (
json.loads(
json.dumps(np.array([12345])[0], cls=AbstractReprEncoder)
)
== 12345
)
assert (
json.loads(
json.dumps(np.array([np.pi])[0], cls=AbstractReprEncoder)
)
== np.pi
)
assert (
json.loads(
json.dumps(np.array(["abc"])[0], cls=AbstractReprEncoder)
)
== "abc"
)


def _get_serialized_seq(
operations: list[dict] = [],
Expand Down
6 changes: 6 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ def test_type_error():
Sequence._deserialize(json.loads(s))


def test_numpy_types():
assert encode_decode(np.array([12])[0]) == 12
assert encode_decode(np.array([np.pi])[0]) == np.pi
assert encode_decode(np.array(["abc"])[0]) == "abc"


def test_deprecated_device_args():
seq = Sequence(Register.square(1), MockDevice)

Expand Down

0 comments on commit b41d62a

Please sign in to comment.