From 4601d6f105ad1f935dd8ddc99594bffc788c7100 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrique=20Silv=C3=A9rio?= <29920212+HGSilveri@users.noreply.github.com> Date: Thu, 11 Jan 2024 10:22:58 +0100 Subject: [PATCH] FIX: Coordinate matching in `WeightMap.get_qubit_weight_map()` (#631) --- pulser-core/pulser/register/weight_maps.py | 9 +++++---- tests/test_dmm.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pulser-core/pulser/register/weight_maps.py b/pulser-core/pulser/register/weight_maps.py index c5ec38746..a2d0e446b 100644 --- a/pulser-core/pulser/register/weight_maps.py +++ b/pulser-core/pulser/register/weight_maps.py @@ -79,11 +79,12 @@ def get_qubit_weight_map( coords_arr = self.sorted_coords weights_arr = self.sorted_weights for qid, pos in qubits.items(): - dists = np.round( - np.linalg.norm(coords_arr - np.array(pos), axis=1), - decimals=COORD_PRECISION, + matches = np.argwhere( + np.all( + np.isclose(coords_arr, pos, atol=10 ** (-COORD_PRECISION)), + axis=1, + ) ) - matches = np.argwhere(dists == 0.0) qubit_weight_map[qid] = float(np.sum(weights_arr[matches])) return qubit_weight_map diff --git a/tests/test_dmm.py b/tests/test_dmm.py index fe0f8fba8..dad3d1aaf 100644 --- a/tests/test_dmm.py +++ b/tests/test_dmm.py @@ -20,6 +20,7 @@ import numpy as np import pytest +import pulser from pulser.channels.dmm import DMM from pulser.pulse import Pulse from pulser.register.base_register import BaseRegister @@ -106,6 +107,19 @@ def test_qubit_weight_map(self, register): 2: 0.0, } + tri_layout = TriangularLatticeLayout(100, spacing=5) + sites = [31, 53, 39, 62, 43, 49, 42, 37, 48, 44, 55, 50] + labels = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"] + weights = [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + + positions = np.array([tri_layout.traps_dict[i] for i in sites]) + reg = pulser.Register.from_coordinates( + positions, labels=labels, center=True + ) + det_map = reg.define_detuning_map(dict(zip(labels, weights))) + qubit_weight_map = det_map.get_qubit_weight_map(reg.qubits) + assert qubit_weight_map == dict(zip(labels, weights)) + def test_hash(self, det_map, det_dict, layout): disordered_det_dict = { i: det_dict[i] for i in sorted(det_dict, reverse=True)