Skip to content

Commit

Permalink
Updates to the target methods (#601)
Browse files Browse the repository at this point in the history
* Change target types from Iterable to Collection

* Proper error message for when no targets are given
  • Loading branch information
HGSilveri authored Oct 24, 2023
1 parent 581886f commit 7c2ddee
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 16 deletions.
8 changes: 4 additions & 4 deletions pulser-core/pulser/json/abstract_repr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import inspect
import json
from collections.abc import Iterable
from collections.abc import Collection
from itertools import chain
from typing import TYPE_CHECKING, Any, Union, cast

Expand Down Expand Up @@ -156,16 +156,16 @@ def serialize_abstract_sequence(
res["variables"][var.name]["value"] = [var.dtype()] * var.size

def unfold_targets(
target_ids: QubitId | Iterable[QubitId],
target_ids: QubitId | Collection[QubitId],
) -> QubitId | list[QubitId]:
if isinstance(target_ids, (int, str)):
return target_ids

targets = list(cast(Iterable, target_ids))
targets = list(cast(Collection, target_ids))
return targets if len(targets) > 1 else targets[0]

def convert_targets(
target_ids: Union[QubitId, Iterable[QubitId]],
target_ids: Union[QubitId, Collection[QubitId]],
force_list_out: bool = False,
) -> Union[int, list[int]]:
target_array = np.array(unfold_targets(target_ids))
Expand Down
29 changes: 17 additions & 12 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import json
import os
import warnings
from collections.abc import Iterable, Mapping
from collections.abc import Collection, Mapping
from typing import (
Any,
Generic,
Expand Down Expand Up @@ -529,7 +529,7 @@ def _set_slm_mask_dmm(self, dmm_id: str, targets: set[QubitId]) -> None:

@seq_decorators.store
def config_slm_mask(
self, qubits: Iterable[QubitId], dmm_id: str = "dmm_0"
self, qubits: Collection[QubitId], dmm_id: str = "dmm_0"
) -> None:
"""Setup an SLM mask by specifying the qubits it targets.
Expand All @@ -545,7 +545,7 @@ def config_slm_mask(
pulse added after this operation.
Args:
qubits: Iterable of qubit ID's to mask during the first global
qubits: Collection of qubit ID's to mask during the first global
pulse of the sequence.
dmm_id: Id of the DMM channel to use in the device.
"""
Expand Down Expand Up @@ -846,7 +846,7 @@ def declare_channel(
self,
name: str,
channel_id: str,
initial_target: Optional[Union[QubitId, Iterable[QubitId]]] = None,
initial_target: Optional[Union[QubitId, Collection[QubitId]]] = None,
) -> None:
"""Declares a new channel to the Sequence.
Expand Down Expand Up @@ -901,7 +901,7 @@ def declare_channel(
try:
cond = any(
isinstance(t, Parametrized)
for t in cast(Iterable, initial_target)
for t in cast(Collection, initial_target)
)
except TypeError:
cond = isinstance(initial_target, Parametrized)
Expand Down Expand Up @@ -930,7 +930,7 @@ def declare_channel(
else:
# "_target" call is not saved
self._target(
cast(Union[Iterable, QubitId], initial_target), name
cast(Union[Collection, QubitId], initial_target), name
)

# Manually store the channel declaration as a regular call
Expand Down Expand Up @@ -1319,14 +1319,14 @@ def add_dmm_detuning(
@seq_decorators.store
def target(
self,
qubits: Union[QubitId, Iterable[QubitId]],
qubits: Union[QubitId, Collection[QubitId]],
channel: str,
) -> None:
"""Changes the target qubit of a 'Local' channel.
Args:
qubits: The new target for this channel. Must correspond to a
qubit ID in device or an iterable of qubit IDs, when
qubit ID in device or a collection of qubit IDs, when
multi-qubit addressing is possible.
channel: The channel's name provided when declared. Must be
a channel with 'Local' addressing.
Expand All @@ -1336,14 +1336,14 @@ def target(
@seq_decorators.store
def target_index(
self,
qubits: Union[int, Iterable[int], Parametrized],
qubits: Union[int, Collection[int], Parametrized],
channel: str,
) -> None:
"""Changes the target qubit of a 'Local' channel.
Args:
qubits: The new target for this channel. Must correspond to a
qubit index or an iterable of qubit indices, when multi-qubit
qubit index or an collection of qubit indices, when multi-qubit
addressing is possible.
A qubit index is a number between 0 and the number of qubits.
It is then converted to a Qubit ID using the order in which
Expand Down Expand Up @@ -1993,21 +1993,26 @@ def _add(
@seq_decorators.block_if_measured
def _target(
self,
qubits: Union[Iterable[QubitId], QubitId, Parametrized],
qubits: Union[Collection[QubitId], QubitId, Parametrized],
channel: str,
_index: bool = False,
) -> None:
self._validate_channel(channel, block_eom_mode=True)
channel_obj = self._schedule[channel].channel_obj
try:
qubits_set = (
set(cast(Iterable, qubits))
set(cast(Collection, qubits))
if not isinstance(qubits, str)
else {qubits}
)
except TypeError:
qubits_set = {qubits}

if not qubits_set:
raise ValueError(
"Need at least one qubit to target but none were given."
)

if channel_obj.addressing != "Local":
raise ValueError("Can only choose target of 'Local' channels.")
elif (
Expand Down
2 changes: 2 additions & 0 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,8 @@ def test_target(reg, device):
seq.target("q3", "ch1")
with pytest.raises(ValueError, match="can target at most 1 qubits"):
seq.target(["q1", "q5"], "ch0")
with pytest.raises(ValueError, match="Need at least one qubit to target"):
seq.target([], "ch0")

assert seq._schedule["ch0"][-1] == _TimeSlot("target", -1, 0, {"q1"})
seq.target("q4", "ch0")
Expand Down

0 comments on commit 7c2ddee

Please sign in to comment.