diff --git a/pyproject.toml b/pyproject.toml index 88f8181c3..21ef2bf96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dependencies = [ "jsonschema", "nevergrad", "scipy", - "pyqtorch==1.2.3", + "pyqtorch==1.2.4", "pyyaml", "matplotlib", "Arpeggio==2.0.2", diff --git a/qadence/backends/pyqtorch/backend.py b/qadence/backends/pyqtorch/backend.py index 9dbf5c535..5330dc03b 100644 --- a/qadence/backends/pyqtorch/backend.py +++ b/qadence/backends/pyqtorch/backend.py @@ -12,6 +12,7 @@ from qadence.backend import Backend as BackendInterface from qadence.backend import ConvertedCircuit, ConvertedObservable from qadence.backends.utils import ( + infer_batchsize, pyqify, to_list_of_dicts, unpyqify, @@ -31,7 +32,6 @@ transpile, ) from qadence.types import BackendName, Endianness, Engine -from qadence.utils import infer_batchsize from .config import Configuration, default_passes from .convert_ops import convert_block diff --git a/qadence/backends/utils.py b/qadence/backends/utils.py index 311c2a980..0fe8702a1 100644 --- a/qadence/backends/utils.py +++ b/qadence/backends/utils.py @@ -143,9 +143,22 @@ def validate_state(state: Tensor, n_qubits: int) -> None: ) -def infer_batchsize(param_values: ParamDictType = None) -> int: +def infer_batchsize(param_values: dict[str, Tensor] = None) -> int: """Infer the batch_size through the length of the parameter tensors.""" - return max([len(tensor) for tensor in param_values.values()]) if param_values else 1 + try: + return ( + max( + [ + len(tensor_or_dict) + for tensor_or_dict in param_values.values() + if isinstance(tensor_or_dict, Tensor) + ] + ) + if param_values + else 1 + ) + except Exception: + return 1 # The following functions can be used to compute potentially higher order gradients using pyqtorch's diff --git a/qadence/utils.py b/qadence/utils.py index 87dee9348..99215c4e1 100644 --- a/qadence/utils.py +++ b/qadence/utils.py @@ -234,18 +234,12 @@ def is_qadence_shape(state: ArrayLike, n_qubits: int) -> bool: return state.shape[1] == 2**n_qubits # type: ignore[no-any-return] -def infer_batchsize(param_values: dict[str, Tensor] = None) -> int: - """Infer the batch_size through the length of the parameter tensors.""" - try: - return max([len(tensor) for tensor in param_values.values()]) if param_values else 1 - except Exception: - return 1 - - def validate_values_and_state( state: ArrayLike | None, n_qubits: int, param_values: dict[str, Tensor] = None ) -> None: if state is not None: + from qadence.backends.utils import infer_batchsize + if isinstance(state, Tensor): if state is not None: batch_size_state = (