Skip to content

Commit

Permalink
FIX: Deserialization of sized variable items (#653)
Browse files Browse the repository at this point in the history
* FIX: Deserialization of sized variable items

* Bump version to 0.17.2
  • Loading branch information
HGSilveri authored Mar 1, 2024
1 parent b50a378 commit 83894fe
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 12 deletions.
2 changes: 1 addition & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.17.1
0.17.2
52 changes: 43 additions & 9 deletions pulser-core/pulser/parametrized/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from __future__ import annotations

import collections.abc # To use collections.abc.Sequence
import collections.abc as abc # To use collections.abc.Sequence
import dataclasses
from typing import Any, Iterator, Optional, Union, cast

Expand Down Expand Up @@ -99,15 +99,37 @@ def _to_abstract_repr(self) -> dict[str, str]:
def __str__(self) -> str:
return self.name

def __getitem__(self, key: Union[int, slice]) -> VariableItem:
if not isinstance(key, (int, slice)):
def __getitem__(
self, key: Union[int, slice, abc.Sequence[int]]
) -> VariableItem:
if not isinstance(key, (int, slice, abc.Sequence)):
raise TypeError(f"Invalid key type {type(key)} for '{self.name}'.")
if isinstance(key, int):
if not -self.size <= key < self.size:
raise IndexError(f"{key} outside of range for '{self.name}'.")
bad_ind = None
if isinstance(key, int) and not -self.size <= key < self.size:
bad_ind = key
elif isinstance(key, abc.Sequence):
for ind_ in key:
if not isinstance(ind_, int):
raise TypeError(
f"Invalid index type {type(ind_)} for variable "
f"'{self.name}'."
)
if not -self.size <= ind_ < self.size:
bad_ind = ind_
break
else:
key = list(key)
if bad_ind is not None:
raise IndexError(
f"Index {bad_ind} out of bounds for variable '{self.name}' "
f"with size {self.size}."
)

return VariableItem(self, key)

# NOTE: __len__ cannot be defined because it makes numpy.ufuncs convert a
# Variable into an array of VariableItem's

def __iter__(self) -> Iterator[VariableItem]:
for i in range(self.size):
yield self[i]
Expand All @@ -118,7 +140,7 @@ class VariableItem(Parametrized, OpSupport):
"""Stores access to items of a variable with multiple values."""

var: Variable
key: Union[int, slice]
key: Union[int, slice, abc.Sequence[int]]

@property
def variables(self) -> dict[str, Variable]:
Expand All @@ -127,15 +149,22 @@ def variables(self) -> dict[str, Variable]:

def build(self) -> Union[ArrayLike, float, int]:
"""Return the variable's item(s) values."""
return cast(collections.abc.Sequence, self.var.build())[self.key]
built_var = cast(abc.Sequence, self.var.build())
if isinstance(self.key, abc.Sequence):
return [built_var[k] for k in self.key]
return built_var[self.key]

def _to_dict(self) -> dict[str, Any]:
return obj_to_dict(
self, self.var, self.key, _module="operator", _name="getitem"
)

def _to_abstract_repr(self) -> dict[str, Any]:
indices = list(range(self.var.size))[self.key]
indices: int | list[int]
if isinstance(self.key, abc.Sequence):
indices = list(self.key)
else:
indices = list(range(self.var.size))[self.key]
return {"expression": "index", "lhs": self.var, "rhs": indices}

def __str__(self) -> str:
Expand All @@ -148,3 +177,8 @@ def __str__(self) -> str:
else:
key_str = str(self.key)
return f"{str(self.var)}[{key_str}]"

def __len__(self) -> int:
if isinstance(self.key, int):
raise TypeError(f"len() of unsized variable item '{self!s}'.")
return len(np.arange(self.var.size)[self.key])
36 changes: 35 additions & 1 deletion tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,25 @@ def test_paramobj_serialization(self, sequence):
times=[0.0, 0.5, 1.0],
)

ser_inv_list_var_items = {
"expression": "index",
"lhs": {"variable": "list_var"},
"rhs": [2, 1, 0],
}
s = json.dumps(
InterpolatedWaveform(var, list_var[::-1]), cls=AbstractReprEncoder
)
assert json.loads(s) == dict(
kind="interpolated",
duration=ser_var,
values=ser_inv_list_var_items,
times=[0.0, 0.5, 1.0],
)
assert s == json.dumps(
InterpolatedWaveform(var, list_var[[2, 1, 0]]),
cls=AbstractReprEncoder,
)

err_msg = (
"An InterpolatedWaveform with 'values' of unknown length "
"and unspecified 'times' can't be serialized to the abstract"
Expand Down Expand Up @@ -2023,6 +2042,16 @@ def test_deserialize_parametrized_waveform(self, wf_obj):
{"expression": "cos", "lhs": var1},
{"expression": "tan", "lhs": {"variable": "var1"}},
{"expression": "index", "lhs": {"variable": "var1"}, "rhs": 0},
{
"expression": "index",
"lhs": {"variable": "var2"},
"rhs": [1, 2],
},
{
"expression": "index",
"lhs": {"variable": "var2"},
"rhs": [4, 2, 0],
},
{"expression": "add", "lhs": var1, "rhs": 0.5},
{"expression": "sub", "lhs": {"variable": "var1"}, "rhs": 0.5},
{"expression": "mul", "lhs": {"variable": "var1"}, "rhs": 0.5},
Expand Down Expand Up @@ -2058,6 +2087,7 @@ def test_deserialize_param(self, json_param):
],
variables={
"var1": {"type": "float", "value": [1.5]},
"var2": {"type": "int", "value": [0, 1, 2, 3, 4]},
},
)
# Note: If built, some of these sequences will be invalid
Expand All @@ -2072,6 +2102,7 @@ def test_deserialize_param(self, json_param):
_check_roundtrip(s)
seq = Sequence.from_abstract_repr(json.dumps(s))
seq_var1 = seq._variables["var1"]
seq_var2 = seq._variables["var2"]

# init + declare channels + 1 operation
offset = 1 + len(s["channels"])
Expand Down Expand Up @@ -2111,7 +2142,10 @@ def test_deserialize_param(self, json_param):
assert param == np.tan(seq_var1)

if expression == "index":
assert param == seq_var1[rhs]
if json_param["lhs"] == {"variable": "var1"}:
assert param == seq_var1[rhs]
else:
assert param == seq_var2[rhs]
if expression == "add":
assert param == seq_var1[0] + rhs
if expression == "sub":
Expand Down
18 changes: 17 additions & 1 deletion tests/test_parametrized.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import json
import re
from dataclasses import FrozenInstanceError

import numpy as np
Expand Down Expand Up @@ -87,26 +88,41 @@ def test_var(a, b):
assert np.all(var_.build() == np.array([1, 2]))

with pytest.raises(TypeError, match="Invalid key type"):
b[[0, 1]]
b[{0, 1}]
with pytest.raises(TypeError, match="Invalid index type"):
b[[0.0, -1.0]]
with pytest.raises(IndexError):
b[2]
with pytest.raises(IndexError):
b[[-3, 1]]


def test_varitem(a, b, d):
a0 = a[0]
b1 = b[1]
b01 = b[100::-1]
b01_2 = b[[-1, -2]]
b01_3 = b[(1, 0)]
d0 = d[0]
assert b01.variables == {"b": b}
assert str(a0) == "a[0]"
assert str(b1) == "b[1]"
assert str(b01) == "b[100::-1]"
assert str(b01_2) == "b[[-1, -2]]"
assert str(b01_3) == "b[[1, 0]]"
assert str(d0) == "d[0]"
assert b1.build() == 1
assert np.all(b01.build() == np.array([1, -1]))
assert d0.build() == 0.5
with pytest.raises(FrozenInstanceError):
b1.key = 0
np.testing.assert_equal(b01.build(), b01_2.build())
np.testing.assert_equal(b01_2.build(), b01_3.build())
with pytest.raises(
TypeError, match=re.escape("len() of unsized variable item 'b[1]'")
):
len(b1)
assert len(b01) == len(b01_2) == len(b01_3) == b.size == 2


def test_paramobj(bwf, t, a, b):
Expand Down

0 comments on commit 83894fe

Please sign in to comment.