Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give access to all EOM block parameters and allow for phase drift correction #566

Merged
merged 8 commits into from
Aug 29, 2023
7 changes: 6 additions & 1 deletion pulser-core/pulser/json/abstract_repr/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def _deserialize_operation(seq: Sequence, op: dict, vars: dict) -> None:
optimal_detuning_off=_deserialize_parameter(
op["optimal_detuning_off"], vars
),
correct_phase_drift=op.get("correct_phase_drift", False),
)
elif op["op"] == "add_eom_pulse":
seq.add_eom_pulse(
Expand All @@ -273,9 +274,13 @@ def _deserialize_operation(seq: Sequence, op: dict, vars: dict) -> None:
op["post_phase_shift"], vars
),
protocol=op["protocol"],
correct_phase_drift=op.get("correct_phase_drift", False),
)
elif op["op"] == "disable_eom_mode":
seq.disable_eom_mode(channel=op["channel"])
seq.disable_eom_mode(
channel=op["channel"],
correct_phase_drift=op.get("correct_phase_drift", False),
)


def _deserialize_channel(obj: dict[str, Any]) -> Channel:
Expand Down
12 changes: 12 additions & 0 deletions pulser-core/pulser/json/abstract_repr/schemas/sequence-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,10 @@
"$ref": "#/definitions/ChannelName",
"description": "The name of the channel to take out of EOM mode."
},
"correct_phase_drift": {
"description": "Performs a phase shift to correct for the phase drift that occured since the last pulse (or the start of the EOM mode, if no pulse was added).",
"type": "boolean"
},
"op": {
"const": "disable_eom_mode",
"type": "string"
Expand All @@ -467,6 +471,10 @@
"$ref": "#/definitions/ChannelName",
"description": "The name of the channel to add the pulse to."
},
"correct_phase_drift": {
"description": "Performs a phase shift to correct for the phase drift that occured since the last pulse (or the start of the EOM mode, if adding the first pulse).",
"type": "boolean"
},
"duration": {
"$ref": "#/definitions/ParametrizedNum",
"description": "The duration of the pulse (in ns)."
Expand Down Expand Up @@ -514,6 +522,10 @@
"$ref": "#/definitions/ChannelName",
"description": "The name of the channel to put in EOM mode."
},
"correct_phase_drift": {
"description": "Performs a phase shift to correct for the phase drift incurred while turning on the EOM mode.",
"type": "boolean"
},
"detuning_on": {
"$ref": "#/definitions/ParametrizedNum",
"description": "The detuning of the EOM pulses (in rad/µs)."
Expand Down
31 changes: 27 additions & 4 deletions pulser-core/pulser/json/abstract_repr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ def get_all_args(
}
return {**default_values, **params}

def remove_kwarg_if_default(
data: dict[str, Any], call_name: str, kwarg_name: str
) -> dict[str, Any]:
if data.get(kwarg_name, None) == get_kwarg_default(
call_name, kwarg_name
):
data.pop(kwarg_name, None)
return data

HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
operations = res["operations"]
for call in chain(seq._calls, seq._to_build_calls):
if call.name == "__init__":
Expand Down Expand Up @@ -269,9 +278,18 @@ def get_all_args(
res["slm_mask_targets"] = tuple(seq._slm_mask_targets)
elif call.name == "enable_eom_mode":
data = get_all_args(
("channel", "amp_on", "detuning_on", "optimal_detuning_off"),
(
"channel",
"amp_on",
"detuning_on",
"optimal_detuning_off",
"correct_phase_drift",
),
call,
)
data = remove_kwarg_if_default(
data, call.name, "correct_phase_drift"
)
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
operations.append({"op": "enable_eom_mode", **data})
elif call.name == "add_eom_pulse":
data = get_all_args(
Expand All @@ -281,15 +299,20 @@ def get_all_args(
"phase",
"post_phase_shift",
"protocol",
"correct_phase_drift",
),
call,
)
data = remove_kwarg_if_default(
data, call.name, "correct_phase_drift"
)
operations.append({"op": "add_eom_pulse", **data})
elif call.name == "disable_eom_mode":
data = get_all_args(("channel",), call)
operations.append(
{"op": "disable_eom_mode", "channel": data["channel"]}
data = get_all_args(("channel", "correct_phase_drift"), call)
data = remove_kwarg_if_default(
data, call.name, "correct_phase_drift"
)
operations.append({"op": "disable_eom_mode", **data})
else:
raise AbstractReprError(f"Unknown call '{call.name}'.")

Expand Down
84 changes: 62 additions & 22 deletions tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,32 +732,48 @@ def test_mappable_register(self, triangular_lattice):
]
assert abstract["variables"]["var"] == dict(type="int", value=[0])

def test_eom_mode(self, triangular_lattice):
@pytest.mark.parametrize("correct_phase_drift", (False, True))
def test_eom_mode(self, triangular_lattice, correct_phase_drift):
reg = triangular_lattice.hexagonal_register(7)
seq = Sequence(reg, IroiseMVP)
seq.declare_channel("ryd", "rydberg_global")
det_off = seq.declare_variable("det_off", dtype=float)
duration = seq.declare_variable("duration", dtype=int)
seq.enable_eom_mode(
"ryd", amp_on=3.0, detuning_on=0.0, optimal_detuning_off=det_off
"ryd",
amp_on=3.0,
detuning_on=0.0,
optimal_detuning_off=det_off,
correct_phase_drift=correct_phase_drift,
)
seq.add_eom_pulse(
"ryd", duration, 0.0, correct_phase_drift=correct_phase_drift
)
seq.add_eom_pulse("ryd", duration, 0.0)
seq.delay(duration, "ryd")
seq.disable_eom_mode("ryd")
seq.disable_eom_mode("ryd", correct_phase_drift)

abstract = json.loads(seq.to_abstract_repr())
validate_schema(abstract)

extra_kwargs = (
dict(correct_phase_drift=correct_phase_drift)
if correct_phase_drift
else {}
)

assert abstract["operations"][0] == {
"op": "enable_eom_mode",
"channel": "ryd",
"amp_on": 3.0,
"detuning_on": 0.0,
"optimal_detuning_off": {
"expression": "index",
"lhs": {"variable": "det_off"},
"rhs": 0,
**{
"op": "enable_eom_mode",
"channel": "ryd",
"amp_on": 3.0,
"detuning_on": 0.0,
"optimal_detuning_off": {
"expression": "index",
"lhs": {"variable": "det_off"},
"rhs": 0,
},
},
**extra_kwargs,
}

ser_duration = {
Expand All @@ -766,17 +782,23 @@ def test_eom_mode(self, triangular_lattice):
"rhs": 0,
}
assert abstract["operations"][1] == {
"op": "add_eom_pulse",
"channel": "ryd",
"duration": ser_duration,
"phase": 0.0,
"post_phase_shift": 0.0,
"protocol": "min-delay",
**{
"op": "add_eom_pulse",
"channel": "ryd",
"duration": ser_duration,
"phase": 0.0,
"post_phase_shift": 0.0,
"protocol": "min-delay",
},
**extra_kwargs,
}

assert abstract["operations"][3] == {
"op": "disable_eom_mode",
"channel": "ryd",
**{
"op": "disable_eom_mode",
"channel": "ryd",
},
**extra_kwargs,
}

@pytest.mark.parametrize("use_default", [True, False])
Expand Down Expand Up @@ -877,6 +899,10 @@ def _check_roundtrip(serialized_seq: dict[str, Any]):
*(op[wf][qty] for qty in wf_args)
)
op[wf] = reconstructed_wf._to_abstract_repr()
elif "eom" in op["op"] and not op.get("correct_phase_drift"):
# Remove correct_phase_drift when at default, since the
# roundtrip will delete it
op.pop("correct_phase_drift", None)

seq = Sequence.from_abstract_repr(json.dumps(s))
defaults = {
Expand Down Expand Up @@ -1425,7 +1451,8 @@ def test_deserialize_parametrized_pulse(self, op, pulse_cls):
else:
assert pulse.kwargs["detuning"] == 1

def test_deserialize_eom_ops(self):
@pytest.mark.parametrize("correct_phase_drift", (False, True, None))
def test_deserialize_eom_ops(self, correct_phase_drift):
s = _get_serialized_seq(
operations=[
{
Expand All @@ -1434,6 +1461,7 @@ def test_deserialize_eom_ops(self):
"amp_on": 3.0,
"detuning_on": 0.0,
"optimal_detuning_off": -1.0,
"correct_phase_drift": correct_phase_drift,
},
{
"op": "add_eom_pulse",
Expand All @@ -1446,16 +1474,21 @@ def test_deserialize_eom_ops(self):
"phase": 0.0,
"post_phase_shift": 0.0,
"protocol": "no-delay",
"correct_phase_drift": correct_phase_drift,
},
{
"op": "disable_eom_mode",
"channel": "global",
"correct_phase_drift": correct_phase_drift,
},
],
variables={"duration": {"type": "int", "value": [100]}},
device=json.loads(IroiseMVP.to_abstract_repr()),
channels={"global": "rydberg_global"},
)
if correct_phase_drift is None:
for op in s["operations"]:
del op["correct_phase_drift"]
_check_roundtrip(s)
seq = Sequence.from_abstract_repr(json.dumps(s))
# init + declare_channel + enable_eom_mode
Expand All @@ -1470,11 +1503,15 @@ def test_deserialize_eom_ops(self):
"amp_on": 3.0,
"detuning_on": 0.0,
"optimal_detuning_off": -1.0,
"correct_phase_drift": bool(correct_phase_drift),
}

disable_eom_call = seq._to_build_calls[-1]
assert disable_eom_call.name == "disable_eom_mode"
assert disable_eom_call.kwargs == {"channel": "global"}
assert disable_eom_call.kwargs == {
"channel": "global",
"correct_phase_drift": bool(correct_phase_drift),
}

eom_pulse_call = seq._to_build_calls[0]
assert eom_pulse_call.name == "add_eom_pulse"
Expand All @@ -1483,6 +1520,9 @@ def test_deserialize_eom_ops(self):
assert eom_pulse_call.kwargs["phase"] == 0.0
assert eom_pulse_call.kwargs["post_phase_shift"] == 0.0
assert eom_pulse_call.kwargs["protocol"] == "no-delay"
assert eom_pulse_call.kwargs["correct_phase_drift"] == bool(
correct_phase_drift
)

@pytest.mark.parametrize(
"wf_obj",
Expand Down