From 9e05982721283d909fe18e137c5e00ba3a95a1ff Mon Sep 17 00:00:00 2001 From: Antoine Cornillot <61453516+a-corni@users.noreply.github.com> Date: Fri, 22 Sep 2023 17:10:55 +0200 Subject: [PATCH] Adding register_is_from_calibrated_layout and is_calibrated_layout to Device (#586) * Adding register_is_from_calibrated_layout * Fix descr, testing type of register * Introducing any, modifying instance check --- pulser-core/pulser/devices/_device_datacls.py | 42 +++++++++++++++++++ tests/test_devices.py | 32 +++++++++++--- 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/pulser-core/pulser/devices/_device_datacls.py b/pulser-core/pulser/devices/_device_datacls.py index 2cef6d0b2..c5812006a 100644 --- a/pulser-core/pulser/devices/_device_datacls.py +++ b/pulser-core/pulser/devices/_device_datacls.py @@ -538,6 +538,48 @@ def calibrated_register_layouts(self) -> dict[str, RegisterLayout]: """Register layouts already calibrated on this device.""" return {str(layout): layout for layout in self.pre_calibrated_layouts} + def is_calibrated_layout(self, register_layout: RegisterLayout) -> bool: + """Checks whether a layout is within the calibrated layouts. + + Args: + register_layout: The RegisterLayout to check. + + Returns: + True if register_layout is found among calibrated_register_layouts, + False otherwise. + """ + return any( + [ + register_layout == layout + for layout in list(self.calibrated_register_layouts.values()) + ] + ) + + def register_is_from_calibrated_layout( + self, register: BaseRegister | MappableRegister + ) -> bool: + """Checks whether a register was constructed from a calibrated layout. + + If the register is a BaseRegister, checks that it has a layout. If so, + or if it is a MappableRegister, check that its layout is within the + calibrated layouts. + + Args: + register_layout: the Register or MappableRegister to check. + + Returns: + True if register has a layout and it is found among + calibrated_register_layouts, False otherwise. + """ + if not isinstance(register, (BaseRegister, MappableRegister)): + raise TypeError( + "The register to check must be of type " + "BaseRegister or MappableRegister." + ) + if isinstance(register, BaseRegister) and register.layout is None: + return False + return self.is_calibrated_layout(cast(RegisterLayout, register.layout)) + def to_virtual(self) -> VirtualDevice: """Converts the Device into a VirtualDevice.""" params = self._params() diff --git a/tests/test_devices.py b/tests/test_devices.py index d3d0c7422..54d8aa696 100644 --- a/tests/test_devices.py +++ b/tests/test_devices.py @@ -25,7 +25,10 @@ from pulser.devices import Chadoq2, Device, VirtualDevice from pulser.register import Register, Register3D from pulser.register.register_layout import RegisterLayout -from pulser.register.special_layouts import TriangularLatticeLayout +from pulser.register.special_layouts import ( + SquareLatticeLayout, + TriangularLatticeLayout, +) @pytest.fixture @@ -336,6 +339,8 @@ def test_calibrated_layouts(): pre_calibrated_layouts=(TriangularLatticeLayout(201, 3),), ) + layout100 = TriangularLatticeLayout(100, 6.8) + layout200 = TriangularLatticeLayout(200, 5) TestDevice = Device( name="TestDevice", dimensions=2, @@ -344,15 +349,32 @@ def test_calibrated_layouts(): max_radial_distance=50, min_atom_distance=4, channel_objects=(), - pre_calibrated_layouts=( - TriangularLatticeLayout(100, 6.8), - TriangularLatticeLayout(200, 5), - ), + pre_calibrated_layouts=(layout100, layout200), ) assert TestDevice.calibrated_register_layouts.keys() == { "TriangularLatticeLayout(100, 6.8µm)", "TriangularLatticeLayout(200, 5.0µm)", } + with pytest.raises( + TypeError, + match="The register to check must be of type ", + ): + TestDevice.register_is_from_calibrated_layout(layout100) + assert TestDevice.is_calibrated_layout(layout100) + register = layout200.define_register(*range(10)) + assert TestDevice.register_is_from_calibrated_layout(register) + # Checking a register not built from a layout returns False + assert not TestDevice.register_is_from_calibrated_layout( + Register.triangular_lattice(4, 25, 6.8) + ) + # Checking Layouts that don't match calibrated layouts returns False + square_layout = SquareLatticeLayout(10, 10, 6.8) + layout125 = TriangularLatticeLayout(125, 6.8) + compact_layout = TriangularLatticeLayout(100, 3) + for bad_layout in (square_layout, layout125, compact_layout): + assert not TestDevice.is_calibrated_layout(bad_layout) + register = bad_layout.define_register(*range(10)) + assert not TestDevice.register_is_from_calibrated_layout(register) def test_device_with_virtual_channel():