Skip to content

Commit

Permalink
Fixed tests [requires modification of pyqtorch]
Browse files Browse the repository at this point in the history
  • Loading branch information
vytautas-a committed Jun 26, 2024
1 parent b738d83 commit 8cc6dab
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
11 changes: 9 additions & 2 deletions qadence/backends/pyqtorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def run(
unpyqify_state: bool = True,
) -> Tensor:
n_qubits = circuit.abstract.n_qubits

orig_param_values = param_values.pop("orig_param_values", {})
if state is None:
# If no state is passed, we infer the batch_size through the length
# of the individual parameter value tensors.
Expand All @@ -105,6 +105,8 @@ def run(
validate_state(state, n_qubits)
# pyqtorch expects input shape [2] * n_qubits + [batch_size]
state = pyqify(state, n_qubits) if pyqify_state else state
if len(orig_param_values) != 0:
param_values.update({"orig_param_values": orig_param_values})
state = circuit.native.run(state=state, values=param_values)
state = unpyqify(state) if unpyqify_state else state
state = invert_endianness(state) if endianness != self.native_endianness else state
Expand Down Expand Up @@ -166,10 +168,12 @@ def _looped_expectation(
"Looping expectation does not make sense with batched initial state. "
"Define your initial state with `batch_size=1`"
)

orig_param_values = param_values.pop("orig_param_values", {})
list_expvals = []
observables = observable if isinstance(observable, list) else [observable]
for vals in to_list_of_dicts(param_values):
if len(orig_param_values) != 0:
vals.update({"orig_param_values": orig_param_values})
wf = self.run(circuit, vals, state, endianness, pyqify_state=True, unpyqify_state=False)
exs = torch.cat([obs.native(wf, vals) for obs in observables], 0)
list_expvals.append(exs)
Expand Down Expand Up @@ -216,11 +220,14 @@ def sample(
endianness: Endianness = Endianness.BIG,
pyqify_state: bool = True,
) -> list[Counter]:
orig_param_values = param_values.pop("orig_param_values", {})
if state is None:
state = circuit.native.init_state(batch_size=infer_batchsize(param_values))
elif state is not None and pyqify_state:
n_qubits = circuit.abstract.n_qubits
state = pyqify(state, n_qubits) if pyqify_state else state
if len(orig_param_values) != 0:
param_values.update({"orig_param_values": orig_param_values})
samples: list[Counter] = circuit.native.sample(
state=state, values=param_values, n_shots=n_shots
)
Expand Down
3 changes: 2 additions & 1 deletion qadence/backends/pyqtorch/convert_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def forward(
state: Tensor,
values: dict[str, Tensor],
) -> Tensor:
if self.block.generator.is_time_dependent: # type: ignore [union-attr]
if getattr(self.block.generator, "is_time_dependent", False): # type: ignore [union-attr]

def Ht(t: Tensor | float) -> Tensor:
# values dict has to change with new value of t
Expand Down Expand Up @@ -363,6 +363,7 @@ def Ht(t: Tensor | float) -> Tensor:
sesolve(Ht, unpyqify(state).T[:, 0:1], tsave, self.config.ode_solver).states[-1].T
)
else:
values.pop("orig_param_values", {})
result = apply_operator(
state,
self.unitary(values),
Expand Down
1 change: 1 addition & 0 deletions tests/engines/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def test_embeddings() -> None:

inputs = {"x": torch.ones(batch_size), "y": torch.rand(batch_size)}
low_level_params = embed(params, inputs)
low_level_params.pop("orig_param_values", {})

assert len(list(low_level_params.keys())) == 9

Expand Down
4 changes: 4 additions & 0 deletions tests/qadence/test_quantum_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def test_quantum_model_parameters(parametric_circuit: QuantumCircuit) -> None:
assert len([i for i in model_psr.parameters()]) == 4
assert len([i for i in model_ad.parameters()]) == 4
embedded_params_psr = model_psr.embedding_fn(model_psr._params, {"x": torch.rand(1)})
embedded_params_psr.pop("orig_param_values", {})
embedded_params_ad = model_ad.embedding_fn(model_ad._params, {"x": torch.rand(1)})
embedded_params_ad.pop("orig_param_values", {})
assert len(embedded_params_ad) == 5
assert len(embedded_params_psr) == 6

Expand All @@ -57,7 +59,9 @@ def test_quantum_model_duplicate_expr(duplicate_expression_circuit: QuantumCircu
assert len([i for i in model_psr.parameters()]) == 3
assert len([i for i in model_ad.parameters()]) == 3
embedded_params_psr = model_psr.embedding_fn(model_psr._params, {"x": torch.rand(1)})
embedded_params_psr.pop("orig_param_values", {})
embedded_params_ad = model_ad.embedding_fn(model_ad._params, {"x": torch.rand(1)})
embedded_params_ad.pop("orig_param_values", {})
assert len(embedded_params_ad) == 2
assert len(embedded_params_psr) == 8

Expand Down

0 comments on commit 8cc6dab

Please sign in to comment.