Skip to content

Commit

Permalink
Give access to all EOM block parameters and allow for phase drift cor…
Browse files Browse the repository at this point in the history
…rection (#566)

* Isolate detuning off calculation

* Adding option to correct for phase drift in EOM mode

* Add UTs for phase drift correction

* Abstract repr support

* Include phase drift correctin in the EOM tutorial
  • Loading branch information
HGSilveri authored Aug 29, 2023
1 parent f742f5c commit 2315989
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 53 deletions.
16 changes: 16 additions & 0 deletions pulser-core/pulser/channels/eom.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,22 @@ def __post_init__(self) -> None:
f" enumeration, not {self.limiting_beam}."
)

def calculate_detuning_off(
self, amp_on: float, detuning_on: float, optimal_detuning_off: float
) -> float:
"""Calculates the detuning when the amplitude is off in EOM mode.
Args:
amp_on: The amplitude of the EOM pulses (in rad/µs).
detuning_on: The detuning of the EOM pulses (in rad/µs).
optimal_detuning_off: The optimal value of detuning (in rad/µs)
when there is no pulse being played. It will choose the closest
value among the existing options.
"""
off_options = self.detuning_off_options(amp_on, detuning_on)
closest_option = np.abs(off_options - optimal_detuning_off).argmin()
return cast(float, off_options[closest_option])

def detuning_off_options(
self, rabi_frequency: float, detuning_on: float
) -> np.ndarray:
Expand Down
7 changes: 6 additions & 1 deletion pulser-core/pulser/json/abstract_repr/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def _deserialize_operation(seq: Sequence, op: dict, vars: dict) -> None:
optimal_detuning_off=_deserialize_parameter(
op["optimal_detuning_off"], vars
),
correct_phase_drift=op.get("correct_phase_drift", False),
)
elif op["op"] == "add_eom_pulse":
seq.add_eom_pulse(
Expand All @@ -273,9 +274,13 @@ def _deserialize_operation(seq: Sequence, op: dict, vars: dict) -> None:
op["post_phase_shift"], vars
),
protocol=op["protocol"],
correct_phase_drift=op.get("correct_phase_drift", False),
)
elif op["op"] == "disable_eom_mode":
seq.disable_eom_mode(channel=op["channel"])
seq.disable_eom_mode(
channel=op["channel"],
correct_phase_drift=op.get("correct_phase_drift", False),
)


def _deserialize_channel(obj: dict[str, Any]) -> Channel:
Expand Down
12 changes: 12 additions & 0 deletions pulser-core/pulser/json/abstract_repr/schemas/sequence-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,10 @@
"$ref": "#/definitions/ChannelName",
"description": "The name of the channel to take out of EOM mode."
},
"correct_phase_drift": {
"description": "Performs a phase shift to correct for the phase drift that occured since the last pulse (or the start of the EOM mode, if no pulse was added).",
"type": "boolean"
},
"op": {
"const": "disable_eom_mode",
"type": "string"
Expand All @@ -467,6 +471,10 @@
"$ref": "#/definitions/ChannelName",
"description": "The name of the channel to add the pulse to."
},
"correct_phase_drift": {
"description": "Performs a phase shift to correct for the phase drift that occured since the last pulse (or the start of the EOM mode, if adding the first pulse).",
"type": "boolean"
},
"duration": {
"$ref": "#/definitions/ParametrizedNum",
"description": "The duration of the pulse (in ns)."
Expand Down Expand Up @@ -514,6 +522,10 @@
"$ref": "#/definitions/ChannelName",
"description": "The name of the channel to put in EOM mode."
},
"correct_phase_drift": {
"description": "Performs a phase shift to correct for the phase drift incurred while turning on the EOM mode.",
"type": "boolean"
},
"detuning_on": {
"$ref": "#/definitions/ParametrizedNum",
"description": "The detuning of the EOM pulses (in rad/µs)."
Expand Down
31 changes: 27 additions & 4 deletions pulser-core/pulser/json/abstract_repr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ def get_all_args(
}
return {**default_values, **params}

def remove_kwarg_if_default(
data: dict[str, Any], call_name: str, kwarg_name: str
) -> dict[str, Any]:
if data.get(kwarg_name, None) == get_kwarg_default(
call_name, kwarg_name
):
data.pop(kwarg_name, None)
return data

operations = res["operations"]
for call in chain(seq._calls, seq._to_build_calls):
if call.name == "__init__":
Expand Down Expand Up @@ -269,9 +278,18 @@ def get_all_args(
res["slm_mask_targets"] = tuple(seq._slm_mask_targets)
elif call.name == "enable_eom_mode":
data = get_all_args(
("channel", "amp_on", "detuning_on", "optimal_detuning_off"),
(
"channel",
"amp_on",
"detuning_on",
"optimal_detuning_off",
"correct_phase_drift",
),
call,
)
data = remove_kwarg_if_default(
data, call.name, "correct_phase_drift"
)
operations.append({"op": "enable_eom_mode", **data})
elif call.name == "add_eom_pulse":
data = get_all_args(
Expand All @@ -281,15 +299,20 @@ def get_all_args(
"phase",
"post_phase_shift",
"protocol",
"correct_phase_drift",
),
call,
)
data = remove_kwarg_if_default(
data, call.name, "correct_phase_drift"
)
operations.append({"op": "add_eom_pulse", **data})
elif call.name == "disable_eom_mode":
data = get_all_args(("channel",), call)
operations.append(
{"op": "disable_eom_mode", "channel": data["channel"]}
data = get_all_args(("channel", "correct_phase_drift"), call)
data = remove_kwarg_if_default(
data, call.name, "correct_phase_drift"
)
operations.append({"op": "disable_eom_mode", **data})
else:
raise AbstractReprError(f"Unknown call '{call.name}'.")

Expand Down
30 changes: 28 additions & 2 deletions pulser-core/pulser/sequence/_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ class _EOMSettings:
tf: Optional[int] = None


@dataclass
class _PhaseDriftParams:
drift_rate: float # rad/µs
ti: int # ns

def calc_phase_drift(self, tf: int) -> float:
"""Calculate the phase drift during the elapsed time."""
return self.drift_rate * (tf - self.ti) * 1e-3


@dataclass
class _ChannelSchedule:
channel_id: str
Expand Down Expand Up @@ -350,8 +360,16 @@ def add_pulse(
channel: str,
phase_barrier_ts: list[int],
protocol: str,
phase_drift_params: _PhaseDriftParams | None = None,
) -> None:
pass
def corrected_phase(tf: int) -> float:
phase_drift = (
phase_drift_params.calc_phase_drift(tf)
if phase_drift_params
else 0
)
return pulse.phase - phase_drift

last = self[channel][-1]
t0 = last.tf
current_max_t = max(t0, *phase_barrier_ts)
Expand All @@ -368,7 +386,7 @@ def add_pulse(
)
last_pulse = cast(Pulse, last_pulse_slot.type)
# Checks if the current pulse changes the phase
if last_pulse.phase != pulse.phase:
if last_pulse.phase != corrected_phase(current_max_t):
# Subtracts the time that has already elapsed since the
# last pulse from the phase_jump_time and adds the
# fall_time to let the last pulse ramp down
Expand All @@ -392,6 +410,14 @@ def add_pulse(
ti = t0 + delay_duration
tf = ti + pulse.duration
self._check_duration(tf)
# dataclasses.replace() does not work on Pulse (because init=False)
if phase_drift_params is not None:
pulse = Pulse(
amplitude=pulse.amplitude,
detuning=pulse.detuning,
phase=corrected_phase(ti),
post_phase_shift=pulse.post_phase_shift,
)
self[channel].slots.append(_TimeSlot(pulse, ti, tf, last.targets))

def add_delay(self, duration: int, channel: str) -> None:
Expand Down
91 changes: 78 additions & 13 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@
from pulser.register.mappable_reg import MappableRegister
from pulser.sequence._basis_ref import _QubitRef
from pulser.sequence._call import _Call
from pulser.sequence._schedule import _ChannelSchedule, _Schedule, _TimeSlot
from pulser.sequence._schedule import (
_ChannelSchedule,
_PhaseDriftParams,
_Schedule,
_TimeSlot,
)
from pulser.sequence._seq_drawer import Figure, draw_sequence
from pulser.sequence._seq_str import seq_to_str

Expand Down Expand Up @@ -774,6 +779,7 @@ def enable_eom_mode(
amp_on: Union[float, Parametrized],
detuning_on: Union[float, Parametrized],
optimal_detuning_off: Union[float, Parametrized] = 0.0,
correct_phase_drift: bool = False,
) -> None:
"""Puts a channel in EOM mode operation.
Expand Down Expand Up @@ -806,6 +812,8 @@ def enable_eom_mode(
optimal_detuning_off: The optimal value of detuning (in rad/µs)
when there is no pulse being played. It will choose the closest
value among the existing options.
correct_phase_drift: Performs a phase shift to correct for the
phase drift incurred while turning on the EOM mode.
"""
if self.is_in_eom_mode(channel):
raise RuntimeError(
Expand All @@ -822,29 +830,35 @@ def enable_eom_mode(
channel_obj.validate_pulse(on_pulse)
amp_on = cast(float, amp_on)
detuning_on = cast(float, detuning_on)

off_options = cast(
RydbergEOM, channel_obj.eom_config
).detuning_off_options(amp_on, detuning_on)

eom_config = cast(RydbergEOM, channel_obj.eom_config)
if not isinstance(optimal_detuning_off, Parametrized):
closest_option = np.abs(
off_options - optimal_detuning_off
).argmin()
detuning_off = off_options[closest_option]
detuning_off = eom_config.calculate_detuning_off(
amp_on, detuning_on, optimal_detuning_off
)
off_pulse = Pulse.ConstantPulse(
channel_obj.min_duration, 0.0, detuning_off, 0.0
)
channel_obj.validate_pulse(off_pulse)

if not self.is_parametrized():
phase_drift_params = _PhaseDriftParams(
drift_rate=-detuning_off, ti=self.get_duration(channel)
)
self._schedule.enable_eom(
channel, amp_on, detuning_on, detuning_off
)
if correct_phase_drift:
buffer_slot = self._last(channel)
drift = phase_drift_params.calc_phase_drift(buffer_slot.tf)
self._phase_shift(
-drift, *buffer_slot.targets, basis=channel_obj.basis
)

@seq_decorators.store
@seq_decorators.block_if_measured
def disable_eom_mode(self, channel: str) -> None:
def disable_eom_mode(
self, channel: str, correct_phase_drift: bool = False
) -> None:
"""Takes a channel out of EOM mode operation.
For channels with a finite modulation bandwidth and an EOM, operation
Expand All @@ -867,11 +881,24 @@ def disable_eom_mode(self, channel: str) -> None:
Args:
channel: The name of the channel to take out of EOM mode.
correct_phase_drift: Performs a phase shift to correct for the
phase drift that occured since the last pulse (or the start of
the EOM mode, if no pulse was added).
"""
if not self.is_in_eom_mode(channel):
raise RuntimeError(f"The '{channel}' channel is not in EOM mode.")
if not self.is_parametrized():
self._schedule.disable_eom(channel)
if correct_phase_drift:
ch_schedule = self._schedule[channel]
# EOM mode has just been disabled, so tf is defined
last_eom_block_tf = cast(int, ch_schedule.eom_blocks[-1].tf)
drift_params = self._get_last_eom_pulse_phase_drift(channel)
self._phase_shift(
-drift_params.calc_phase_drift(last_eom_block_tf),
*ch_schedule[-1].targets,
basis=ch_schedule.channel_obj.basis,
)

@seq_decorators.store
@seq_decorators.mark_non_empty
Expand All @@ -883,6 +910,7 @@ def add_eom_pulse(
phase: Union[float, Parametrized],
post_phase_shift: Union[float, Parametrized] = 0.0,
protocol: PROTOCOLS = "min-delay",
correct_phase_drift: bool = False,
) -> None:
"""Adds a square pulse to a channel in EOM mode.
Expand Down Expand Up @@ -913,6 +941,11 @@ def add_eom_pulse(
immediately after the end of the pulse.
protocol: Stipulates how to deal with eventual conflicts with
other channels (see `Sequence.add()` for more details).
correct_phase_drift: Adjusts the phase to correct for the phase
drift that occured since the last pulse (or the start of the
EOM mode, if adding the first pulse). This effectively
changes the phase of the EOM pulse, so an extra delay might
be added to enforce the phase jump time.
"""
if not self.is_in_eom_mode(channel):
raise RuntimeError(f"Channel '{channel}' must be in EOM mode.")
Expand All @@ -936,7 +969,14 @@ def add_eom_pulse(
phase,
post_phase_shift=post_phase_shift,
)
self._add(eom_pulse, channel, protocol)
phase_drift_params = (
self._get_last_eom_pulse_phase_drift(channel)
if correct_phase_drift
else None
)
self._add(
eom_pulse, channel, protocol, phase_drift_params=phase_drift_params
)

@seq_decorators.store
@seq_decorators.mark_non_empty
Expand Down Expand Up @@ -1446,6 +1486,7 @@ def _add(
pulse: Union[Pulse, Parametrized],
channel: str,
protocol: PROTOCOLS,
phase_drift_params: _PhaseDriftParams | None = None,
) -> None:
self._validate_add_protocol(protocol)
if self.is_parametrized():
Expand Down Expand Up @@ -1475,7 +1516,13 @@ def _add(
self._basis_ref[basis][q].phase.last_time for q in last.targets
]

self._schedule.add_pulse(pulse, channel, phase_barriers, protocol)
self._schedule.add_pulse(
pulse,
channel,
phase_barriers,
protocol,
phase_drift_params=phase_drift_params,
)

true_finish = self._last(channel).tf + pulse.fall_time(
channel_obj, in_eom_mode=self.is_in_eom_mode(channel)
Expand Down Expand Up @@ -1590,6 +1637,24 @@ def _phase_shift(
for qubit in target_ids:
self._basis_ref[basis][qubit].increment_phase(phi)

def _get_last_eom_pulse_phase_drift(
self, channel: str
) -> _PhaseDriftParams:
eom_settings = self._schedule[channel].eom_blocks[-1]
try:
last_pulse_tf = (
self._schedule[channel]
.last_pulse_slot(ignore_detuned_delay=True)
.tf
)
except RuntimeError:
# There is no previous pulse
last_pulse_tf = 0
return _PhaseDriftParams(
drift_rate=-eom_settings.detuning_off,
ti=max(eom_settings.ti, last_pulse_tf),
)

def _to_dict(self, _module: str = "pulser.sequence") -> dict[str, Any]:
d = obj_to_dict(
self,
Expand Down
Loading

0 comments on commit 2315989

Please sign in to comment.