Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vytautas-a committed Jun 28, 2024
1 parent a93a59c commit d4de9f0
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 11 deletions.
9 changes: 0 additions & 9 deletions qadence/backends/pyqtorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def run(
unpyqify_state: bool = True,
) -> Tensor:
n_qubits = circuit.abstract.n_qubits
orig_param_values = param_values.pop("orig_param_values", {})
if state is None:
# If no state is passed, we infer the batch_size through the length
# of the individual parameter value tensors.
Expand All @@ -105,8 +104,6 @@ def run(
validate_state(state, n_qubits)
# pyqtorch expects input shape [2] * n_qubits + [batch_size]
state = pyqify(state, n_qubits) if pyqify_state else state
if len(orig_param_values) != 0:
param_values.update({"orig_param_values": orig_param_values})
state = circuit.native.run(state=state, values=param_values)
state = unpyqify(state) if unpyqify_state else state
state = invert_endianness(state) if endianness != self.native_endianness else state
Expand Down Expand Up @@ -168,12 +165,9 @@ def _looped_expectation(
"Looping expectation does not make sense with batched initial state. "
"Define your initial state with `batch_size=1`"
)
orig_param_values = param_values.pop("orig_param_values", {})
list_expvals = []
observables = observable if isinstance(observable, list) else [observable]
for vals in to_list_of_dicts(param_values):
if len(orig_param_values) != 0:
vals.update({"orig_param_values": orig_param_values})
wf = self.run(circuit, vals, state, endianness, pyqify_state=True, unpyqify_state=False)
exs = torch.cat([obs.native(wf, vals) for obs in observables], 0)
list_expvals.append(exs)
Expand Down Expand Up @@ -220,14 +214,11 @@ def sample(
endianness: Endianness = Endianness.BIG,
pyqify_state: bool = True,
) -> list[Counter]:
orig_param_values = param_values.pop("orig_param_values", {})
if state is None:
state = circuit.native.init_state(batch_size=infer_batchsize(param_values))
elif state is not None and pyqify_state:
n_qubits = circuit.abstract.n_qubits
state = pyqify(state, n_qubits) if pyqify_state else state
if len(orig_param_values) != 0:
param_values.update({"orig_param_values": orig_param_values})
samples: list[Counter] = circuit.native.sample(
state=state, values=param_values, n_shots=n_shots
)
Expand Down
1 change: 0 additions & 1 deletion qadence/backends/pyqtorch/convert_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ def Ht(t: Tensor | float) -> Tensor:
sesolve(Ht, unpyqify(state).T[:, 0:1], tsave, self.config.ode_solver).states[-1].T
)
else:
values.pop("orig_param_values", {})
result = apply_operator(
state,
self.unitary(values),
Expand Down
3 changes: 2 additions & 1 deletion qadence/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ def to_list_of_dicts(param_values: ParamDictType) -> list[ParamDictType]:
if not param_values:
return [param_values]

max_batch_size = max(p.size()[0] for p in param_values.values())
max_batch_size = max(p.size()[0] for p in param_values.values() if isinstance(p, Tensor))
batched_values = {
k: (v if v.size()[0] == max_batch_size else v.repeat(max_batch_size, 1))
for k, v in param_values.items()
if isinstance(v, Tensor)
}

return [{k: v[i] for k, v in batched_values.items()} for i in range(max_batch_size)]
Expand Down

0 comments on commit d4de9f0

Please sign in to comment.