Skip to content

Commit

Permalink
merge master
Browse files Browse the repository at this point in the history
  • Loading branch information
Simone-Bordoni committed Sep 16, 2024
1 parent b02222c commit 364acbf
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
38 changes: 38 additions & 0 deletions doc/source/code-examples/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import qibo

qibo.set_backend("pytorch")
import torch

from qibo import gates, models

torch.set_anomaly_enabled(True)

# Optimization parameters
nepochs = 1
optimizer = torch.optim.Adam
target_state = torch.ones(4, dtype=torch.complex128) / 2.0

# Define circuit ansatz
params = torch.rand(2, dtype=torch.float64, requires_grad=True)
print(params)
optimizer = optimizer([params])
c = models.Circuit(2)
c.add(gates.RX(0, params[0]))
c.add(gates.RY(1, params[1]))
gate = gates.RY(0, params[1])

print("Gate", gate.matrix())
print(torch.norm(gate.matrix()).grad)

# for _ in range(nepochs):
# optimizer.zero_grad()
# c.set_parameters(params)
# final_state = c().state()
# print("state", final_state)
# fidelity = torch.abs(torch.sum(torch.conj(target_state) * final_state))
# loss = 1 - fidelity
# loss.backward()
# optimizer.step()
# print("state", final_state)
# print("params", params)
# print("loss", loss.grad)
6 changes: 4 additions & 2 deletions src/qibo/backends/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""PyTorch backend."""

from typing import Optional

import numpy as np

from qibo import __version__
Expand Down Expand Up @@ -85,7 +87,7 @@ def cast(
x,
dtype=None,
copy: bool = False,
requires_grad: bool = None,
requires_grad: Optional[bool] = None,
):
"""Casts input as a Torch tensor of the specified dtype.
Expand Down Expand Up @@ -117,7 +119,6 @@ def cast(
# check if dtype is an integer to remove gradients
if dtype in [self.np.int32, self.np.int64, self.np.int8, self.np.int16]:
requires_grad = False

if isinstance(x, self.np.Tensor):
x = x.to(dtype)
elif isinstance(x, list) and all(isinstance(row, self.np.Tensor) for row in x):
Expand All @@ -128,6 +129,7 @@ def cast(
if copy:
return x.clone()

print("Casting", x)
return x

def is_sparse(self, x):
Expand Down

0 comments on commit 364acbf

Please sign in to comment.