Skip to content

Commit

Permalink
Merge branch 'develop' into hs/serializer-updates
Browse files Browse the repository at this point in the history
  • Loading branch information
HGSilveri committed Sep 22, 2023
2 parents f05d491 + c08dfa8 commit 59bf4d0
Show file tree
Hide file tree
Showing 44 changed files with 3,780 additions and 958 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ computers and simulators, check the pages in :doc:`review`.
tutorials/reg_layouts
tutorials/interpolated_wfs
tutorials/serialization
tutorials/dmm
tutorials/slm_mask
tutorials/output_mod_eom
tutorials/virtual_devices
Expand Down
3 changes: 3 additions & 0 deletions docs/source/tutorials/dmm.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"path": "../../../tutorials/advanced_features/Local addressability with DMM.ipynb"
}
18 changes: 3 additions & 15 deletions pulser-core/pulser/backend/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,12 @@ class RemoteResults(Results):
the results.
connection: The remote connection over which to get the submission's
status and fetch the results.
jobs_order: An optional list of job IDs (as stored by the connection)
used to order the results.
"""

def __init__(
self,
submission_id: str,
connection: RemoteConnection,
jobs_order: list[str] | None = None,
):
def __init__(self, submission_id: str, connection: RemoteConnection):
"""Instantiates a new collection of remote results."""
self._submission_id = submission_id
self._connection = connection
self._jobs_order = jobs_order

@property
def results(self) -> tuple[Result, ...]:
Expand All @@ -87,9 +79,7 @@ def __getattr__(self, name: str) -> Any:
status = self.get_status()
if status == SubmissionStatus.DONE:
self._results = tuple(
self._connection._fetch_result(
self._submission_id, self._jobs_order
)
self._connection._fetch_result(self._submission_id)
)
return self._results
raise RemoteResultsError(
Expand All @@ -112,9 +102,7 @@ def submit(
pass

@abstractmethod
def _fetch_result(
self, submission_id: str, jobs_order: list[str] | None
) -> typing.Sequence[Result]:
def _fetch_result(self, submission_id: str) -> typing.Sequence[Result]:
"""Fetches the results of a completed submission."""
pass

Expand Down
3 changes: 1 addition & 2 deletions pulser-core/pulser/channels/base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __post_init__(self) -> None:
"min_avg_amp",
]
non_negative = [
"max_amp",
"max_abs_detuning",
"min_retarget_interval",
"fixed_retarget_t",
Expand Down Expand Up @@ -359,8 +360,6 @@ def validate_pulse(self, pulse: Pulse) -> None:
Args:
pulse: The pulse to validate.
channel_id: The channel ID used to index the chosen channel
on this device.
"""
if not isinstance(pulse, Pulse):
raise TypeError(
Expand Down
67 changes: 62 additions & 5 deletions pulser-core/pulser/channels/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from dataclasses import dataclass, field
from typing import Literal, Optional

import numpy as np

from pulser.channels.base_channel import Channel
from pulser.pulse import Pulse


@dataclass(init=True, repr=False, frozen=True)
Expand Down Expand Up @@ -50,11 +53,11 @@ class DMM(Channel):

bottom_detuning: Optional[float] = field(default=None, init=True)
addressing: Literal["Global"] = field(default="Global", init=False)
max_abs_detuning: Optional[float] = field(init=False, default=None)
max_amp: float = field(default=1e-16, init=False) # can't be 0
min_retarget_interval: Optional[int] = field(init=False, default=None)
fixed_retarget_t: Optional[int] = field(init=False, default=None)
max_targets: Optional[int] = field(init=False, default=None)
max_abs_detuning: Optional[float] = field(default=None, init=False)
max_amp: float = field(default=0, init=False)
min_retarget_interval: Optional[int] = field(default=None, init=False)
fixed_retarget_t: Optional[int] = field(default=None, init=False)
max_targets: Optional[int] = field(default=None, init=False)

def __post_init__(self) -> None:
super().__post_init__()
Expand All @@ -72,3 +75,57 @@ def _undefined_fields(self) -> list[str]:
"max_duration",
]
return [field for field in optional if getattr(self, field) is None]

def validate_pulse(self, pulse: Pulse) -> None:
"""Checks if a pulse can be executed in this DMM.
Args:
pulse: The pulse to validate.
"""
super().validate_pulse(pulse)
round_detuning = np.round(pulse.detuning.samples, decimals=6)
if np.any(round_detuning > 0):
raise ValueError("The detuning in a DMM must not be positive.")
if self.bottom_detuning is not None and np.any(
round_detuning < self.bottom_detuning
):
raise ValueError(
"The detuning goes below the bottom detuning "
f"of the DMM ({self.bottom_detuning} rad/µs)."
)


def _dmm_id_from_name(dmm_name: str) -> str:
"""Converts a dmm_name into a dmm_id.
As a reminder the dmm_name is generated automatically from dmm_id
as dmm_id_{number of times dmm_id has been called}.
Args:
dmm_name: The dmm_name to convert.
Returns:
The associated dmm_id.
"""
return "_".join(dmm_name.split("_")[0:2])


def _get_dmm_name(dmm_id: str, channels: list[str]) -> str:
"""Get the dmm_name to add a dmm_id to a list of channels.
Counts the number of channels starting by dmm_id, generates the
dmm_name as dmm_id_{number of times dmm_id has been called}.
Args:
dmm_id: the id of the DMM to add to the list of channels.
channels: a list of channel names.
Returns:
The associated dmm_name.
"""
dmm_count = len(
[key for key in channels if _dmm_id_from_name(key) == dmm_id]
)
if dmm_count == 0:
return dmm_id
return dmm_id + f"_{dmm_count}"
28 changes: 14 additions & 14 deletions pulser-core/pulser/devices/_device_datacls.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from pulser.json.utils import get_dataclass_defaults, obj_to_dict
from pulser.register.base_register import BaseRegister, QubitId
from pulser.register.mappable_reg import MappableRegister
from pulser.register.register_layout import COORD_PRECISION, RegisterLayout
from pulser.register.register_layout import RegisterLayout
from pulser.register.traps import COORD_PRECISION

DIMENSIONS = Literal[2, 3]

Expand Down Expand Up @@ -164,12 +165,10 @@ def type_check(
for dmm_obj in self.dmm_objects:
type_check("All DMM channels", DMM, value_override=dmm_obj)

# TODO: Check that device has dmm objects if it supports SLM mask
# once DMM is supported for serialization
# if self.supports_slm_mask and not self.dmm_objects:
# raise ValueError(
# "One DMM object should be defined to support SLM mask."
# )
if self.supports_slm_mask and not self.dmm_objects:
raise ValueError(
"One DMM object should be defined to support SLM mask."
)

if self.channel_ids is not None:
if not (
Expand Down Expand Up @@ -454,6 +453,9 @@ def _to_abstract_repr(self) -> dict[str, Any]:
for p in ALWAYS_OPTIONAL_PARAMS:
if params[p] == defaults[p]:
params.pop(p, None)
# Delete parameters of PARAMS_WITH_ABSTR_REPR in params
for p in PARAMS_WITH_ABSTR_REPR:
params.pop(p, None)
ch_list = []
for ch_name, ch_obj in self.channels.items():
ch_list.append(ch_obj._to_abstract_repr(ch_name))
Expand All @@ -462,12 +464,8 @@ def _to_abstract_repr(self) -> dict[str, Any]:
dmm_list = []
for dmm_name, dmm_obj in self.dmm_channels.items():
dmm_list.append(dmm_obj._to_abstract_repr(dmm_name))
# Add dmm channels if different than default
if "dmm_objects" in params:
params["dmm_channels"] = dmm_list
# Delete parameters of PARAMS_WITH_ABSTR_REPR in params
for p in PARAMS_WITH_ABSTR_REPR:
params.pop(p, None)
if dmm_list:
params["dmm_objects"] = dmm_list
return params

def to_abstract_repr(self) -> str:
Expand Down Expand Up @@ -518,7 +516,7 @@ class Device(BaseDevice):

def __post_init__(self) -> None:
super().__post_init__()
for ch_id, ch_obj in self.channels.items():
for ch_id, ch_obj in {**self.channels, **self.dmm_channels}.items():
if ch_obj.is_virtual():
_sep = "', '"
raise ValueError(
Expand Down Expand Up @@ -667,6 +665,8 @@ class VirtualDevice(BaseDevice):
max_atom_num: int | None = None
max_radial_distance: int | None = None
supports_slm_mask: bool = True
# Needed to support SLM mask by default
dmm_objects: tuple[DMM, ...] = (DMM(),)
reusable_channels: bool = True

@property
Expand Down
19 changes: 9 additions & 10 deletions pulser-core/pulser/devices/_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Definitions of real devices."""
import numpy as np

from pulser.channels import Raman, Rydberg
from pulser.channels import DMM, Raman, Rydberg
from pulser.channels.eom import RydbergBeam, RydbergEOM
from pulser.devices._device_datacls import Device
from pulser.register.special_layouts import TriangularLatticeLayout
Expand Down Expand Up @@ -56,15 +56,14 @@
max_duration=2**26,
),
),
# TODO: Add DMM once it is supported for serialization
# dmm_objects=(
# DMM(
# clock_period=4,
# min_duration=16,
# max_duration=2**26,
# bottom_detuning=-20,
# ),
# ),
dmm_objects=(
DMM(
clock_period=4,
min_duration=16,
max_duration=2**26,
bottom_detuning=-20,
),
),
)

IroiseMVP = Device(
Expand Down
4 changes: 2 additions & 2 deletions pulser-core/pulser/devices/_mock_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pulser.channels import Microwave, Raman, Rydberg
from pulser.channels import DMM, Microwave, Raman, Rydberg
from pulser.devices._device_datacls import VirtualDevice

MockDevice = VirtualDevice(
Expand All @@ -31,5 +31,5 @@
Raman.Local(None, None, max_duration=None),
Microwave.Global(None, None, max_duration=None),
),
# TODO: Add DMM once it is supported for serialization
dmm_objects=(DMM(),),
)
41 changes: 36 additions & 5 deletions pulser-core/pulser/json/abstract_repr/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import pulser
import pulser.devices as devices
from pulser.channels import Microwave, Raman, Rydberg
from pulser.channels import DMM, Microwave, Raman, Rydberg
from pulser.channels.base_channel import Channel
from pulser.channels.eom import (
OPTIONAL_ABSTR_EOM_FIELDS,
Expand All @@ -44,6 +44,7 @@
from pulser.register.mappable_reg import MappableRegister
from pulser.register.register import Register
from pulser.register.register_layout import RegisterLayout
from pulser.register.weight_maps import DetuningMap
from pulser.waveforms import (
BlackmanWaveform,
CompositeWaveform,
Expand Down Expand Up @@ -281,14 +282,30 @@ def _deserialize_operation(seq: Sequence, op: dict, vars: dict) -> None:
channel=op["channel"],
correct_phase_drift=op.get("correct_phase_drift", False),
)
elif op["op"] == "add_dmm_detuning":
seq.add_dmm_detuning(
waveform=_deserialize_waveform(op["waveform"], vars),
dmm_name=op["dmm_name"],
protocol=op["protocol"],
)
elif op["op"] == "config_slm_mask":
seq.config_slm_mask(qubits=op["qubits"], dmm_id=op["dmm_id"])
elif op["op"] == "config_detuning_map":
seq.config_detuning_map(
detuning_map=_deserialize_det_map(op["detuning_map"]),
dmm_id=op["dmm_id"],
)


def _deserialize_channel(obj: dict[str, Any]) -> Channel:
params: dict[str, Any] = {}
channel_cls: Type[Channel]
if obj["basis"] == "ground-rydberg":
channel_cls = Rydberg
params["eom_config"] = None
if "bottom_detuning" in obj:
channel_cls = DMM
else:
channel_cls = Rydberg
params["eom_config"] = None
if obj["eom_config"] is not None:
data = obj["eom_config"]
try:
Expand Down Expand Up @@ -352,9 +369,9 @@ def _deserialize_device_object(obj: dict[str, Any]) -> Device | VirtualDevice:
params: dict[str, Any] = dict(
channel_ids=tuple(ch_ids), channel_objects=tuple(ch_objs)
)
if "dmm_channels" in obj:
if "dmm_objects" in obj:
params["dmm_objects"] = tuple(
_deserialize_channel(dmm_ch) for dmm_ch in obj["dmm_channels"]
_deserialize_channel(dmm_ch) for dmm_ch in obj["dmm_objects"]
)
device_fields = dataclasses.fields(device_cls)
device_defaults = get_dataclass_defaults(device_fields)
Expand All @@ -379,6 +396,19 @@ def _deserialize_device_object(obj: dict[str, Any]) -> Device | VirtualDevice:
raise AbstractReprError("Device deserialization failed.") from e


def _deserialize_det_map(ser_det_map: dict) -> DetuningMap:
trap_coords = []
weights = []
for trap in ser_det_map["traps"]:
trap_coords.append((trap["x"], trap["y"]))
weights.append(trap["weight"])
return DetuningMap(
trap_coordinates=trap_coords,
weights=weights,
slug=ser_det_map.get("slug"),
)


def deserialize_abstract_sequence(obj_str: str) -> Sequence:
"""Deserialize a sequence from an abstract JSON object.
Expand Down Expand Up @@ -433,6 +463,7 @@ def deserialize_abstract_sequence(obj_str: str) -> Sequence:

# SLM Mask
if "slm_mask_targets" in obj:
# This is kept for backwards compatibility
seq.config_slm_mask(obj["slm_mask_targets"])

# Variables
Expand Down
Loading

0 comments on commit 59bf4d0

Please sign in to comment.