Skip to content

Commit

Permalink
Implement more flexible evaluation time matching
Browse files Browse the repository at this point in the history
  • Loading branch information
HGSilveri committed Nov 18, 2024
1 parent 0310199 commit 694c3d3
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 9 deletions.
41 changes: 34 additions & 7 deletions pulser-core/pulser/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Sequence,
SupportsFloat,
TypeVar,
cast,
get_args,
)

Expand All @@ -41,10 +42,13 @@
class BackendConfig:
"""The base backend configuration."""

_backend_options: dict[str, Any]

def __init__(self, **backend_options: Any) -> None:
"""Initializes the backend config."""
object.__setattr__(self, "_backend_options", backend_options)
self._backend_options = backend_options
# TODO: Deprecate use of backend_options kwarg
# TODO: Filter for accepted kwargs

def __getattr__(self, name: str) -> Any:
if (
Expand All @@ -61,7 +65,7 @@ class EmulationConfig(BackendConfig, Generic[StateType]):

# TODO: Complete docstring
observables: Sequence[Observable]
default_evaluation_times: tuple[float]
default_evaluation_times: np.ndarray | Literal["Full"]
initial_state: StateType | None
with_modulation: bool
interaction_matrix: pm.AbstractArray | None
Expand All @@ -73,7 +77,9 @@ def __init__(
*,
observables: Sequence[Observable] = (),
# Default evaluation times for observables that don't specify one
default_evaluation_times: Sequence[SupportsFloat] = (1.0,),
default_evaluation_times: Sequence[SupportsFloat] | Literal["Full"] = (
1.0,
),
initial_state: StateType | None = None, # Default is ggg...
with_modulation: bool = False,
interaction_matrix: ArrayLike | None = None,
Expand All @@ -89,10 +95,13 @@ def __init__(
f"Observable. Instead, got instance of type {type(obs)}."
)

default_evaluation_times = tuple(map(float, default_evaluation_times))
eval_times_arr = np.array(default_evaluation_times, dtype=float)
if np.any((eval_times_arr < 0.0) | (eval_times_arr > 1.0)):
raise ValueError("All evaluation times must be between 0. and 1.")
if default_evaluation_times != "Full":
eval_times_arr = np.array(default_evaluation_times, dtype=float)
if np.any((eval_times_arr < 0.0) | (eval_times_arr > 1.0)):
raise ValueError(
"All evaluation times must be between 0. and 1."
)
default_evaluation_times = cast(Sequence[float], eval_times_arr)

if initial_state is not None and not isinstance(initial_state, State):
raise TypeError(
Expand All @@ -119,6 +128,24 @@ def __init__(
**backend_options,
)

def is_evaluation_time(self, t: float, tol: float = 1e-6) -> bool:
"""Assesses whether a relative time is an evaluation time."""
return 0.0 <= t <= 1.0 and (
self.default_evaluation_times == "Full"
or self.is_time_in_evaluation_times(
t, self.default_evaluation_times, tol=tol
)
)

@staticmethod
def is_time_in_evaluation_times(
t: float, evaluation_times: ArrayLike, tol: float = 1e-6
) -> bool:
"""Checks if a time is within a collection of evaluation times."""
return bool(
np.any(np.abs(np.array(evaluation_times, dtype=float) - t) <= tol)
)


# Legacy class

Expand Down
13 changes: 12 additions & 1 deletion pulser-core/pulser/backend/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __call__(
state: State,
hamiltonian: Operator,
result: Results,
time_tol: float,
) -> None:
"""Specifies a call to the callback at a specific time.
Expand All @@ -52,6 +53,8 @@ def __call__(
state: The current state.
hamiltonian: The Hamiltonian at this time.
result: The Results object to store the result in.
time_tol: Tolerance below which two time values are considered
equal.
"""
pass

Expand All @@ -76,6 +79,7 @@ def __call__(
state: State,
hamiltonian: Operator,
result: Results,
time_tol: float,
) -> None:
"""Specifies a call to the observable at a specific time.
Expand All @@ -91,8 +95,15 @@ def __call__(
state: The current state.
hamiltonian: The Hamiltonian at this time.
result: The Results object to store the result in.
time_tol: Tolerance below which two time values are considered
equal.
"""
if t in (self.evaluation_times or config.default_evaluation_times):
if (
self.evaluation_times is not None
and config.is_time_in_evaluation_times(
t, self.evaluation_times, tol=time_tol
)
) or config.is_evaluation_time(t, tol=time_tol):
value_to_store = self.apply(
config=config, t=t, state=state, hamiltonian=hamiltonian
)
Expand Down
2 changes: 1 addition & 1 deletion pulser-core/pulser/backend/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class State(ABC, Generic[ArgScalarType, ReturnScalarType]):

eigenstates: Sequence[Eigenstate]

@abstractmethod
@property
@abstractmethod
def n_qudits(self) -> int:
"""The number of qudits in the state."""
pass
Expand Down

0 comments on commit 694c3d3

Please sign in to comment.