Skip to content

Commit

Permalink
[Fix] Handle nested parameter dicts (#482)
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz authored Jun 27, 2024
1 parent 9a5e038 commit 3a0bb75
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dependencies = [
"jsonschema",
"nevergrad",
"scipy",
"pyqtorch==1.2.3",
"pyqtorch==1.2.4",
"pyyaml",
"matplotlib",
"Arpeggio==2.0.2",
Expand Down
2 changes: 1 addition & 1 deletion qadence/backends/pyqtorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from qadence.backend import Backend as BackendInterface
from qadence.backend import ConvertedCircuit, ConvertedObservable
from qadence.backends.utils import (
infer_batchsize,
pyqify,
to_list_of_dicts,
unpyqify,
Expand All @@ -31,7 +32,6 @@
transpile,
)
from qadence.types import BackendName, Endianness, Engine
from qadence.utils import infer_batchsize

from .config import Configuration, default_passes
from .convert_ops import convert_block
Expand Down
17 changes: 15 additions & 2 deletions qadence/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,22 @@ def validate_state(state: Tensor, n_qubits: int) -> None:
)


def infer_batchsize(param_values: ParamDictType = None) -> int:
def infer_batchsize(param_values: dict[str, Tensor] = None) -> int:
"""Infer the batch_size through the length of the parameter tensors."""
return max([len(tensor) for tensor in param_values.values()]) if param_values else 1
try:
return (
max(
[
len(tensor_or_dict)
for tensor_or_dict in param_values.values()
if isinstance(tensor_or_dict, Tensor)
]
)
if param_values
else 1
)
except Exception:
return 1


# The following functions can be used to compute potentially higher order gradients using pyqtorch's
Expand Down
10 changes: 2 additions & 8 deletions qadence/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,18 +234,12 @@ def is_qadence_shape(state: ArrayLike, n_qubits: int) -> bool:
return state.shape[1] == 2**n_qubits # type: ignore[no-any-return]


def infer_batchsize(param_values: dict[str, Tensor] = None) -> int:
"""Infer the batch_size through the length of the parameter tensors."""
try:
return max([len(tensor) for tensor in param_values.values()]) if param_values else 1
except Exception:
return 1


def validate_values_and_state(
state: ArrayLike | None, n_qubits: int, param_values: dict[str, Tensor] = None
) -> None:
if state is not None:
from qadence.backends.utils import infer_batchsize

if isinstance(state, Tensor):
if state is not None:
batch_size_state = (
Expand Down

0 comments on commit 3a0bb75

Please sign in to comment.