Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give access to all EOM block parameters and allow for phase drift correction #566

Merged
merged 8 commits into from
Aug 29, 2023
Merged
16 changes: 16 additions & 0 deletions pulser-core/pulser/channels/eom.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,22 @@ def __post_init__(self) -> None:
f" enumeration, not {self.limiting_beam}."
)

def calculate_detuning_off(
self, amp_on: float, detuning_on: float, optimal_detuning_off: float
) -> float:
"""Calculates the detuning when the amplitude is off in EOM mode.
Args:
amp_on: The amplitude of the EOM pulses (in rad/µs).
detuning_on: The detuning of the EOM pulses (in rad/µs).
optimal_detuning_off: The optimal value of detuning (in rad/µs)
when there is no pulse being played. It will choose the closest
value among the existing options.
a-corni marked this conversation as resolved.
Show resolved Hide resolved
"""
off_options = self.detuning_off_options(amp_on, detuning_on)
closest_option = np.abs(off_options - optimal_detuning_off).argmin()
return cast(float, off_options[closest_option])

def detuning_off_options(
self, rabi_frequency: float, detuning_on: float
) -> np.ndarray:
Expand Down
7 changes: 6 additions & 1 deletion pulser-core/pulser/json/abstract_repr/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def _deserialize_operation(seq: Sequence, op: dict, vars: dict) -> None:
optimal_detuning_off=_deserialize_parameter(
op["optimal_detuning_off"], vars
),
correct_phase_drift=op.get("correct_phase_drift", False),
)
elif op["op"] == "add_eom_pulse":
seq.add_eom_pulse(
Expand All @@ -273,9 +274,13 @@ def _deserialize_operation(seq: Sequence, op: dict, vars: dict) -> None:
op["post_phase_shift"], vars
),
protocol=op["protocol"],
correct_phase_drift=op.get("correct_phase_drift", False),
)
elif op["op"] == "disable_eom_mode":
seq.disable_eom_mode(channel=op["channel"])
seq.disable_eom_mode(
channel=op["channel"],
correct_phase_drift=op.get("correct_phase_drift", False),
)


def _deserialize_channel(obj: dict[str, Any]) -> Channel:
Expand Down
12 changes: 12 additions & 0 deletions pulser-core/pulser/json/abstract_repr/schemas/sequence-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,10 @@
"$ref": "#/definitions/ChannelName",
"description": "The name of the channel to take out of EOM mode."
},
"correct_phase_drift": {
"description": "Performs a phase shift to correct for the phase drift that occured since the last pulse (or the start of the EOM mode, if no pulse was added).",
"type": "boolean"
},
"op": {
"const": "disable_eom_mode",
"type": "string"
Expand All @@ -467,6 +471,10 @@
"$ref": "#/definitions/ChannelName",
"description": "The name of the channel to add the pulse to."
},
"correct_phase_drift": {
"description": "Performs a phase shift to correct for the phase drift that occured since the last pulse (or the start of the EOM mode, if adding the first pulse).",
"type": "boolean"
},
"duration": {
"$ref": "#/definitions/ParametrizedNum",
"description": "The duration of the pulse (in ns)."
Expand Down Expand Up @@ -514,6 +522,10 @@
"$ref": "#/definitions/ChannelName",
"description": "The name of the channel to put in EOM mode."
},
"correct_phase_drift": {
"description": "Performs a phase shift to correct for the phase drift incurred while turning on the EOM mode.",
"type": "boolean"
},
"detuning_on": {
"$ref": "#/definitions/ParametrizedNum",
"description": "The detuning of the EOM pulses (in rad/µs)."
Expand Down
31 changes: 27 additions & 4 deletions pulser-core/pulser/json/abstract_repr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ def get_all_args(
}
return {**default_values, **params}

def remove_kwarg_if_default(
data: dict[str, Any], call_name: str, kwarg_name: str
) -> dict[str, Any]:
if data.get(kwarg_name, None) == get_kwarg_default(
call_name, kwarg_name
):
data.pop(kwarg_name, None)
return data

HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
operations = res["operations"]
for call in chain(seq._calls, seq._to_build_calls):
if call.name == "__init__":
Expand Down Expand Up @@ -269,9 +278,18 @@ def get_all_args(
res["slm_mask_targets"] = tuple(seq._slm_mask_targets)
elif call.name == "enable_eom_mode":
data = get_all_args(
("channel", "amp_on", "detuning_on", "optimal_detuning_off"),
(
"channel",
"amp_on",
"detuning_on",
"optimal_detuning_off",
"correct_phase_drift",
),
call,
)
data = remove_kwarg_if_default(
data, call.name, "correct_phase_drift"
)
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
operations.append({"op": "enable_eom_mode", **data})
elif call.name == "add_eom_pulse":
data = get_all_args(
Expand All @@ -281,15 +299,20 @@ def get_all_args(
"phase",
"post_phase_shift",
"protocol",
"correct_phase_drift",
),
call,
)
data = remove_kwarg_if_default(
data, call.name, "correct_phase_drift"
)
operations.append({"op": "add_eom_pulse", **data})
elif call.name == "disable_eom_mode":
data = get_all_args(("channel",), call)
operations.append(
{"op": "disable_eom_mode", "channel": data["channel"]}
data = get_all_args(("channel", "correct_phase_drift"), call)
data = remove_kwarg_if_default(
data, call.name, "correct_phase_drift"
)
operations.append({"op": "disable_eom_mode", **data})
else:
raise AbstractReprError(f"Unknown call '{call.name}'.")

Expand Down
30 changes: 28 additions & 2 deletions pulser-core/pulser/sequence/_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ class _EOMSettings:
tf: Optional[int] = None


@dataclass
class _PhaseDriftParams:
drift_rate: float # rad/µs
ti: int # ns

def calc_phase_drift(self, tf: int) -> float:
"""Calculate the phase drift during the elapsed time."""
return self.drift_rate * (tf - self.ti) * 1e-3


@dataclass
class _ChannelSchedule:
channel_id: str
Expand Down Expand Up @@ -350,8 +360,16 @@ def add_pulse(
channel: str,
phase_barrier_ts: list[int],
protocol: str,
phase_drift_params: _PhaseDriftParams | None = None,
) -> None:
pass
def corrected_phase(tf: int) -> float:
phase_drift = (
phase_drift_params.calc_phase_drift(tf)
if phase_drift_params
else 0
)
return pulse.phase - phase_drift

last = self[channel][-1]
t0 = last.tf
current_max_t = max(t0, *phase_barrier_ts)
Expand All @@ -368,7 +386,7 @@ def add_pulse(
)
last_pulse = cast(Pulse, last_pulse_slot.type)
# Checks if the current pulse changes the phase
if last_pulse.phase != pulse.phase:
if last_pulse.phase != corrected_phase(current_max_t):
# Subtracts the time that has already elapsed since the
# last pulse from the phase_jump_time and adds the
# fall_time to let the last pulse ramp down
Expand All @@ -392,6 +410,14 @@ def add_pulse(
ti = t0 + delay_duration
tf = ti + pulse.duration
self._check_duration(tf)
# dataclasses.replace() does not work on Pulse (because init=False)
if phase_drift_params is not None:
pulse = Pulse(
amplitude=pulse.amplitude,
detuning=pulse.detuning,
phase=corrected_phase(ti),
lvignoli marked this conversation as resolved.
Show resolved Hide resolved
post_phase_shift=pulse.post_phase_shift,
)
self[channel].slots.append(_TimeSlot(pulse, ti, tf, last.targets))

def add_delay(self, duration: int, channel: str) -> None:
Expand Down
91 changes: 78 additions & 13 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@
from pulser.register.mappable_reg import MappableRegister
from pulser.sequence._basis_ref import _QubitRef
from pulser.sequence._call import _Call
from pulser.sequence._schedule import _ChannelSchedule, _Schedule, _TimeSlot
from pulser.sequence._schedule import (
_ChannelSchedule,
_PhaseDriftParams,
_Schedule,
_TimeSlot,
)
from pulser.sequence._seq_drawer import Figure, draw_sequence
from pulser.sequence._seq_str import seq_to_str

Expand Down Expand Up @@ -774,6 +779,7 @@ def enable_eom_mode(
amp_on: Union[float, Parametrized],
detuning_on: Union[float, Parametrized],
optimal_detuning_off: Union[float, Parametrized] = 0.0,
correct_phase_drift: bool = False,
lvignoli marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Puts a channel in EOM mode operation.
Expand Down Expand Up @@ -806,6 +812,8 @@ def enable_eom_mode(
optimal_detuning_off: The optimal value of detuning (in rad/µs)
when there is no pulse being played. It will choose the closest
value among the existing options.
correct_phase_drift: Performs a phase shift to correct for the
phase drift incurred while turning on the EOM mode.
a-corni marked this conversation as resolved.
Show resolved Hide resolved
"""
if self.is_in_eom_mode(channel):
raise RuntimeError(
Expand All @@ -822,29 +830,35 @@ def enable_eom_mode(
channel_obj.validate_pulse(on_pulse)
amp_on = cast(float, amp_on)
detuning_on = cast(float, detuning_on)

off_options = cast(
RydbergEOM, channel_obj.eom_config
).detuning_off_options(amp_on, detuning_on)

eom_config = cast(RydbergEOM, channel_obj.eom_config)
if not isinstance(optimal_detuning_off, Parametrized):
closest_option = np.abs(
off_options - optimal_detuning_off
).argmin()
detuning_off = off_options[closest_option]
detuning_off = eom_config.calculate_detuning_off(
amp_on, detuning_on, optimal_detuning_off
)
off_pulse = Pulse.ConstantPulse(
channel_obj.min_duration, 0.0, detuning_off, 0.0
)
channel_obj.validate_pulse(off_pulse)

if not self.is_parametrized():
phase_drift_params = _PhaseDriftParams(
drift_rate=-detuning_off, ti=self.get_duration(channel)
lvignoli marked this conversation as resolved.
Show resolved Hide resolved
)
self._schedule.enable_eom(
channel, amp_on, detuning_on, detuning_off
)
if correct_phase_drift:
buffer_slot = self._last(channel)
drift = phase_drift_params.calc_phase_drift(buffer_slot.tf)
self._phase_shift(
-drift, *buffer_slot.targets, basis=channel_obj.basis
a-corni marked this conversation as resolved.
Show resolved Hide resolved
)

@seq_decorators.store
@seq_decorators.block_if_measured
def disable_eom_mode(self, channel: str) -> None:
def disable_eom_mode(
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
self, channel: str, correct_phase_drift: bool = False
) -> None:
"""Takes a channel out of EOM mode operation.
For channels with a finite modulation bandwidth and an EOM, operation
Expand All @@ -867,11 +881,24 @@ def disable_eom_mode(self, channel: str) -> None:
Args:
channel: The name of the channel to take out of EOM mode.
correct_phase_drift: Performs a phase shift to correct for the
phase drift that occured since the last pulse (or the start of
the EOM mode, if no pulse was added).
"""
if not self.is_in_eom_mode(channel):
raise RuntimeError(f"The '{channel}' channel is not in EOM mode.")
if not self.is_parametrized():
self._schedule.disable_eom(channel)
if correct_phase_drift:
ch_schedule = self._schedule[channel]
# EOM mode has just been disabled, so tf is defined
last_eom_block_tf = cast(int, ch_schedule.eom_blocks[-1].tf)
drift_params = self._get_last_eom_pulse_phase_drift(channel)
self._phase_shift(
-drift_params.calc_phase_drift(last_eom_block_tf),
*ch_schedule[-1].targets,
basis=ch_schedule.channel_obj.basis,
)

@seq_decorators.store
@seq_decorators.mark_non_empty
Expand All @@ -883,6 +910,7 @@ def add_eom_pulse(
phase: Union[float, Parametrized],
post_phase_shift: Union[float, Parametrized] = 0.0,
protocol: PROTOCOLS = "min-delay",
correct_phase_drift: bool = False,
) -> None:
"""Adds a square pulse to a channel in EOM mode.
Expand Down Expand Up @@ -913,6 +941,11 @@ def add_eom_pulse(
immediately after the end of the pulse.
protocol: Stipulates how to deal with eventual conflicts with
other channels (see `Sequence.add()` for more details).
correct_phase_drift: Adjusts the phase to correct for the phase
drift that occured since the last pulse (or the start of the
EOM mode, if adding the first pulse). This effectively
changes the phase of the EOM pulse, so an extra delay might
be added to enforce the phase jump time.
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
"""
if not self.is_in_eom_mode(channel):
raise RuntimeError(f"Channel '{channel}' must be in EOM mode.")
Expand All @@ -936,7 +969,14 @@ def add_eom_pulse(
phase,
post_phase_shift=post_phase_shift,
)
self._add(eom_pulse, channel, protocol)
phase_drift_params = (
self._get_last_eom_pulse_phase_drift(channel)
if correct_phase_drift
else None
)
self._add(
eom_pulse, channel, protocol, phase_drift_params=phase_drift_params
)

@seq_decorators.store
@seq_decorators.mark_non_empty
Expand Down Expand Up @@ -1446,6 +1486,7 @@ def _add(
pulse: Union[Pulse, Parametrized],
channel: str,
protocol: PROTOCOLS,
phase_drift_params: _PhaseDriftParams | None = None,
) -> None:
self._validate_add_protocol(protocol)
if self.is_parametrized():
Expand Down Expand Up @@ -1475,7 +1516,13 @@ def _add(
self._basis_ref[basis][q].phase.last_time for q in last.targets
]

self._schedule.add_pulse(pulse, channel, phase_barriers, protocol)
self._schedule.add_pulse(
pulse,
channel,
phase_barriers,
protocol,
phase_drift_params=phase_drift_params,
)

true_finish = self._last(channel).tf + pulse.fall_time(
channel_obj, in_eom_mode=self.is_in_eom_mode(channel)
Expand Down Expand Up @@ -1590,6 +1637,24 @@ def _phase_shift(
for qubit in target_ids:
self._basis_ref[basis][qubit].increment_phase(phi)

def _get_last_eom_pulse_phase_drift(
self, channel: str
) -> _PhaseDriftParams:
eom_settings = self._schedule[channel].eom_blocks[-1]
try:
last_pulse_tf = (
self._schedule[channel]
.last_pulse_slot(ignore_detuned_delay=True)
.tf
)
except RuntimeError:
# There is no previous pulse
last_pulse_tf = 0
return _PhaseDriftParams(
drift_rate=-eom_settings.detuning_off,
ti=max(eom_settings.ti, last_pulse_tf),
)

def _to_dict(self, _module: str = "pulser.sequence") -> dict[str, Any]:
d = obj_to_dict(
self,
Expand Down
Loading