Skip to content

Commit

Permalink
Finish up changes in the config base classes
Browse files Browse the repository at this point in the history
  • Loading branch information
HGSilveri committed Nov 19, 2024
1 parent 694c3d3 commit 6e49225
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 9 deletions.
96 changes: 88 additions & 8 deletions pulser-core/pulser/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""Defines the backend configuration classes."""
from __future__ import annotations

import copy
import warnings
from dataclasses import dataclass, field
from typing import (
Any,
Expand Down Expand Up @@ -46,9 +48,33 @@ class BackendConfig:

def __init__(self, **backend_options: Any) -> None:
"""Initializes the backend config."""
self._backend_options = backend_options
# TODO: Deprecate use of backend_options kwarg
# TODO: Filter for accepted kwargs
cls_name = self.__class__.__name__
if invalid_kwargs := (
set(backend_options)
- (self._expected_kwargs() | {"backend_options"})
):
warnings.warn(
f"{cls_name!r} received unexpected keyword arguments: "
f"{invalid_kwargs}; only the following keyword "
f"arguments are expected: {self._expected_kwargs()}.",
stacklevel=2,
)
# Prevents potential issues with mutable arguments
self._backend_options = copy.deepcopy(backend_options)
if "backend_options" in backend_options:
with warnings.catch_warnings():
warnings.filterwarnings("always")
warnings.warn(
f"The 'backend_options' argument of {cls_name!r} "
"has been deprecated. Please provide the options "
f"as keyword arguments directly to '{cls_name}()'.",
DeprecationWarning,
stacklevel=2,
)
self._backend_options.update(backend_options["backend_options"])

def _expected_kwargs(self) -> set[str]:
return set()

def __getattr__(self, name: str) -> Any:
if (
Expand All @@ -57,13 +83,40 @@ def __getattr__(self, name: str) -> Any:
and name in self._backend_options
):
return self._backend_options[name]
raise AttributeError # TODO:
raise AttributeError(f"{name!r} has not been passed to {self!r}.")


class EmulationConfig(BackendConfig, Generic[StateType]):
"""Configurates an emulation on a backend."""
"""Configures an emulation on a backend.
Args:
observables: A sequence of observables to compute at specific
evaluation times. The observables without specified evaluation
times will use this configuration's 'default_evaluation_times'.
default_evaluation_times: The default times at which observables
are computed. Can be a sequence of relative times between 0
(the start of the sequence) and 1 (the end of the sequence).
Can also be specified as "Full", in which case every step in the
emulation will also be an evaluation times.
initial_state: The initial state from which emulation starts. If
specified, the state type needs to be compatible with the emulator
backend. If left undefined, defaults to starting with all qudits
in the ground state.
with_modulation: Whether to emulate the sequence with the programmed
input or the expected output.
interaction_matrix: An optional interaction matrix to replace the
interaction terms in the Hamiltonian. For an N-qudit system,
must be an NxN symmetric matrix where entry (i, j) dictates
the interaction coefficient between qudits i and j, ie it replaces
the C/r_{ij}^6.
prefer_device_noise_model: If the sequence's device has a default noise
model, this option signals the backend to prefer it over the noise
model given with this configuration.
noise_model: An optional noise model to emulate the sequence with.
Ignored if the sequence's device has default noise model and
`prefer_device_noise_model=True`.
"""

# TODO: Complete docstring
observables: Sequence[Observable]
default_evaluation_times: np.ndarray | Literal["Full"]
initial_state: StateType | None
Expand Down Expand Up @@ -109,7 +162,23 @@ def __init__(
f" got object of type {type(initial_state)} instead."
)

# TODO: Validate interaction matrix
if interaction_matrix is not None:
interaction_matrix = pm.AbstractArray(interaction_matrix)
_shape = interaction_matrix.shape
if len(_shape) != 2 or _shape[0] != _shape[1]:
raise ValueError(
"'interaction_matrix' must be a square matrix. Instead, "
f"an array of shape {_shape} was given."
)
if (
initial_state is not None
and _shape[0] != initial_state.n_qudits
):
raise ValueError(
f"The received interaction matrix of shape {_shape} is "
"incompatible with the received initial state of "
f"{initial_state.n_qudits} qudits."
)

if not isinstance(noise_model, NoiseModel):
raise TypeError(
Expand All @@ -128,6 +197,17 @@ def __init__(
**backend_options,
)

def _expected_kwargs(self) -> set[str]:
return super()._expected_kwargs() | {
"observables",
"default_evaluation_times",
"initial_state",
"with_modulation",
"interaction_matrix",
"prefer_device_noise_model",
"noise_model",
}

def is_evaluation_time(self, t: float, tol: float = 1e-6) -> bool:
"""Assesses whether a relative time is an evaluation time."""
return 0.0 <= t <= 1.0 and (
Expand Down Expand Up @@ -198,7 +278,7 @@ class EmulatorConfig(BackendConfig):
noise_model: NoiseModel = field(default_factory=NoiseModel)

def __post_init__(self) -> None:
# TODO: Raise deprecation warning
# TODO: Deprecate
if not (0 < self.sampling_rate <= 1.0):
raise ValueError(
"The sampling rate (`sampling_rate` = "
Expand Down
17 changes: 16 additions & 1 deletion pulser-core/pulser/backend/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,29 @@ class State(ABC, Generic[ArgScalarType, ReturnScalarType]):
methods below.
"""

eigenstates: Sequence[Eigenstate]
_eigenstates: Sequence[Eigenstate]

@property
@abstractmethod
def n_qudits(self) -> int:
"""The number of qudits in the state."""
pass

@property
def eigenstates(self) -> tuple[Eigenstate, ...]:
"""The eigenstates that form a qudit's eigenbasis.
The order of the states should match the order in a
numerical (ie state vector or density matrix)
representation.
"""
return tuple(self._eigenstates)

@property
def qudit_dim(self) -> int:
"""The dimensions (ie number of eigenstates) of a qudit."""
return len(self.eigenstates)

@abstractmethod
def overlap(self: StateType, other: StateType, /) -> ReturnScalarType:
"""Compute the overlap between this state and another of the same type.
Expand Down

0 comments on commit 6e49225

Please sign in to comment.