diff --git a/src/qibo/gates/abstract.py b/src/qibo/gates/abstract.py index 19ebb7ab1f..d1582d6243 100644 --- a/src/qibo/gates/abstract.py +++ b/src/qibo/gates/abstract.py @@ -1,4 +1,5 @@ import collections +import json from typing import List, Sequence, Tuple import sympy @@ -6,6 +7,9 @@ from qibo.backends import GlobalBackend from qibo.config import raise_error +REQUIRED_FIELDS = ["name", "init_kwargs", "_target_qubits", "_control_qubits"] +REQUIRED_FIELDS_INIT_KWARGS = ["theta", "phi", "lam"] + class Gate: """The base class for gate implementation. @@ -50,6 +54,24 @@ def __init__(self): self.device_gates = set() self.original_gate = None + def to_json(self): + encoded = self.__dict__ + + encoded_simple = { + key: value for key, value in encoded.items() if key in REQUIRED_FIELDS + } + + encoded_simple["init_kwargs"] = { + key: value + for key, value in encoded_simple["init_kwargs"].items() + if key in REQUIRED_FIELDS_INIT_KWARGS + } + + for value in encoded_simple: + if isinstance(encoded[value], set): + encoded_simple[value] = list(encoded_simple[value]) + return json.dumps(encoded_simple) + @property def target_qubits(self) -> Tuple[int]: """Tuple with ids of target qubits.""" diff --git a/tests/test_gates_abstract.py b/tests/test_gates_abstract.py index 133f103ca0..ab488c4831 100644 --- a/tests/test_gates_abstract.py +++ b/tests/test_gates_abstract.py @@ -1,4 +1,6 @@ """Tests methods defined in `qibo/gates/abstract.py` and `qibo/gates/gates.py`.""" +import json + import numpy as np import pytest @@ -14,6 +16,22 @@ def test_one_qubit_gates_init(gatename): assert gate.target_qubits == (0,) +@pytest.mark.parametrize( + "gatename", ["H", "X", "Y", "Z", "S", "SDG", "T", "TDG", "I", "Align"] +) +def test_one_qubit_gates_to_json(gatename): + gate = getattr(gates, gatename)(0) + + json_general = ( + '{"name": {}, "init_kwargs": {}, "_target_qubits": [0], "_control_qubits": []}' + ) + + json_gate = json.loads(json_general) + json_gate["name"] = gate.name + + assert gate.to_json() == json.dumps(json_gate) + + @pytest.mark.parametrize( "controls,instance", [((1,), "CNOT"), ((1, 2), "TOFFOLI"), ((1, 2, 4), "X")] ) @@ -93,6 +111,32 @@ def test_one_qubit_rotations_init(gatename, params): assert gate.parameters == params +@pytest.mark.parametrize( + "gatename,params", + [ + ("RX", (0.1234,)), + ("RY", (0.1234,)), + ("RZ", (0.1234,)), + ("U1", (0.1234,)), + ("U2", (0.1234, 0.4321)), + ("U3", (0.1234, 0.4321, 0.5678)), + ], +) +def test_one_qubit_rotations_to_json(gatename, params): + gate = getattr(gates, gatename)(0, *params) + + json_general = ( + '{"name": {}, "init_kwargs": {}, "_target_qubits": [0], "_control_qubits": []}' + ) + + json_gate = json.loads(json_general) + json_gate["name"] = gate.name + json_gate["init_kwargs"] = gate.init_kwargs + del json_gate["init_kwargs"]["trainable"] + + assert gate.to_json() == json.dumps(json_gate) + + @pytest.mark.parametrize( "gatename,params", [