Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give access to all EOM block parameters and allow for phase drift correction #566

Merged
merged 8 commits into from
Aug 29, 2023
16 changes: 16 additions & 0 deletions pulser-core/pulser/channels/eom.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,22 @@ def __post_init__(self) -> None:
f" enumeration, not {self.limiting_beam}."
)

def calculate_detuning_off(
self, amp_on: float, detuning_on: float, optimal_detuning_off: float
) -> float:
"""Calculates the detuning when the amplitude is off in EOM mode.
Args:
amp_on: The amplitude of the EOM pulses (in rad/µs).
detuning_on: The detuning of the EOM pulses (in rad/µs).
optimal_detuning_off: The optimal value of detuning (in rad/µs)
when there is no pulse being played. It will choose the closest
value among the existing options.
a-corni marked this conversation as resolved.
Show resolved Hide resolved
"""
off_options = self.detuning_off_options(amp_on, detuning_on)
closest_option = np.abs(off_options - optimal_detuning_off).argmin()
return cast(float, off_options[closest_option])

def detuning_off_options(
self, rabi_frequency: float, detuning_on: float
) -> np.ndarray:
Expand Down
30 changes: 28 additions & 2 deletions pulser-core/pulser/sequence/_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ class _EOMSettings:
tf: Optional[int] = None


@dataclass
class _PhaseDriftParams:
drift_rate: float # rad/µs
ti: int # ns

def calc_phase_drift(self, tf: int) -> float:
"""Calculate the phase drift during the elapsed time."""
return self.drift_rate * (tf - self.ti) * 1e-3


@dataclass
class _ChannelSchedule:
channel_id: str
Expand Down Expand Up @@ -350,8 +360,16 @@ def add_pulse(
channel: str,
phase_barrier_ts: list[int],
protocol: str,
phase_drift_params: _PhaseDriftParams | None = None,
) -> None:
pass
def corrected_phase(tf: int) -> float:
phase_drift = (
phase_drift_params.calc_phase_drift(tf)
if phase_drift_params
else 0
)
return pulse.phase - phase_drift

last = self[channel][-1]
t0 = last.tf
current_max_t = max(t0, *phase_barrier_ts)
Expand All @@ -368,7 +386,7 @@ def add_pulse(
)
last_pulse = cast(Pulse, last_pulse_slot.type)
# Checks if the current pulse changes the phase
if last_pulse.phase != pulse.phase:
if last_pulse.phase != corrected_phase(current_max_t):
# Subtracts the time that has already elapsed since the
# last pulse from the phase_jump_time and adds the
# fall_time to let the last pulse ramp down
Expand All @@ -392,6 +410,14 @@ def add_pulse(
ti = t0 + delay_duration
tf = ti + pulse.duration
self._check_duration(tf)
# dataclasses.replace() does not work on Pulse (because init=False)
if phase_drift_params is not None:
pulse = Pulse(
amplitude=pulse.amplitude,
detuning=pulse.detuning,
phase=corrected_phase(ti),
lvignoli marked this conversation as resolved.
Show resolved Hide resolved
post_phase_shift=pulse.post_phase_shift,
)
self[channel].slots.append(_TimeSlot(pulse, ti, tf, last.targets))

def add_delay(self, duration: int, channel: str) -> None:
Expand Down
91 changes: 78 additions & 13 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@
from pulser.register.mappable_reg import MappableRegister
from pulser.sequence._basis_ref import _QubitRef
from pulser.sequence._call import _Call
from pulser.sequence._schedule import _ChannelSchedule, _Schedule, _TimeSlot
from pulser.sequence._schedule import (
_ChannelSchedule,
_PhaseDriftParams,
_Schedule,
_TimeSlot,
)
from pulser.sequence._seq_drawer import Figure, draw_sequence
from pulser.sequence._seq_str import seq_to_str

Expand Down Expand Up @@ -774,6 +779,7 @@ def enable_eom_mode(
amp_on: Union[float, Parametrized],
detuning_on: Union[float, Parametrized],
optimal_detuning_off: Union[float, Parametrized] = 0.0,
correct_phase_drift: bool = False,
lvignoli marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Puts a channel in EOM mode operation.
Expand Down Expand Up @@ -806,6 +812,8 @@ def enable_eom_mode(
optimal_detuning_off: The optimal value of detuning (in rad/µs)
when there is no pulse being played. It will choose the closest
value among the existing options.
correct_phase_drift: Performs a phase shift to correct for the
phase drift incurred while turning on the EOM mode.
a-corni marked this conversation as resolved.
Show resolved Hide resolved
"""
if self.is_in_eom_mode(channel):
raise RuntimeError(
Expand All @@ -822,29 +830,35 @@ def enable_eom_mode(
channel_obj.validate_pulse(on_pulse)
amp_on = cast(float, amp_on)
detuning_on = cast(float, detuning_on)

off_options = cast(
RydbergEOM, channel_obj.eom_config
).detuning_off_options(amp_on, detuning_on)

eom_config = cast(RydbergEOM, channel_obj.eom_config)
if not isinstance(optimal_detuning_off, Parametrized):
closest_option = np.abs(
off_options - optimal_detuning_off
).argmin()
detuning_off = off_options[closest_option]
detuning_off = eom_config.calculate_detuning_off(
amp_on, detuning_on, optimal_detuning_off
)
off_pulse = Pulse.ConstantPulse(
channel_obj.min_duration, 0.0, detuning_off, 0.0
)
channel_obj.validate_pulse(off_pulse)

if not self.is_parametrized():
phase_drift_params = _PhaseDriftParams(
drift_rate=-detuning_off, ti=self.get_duration(channel)
lvignoli marked this conversation as resolved.
Show resolved Hide resolved
)
self._schedule.enable_eom(
channel, amp_on, detuning_on, detuning_off
)
if correct_phase_drift:
buffer_slot = self._last(channel)
drift = phase_drift_params.calc_phase_drift(buffer_slot.tf)
self._phase_shift(
-drift, *buffer_slot.targets, basis=channel_obj.basis
a-corni marked this conversation as resolved.
Show resolved Hide resolved
)

@seq_decorators.store
@seq_decorators.block_if_measured
def disable_eom_mode(self, channel: str) -> None:
def disable_eom_mode(
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
self, channel: str, correct_phase_drift: bool = False
) -> None:
"""Takes a channel out of EOM mode operation.
For channels with a finite modulation bandwidth and an EOM, operation
Expand All @@ -867,11 +881,24 @@ def disable_eom_mode(self, channel: str) -> None:
Args:
channel: The name of the channel to take out of EOM mode.
correct_phase_drift: Performs a phase shift to correct for the
phase drift that occured since the last pulse (or the start of
the EOM mode, if no pulse was added).
"""
if not self.is_in_eom_mode(channel):
raise RuntimeError(f"The '{channel}' channel is not in EOM mode.")
if not self.is_parametrized():
self._schedule.disable_eom(channel)
if correct_phase_drift:
ch_schedule = self._schedule[channel]
# EOM mode has just been disabled, so tf is defined
last_eom_block_tf = cast(int, ch_schedule.eom_blocks[-1].tf)
drift_params = self._get_last_eom_pulse_phase_drift(channel)
self._phase_shift(
-drift_params.calc_phase_drift(last_eom_block_tf),
*ch_schedule[-1].targets,
basis=ch_schedule.channel_obj.basis,
)

@seq_decorators.store
@seq_decorators.mark_non_empty
Expand All @@ -883,6 +910,7 @@ def add_eom_pulse(
phase: Union[float, Parametrized],
post_phase_shift: Union[float, Parametrized] = 0.0,
protocol: PROTOCOLS = "min-delay",
correct_phase_drift: bool = False,
) -> None:
"""Adds a square pulse to a channel in EOM mode.
Expand Down Expand Up @@ -913,6 +941,11 @@ def add_eom_pulse(
immediately after the end of the pulse.
protocol: Stipulates how to deal with eventual conflicts with
other channels (see `Sequence.add()` for more details).
correct_phase_drift: Adjusts the phase to correct for the phase
drift that occured since the last pulse (or the start of the
EOM mode, if adding the first pulse). This effectively
changes the phase of the EOM pulse, so an extra delay might
be added to enforce the phase jump time.
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
"""
if not self.is_in_eom_mode(channel):
raise RuntimeError(f"Channel '{channel}' must be in EOM mode.")
Expand All @@ -936,7 +969,14 @@ def add_eom_pulse(
phase,
post_phase_shift=post_phase_shift,
)
self._add(eom_pulse, channel, protocol)
phase_drift_params = (
self._get_last_eom_pulse_phase_drift(channel)
if correct_phase_drift
else None
)
self._add(
eom_pulse, channel, protocol, phase_drift_params=phase_drift_params
)

@seq_decorators.store
@seq_decorators.mark_non_empty
Expand Down Expand Up @@ -1446,6 +1486,7 @@ def _add(
pulse: Union[Pulse, Parametrized],
channel: str,
protocol: PROTOCOLS,
phase_drift_params: _PhaseDriftParams | None = None,
) -> None:
self._validate_add_protocol(protocol)
if self.is_parametrized():
Expand Down Expand Up @@ -1475,7 +1516,13 @@ def _add(
self._basis_ref[basis][q].phase.last_time for q in last.targets
]

self._schedule.add_pulse(pulse, channel, phase_barriers, protocol)
self._schedule.add_pulse(
pulse,
channel,
phase_barriers,
protocol,
phase_drift_params=phase_drift_params,
)

true_finish = self._last(channel).tf + pulse.fall_time(
channel_obj, in_eom_mode=self.is_in_eom_mode(channel)
Expand Down Expand Up @@ -1590,6 +1637,24 @@ def _phase_shift(
for qubit in target_ids:
self._basis_ref[basis][qubit].increment_phase(phi)

def _get_last_eom_pulse_phase_drift(
self, channel: str
) -> _PhaseDriftParams:
eom_settings = self._schedule[channel].eom_blocks[-1]
try:
last_pulse_tf = (
self._schedule[channel]
.last_pulse_slot(ignore_detuned_delay=True)
.tf
)
except RuntimeError:
# There is no previous pulse
last_pulse_tf = 0
return _PhaseDriftParams(
drift_rate=-eom_settings.detuning_off,
ti=max(eom_settings.ti, last_pulse_tf),
)

def _to_dict(self, _module: str = "pulser.sequence") -> dict[str, Any]:
d = obj_to_dict(
self,
Expand Down
54 changes: 49 additions & 5 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,8 +1402,11 @@ def test_multiple_index_targets(reg):
assert built_seq._last("ch0").targets == {"q2", "q3"}


@pytest.mark.parametrize("correct_phase_drift", (True, False))
@pytest.mark.parametrize("custom_buffer_time", (None, 400))
def test_eom_mode(reg, mod_device, custom_buffer_time, patch_plt_show):
def test_eom_mode(
reg, mod_device, custom_buffer_time, correct_phase_drift, patch_plt_show
):
# Setting custom_buffer_time
channels = mod_device.channels
eom_config = dataclasses.replace(
Expand Down Expand Up @@ -1449,22 +1452,39 @@ def test_eom_mode(reg, mod_device, custom_buffer_time, patch_plt_show):
]

pulse_duration = 100
seq.add_eom_pulse("ch0", pulse_duration, phase=0.0)
seq.add_eom_pulse(
"ch0",
pulse_duration,
phase=0.0,
correct_phase_drift=correct_phase_drift,
)
first_pulse_slot = seq._schedule["ch0"].last_pulse_slot()
assert first_pulse_slot.ti == delay_slot.tf
assert first_pulse_slot.tf == first_pulse_slot.ti + pulse_duration
eom_pulse = Pulse.ConstantPulse(pulse_duration, amp_on, detuning_on, 0.0)
phase = detuning_off * first_pulse_slot.ti * 1e-3 * correct_phase_drift
eom_pulse = Pulse.ConstantPulse(pulse_duration, amp_on, detuning_on, phase)
assert first_pulse_slot.type == eom_pulse
assert not seq._schedule["ch0"].is_detuned_delay(eom_pulse)

# Check phase jump buffer
seq.add_eom_pulse("ch0", pulse_duration, phase=np.pi)
phase_ = np.pi
seq.add_eom_pulse(
"ch0",
pulse_duration,
phase=phase_,
correct_phase_drift=correct_phase_drift,
)
second_pulse_slot = seq._schedule["ch0"].last_pulse_slot()
phase_buffer = (
eom_pulse.fall_time(ch0_obj, in_eom_mode=True)
+ seq.declared_channels["ch0"].phase_jump_time
)
assert second_pulse_slot.ti == first_pulse_slot.tf + phase_buffer
# Corrects the phase acquired during the phase buffer
phase_ += detuning_off * phase_buffer * 1e-3 * correct_phase_drift
a-corni marked this conversation as resolved.
Show resolved Hide resolved
assert second_pulse_slot.type == Pulse.ConstantPulse(
pulse_duration, amp_on, detuning_on, phase_
)

# Check phase jump buffer is not enforced with "no-delay"
seq.add_eom_pulse("ch0", pulse_duration, phase=0.0, protocol="no-delay")
Expand Down Expand Up @@ -1495,8 +1515,15 @@ def test_eom_mode(reg, mod_device, custom_buffer_time, patch_plt_show):
)
assert buffer_delay.type == "delay"

assert seq.current_phase_ref("q0", basis="ground-rydberg") == 0
# Check buffer when EOM is not enabled at the start of the sequence
seq.enable_eom_mode("ch0", amp_on, detuning_on, optimal_detuning_off=-100)
seq.enable_eom_mode(
"ch0",
amp_on,
detuning_on,
optimal_detuning_off=-100,
correct_phase_drift=correct_phase_drift,
)
last_slot = seq._schedule["ch0"][-1]
assert len(seq._schedule["ch0"].eom_blocks) == 2
new_eom_block = seq._schedule["ch0"].eom_blocks[1]
Expand All @@ -1511,6 +1538,23 @@ def test_eom_mode(reg, mod_device, custom_buffer_time, patch_plt_show):
assert last_slot.type == Pulse.ConstantPulse(
duration, 0.0, new_eom_block.detuning_off, last_pulse_slot.type.phase
)
# Check the phase shift that corrects for the drift
phase_ref = (
(new_eom_block.detuning_off * duration * 1e-3)
% (2 * np.pi)
* correct_phase_drift
)
assert seq.current_phase_ref("q0", basis="ground-rydberg") == phase_ref

# Add delay to test the phase drift correction in disable_eom_mode
last_delay_time = 400
seq.delay(last_delay_time, "ch0")

seq.disable_eom_mode("ch0", correct_phase_drift=True)
phase_ref += new_eom_block.detuning_off * last_delay_time * 1e-3
assert seq.current_phase_ref("q0", basis="ground-rydberg") == phase_ref % (
2 * np.pi
)

# Test drawing in eom mode
seq.draw()
Expand Down