diff --git a/src/qibo/backends/pytorch.py b/src/qibo/backends/pytorch.py index 403cea10da..a8233d4e3b 100644 --- a/src/qibo/backends/pytorch.py +++ b/src/qibo/backends/pytorch.py @@ -1,36 +1,24 @@ """PyTorch backend.""" -from typing import Union - import numpy as np -import torch from qibo import __version__ from qibo.backends.npmatrices import NumpyMatrices from qibo.backends.numpy import NumpyBackend -torch_dtype_dict = { - "int": torch.int32, - "float": torch.float32, - "complex": torch.complex64, - "int32": torch.int32, - "int64": torch.int64, - "float32": torch.float32, - "float64": torch.float64, - "complex64": torch.complex64, - "complex128": torch.complex128, -} - class TorchMatrices(NumpyMatrices): """Matrix representation of every gate as a torch Tensor.""" def __init__(self, dtype): + import torch # pylint: disable=import-outside-toplevel + super().__init__(dtype) - self.dtype = torch_dtype_dict[dtype] + self.torch = torch + self.dtype = dtype def _cast(self, x, dtype): - return torch.as_tensor(x, dtype=dtype) + return self.torch.as_tensor(x, dtype=dtype) def Unitary(self, u): return self._cast(u, dtype=self.dtype) @@ -39,34 +27,41 @@ def Unitary(self, u): class PyTorchBackend(NumpyBackend): def __init__(self): super().__init__() + import torch # pylint: disable=import-outside-toplevel + + self.np = torch self.name = "pytorch" self.versions = { "qibo": __version__, "numpy": np.__version__, - "torch": torch.__version__, + "torch": self.np.__version__, } + self.dtype = self._torch_dtype(self.dtype) self.matrices = TorchMatrices(self.dtype) - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.device = self.np.device("cuda:0" if torch.cuda.is_available() else "cpu") self.nthreads = 0 - self.np = torch - self.dtype = torch_dtype_dict[self.dtype] self.tensor_types = (self.np.Tensor, np.ndarray) # These functions in Torch works in a different way than numpy or have different names - self.np.transpose = torch.permute + self.np.transpose = self.np.permute self.np.expand_dims = self.np.unsqueeze - self.np.mod = torch.remainder - self.np.right_shift = torch.bitwise_right_shift + self.np.mod = self.np.remainder + self.np.right_shift = self.np.bitwise_right_shift + + def _torch_dtype(self, dtype): + if dtype == "float": + dtype += "32" + return getattr(self.np, dtype) def set_device(self, device): # pragma: no cover self.device = device def cast( self, - x: Union[torch.Tensor, list[torch.Tensor], np.ndarray, list[np.ndarray]], - dtype: Union[str, torch.dtype, np.dtype, type] = None, + x, + dtype=None, copy: bool = False, ): """Casts input as a Torch tensor of the specified dtype. @@ -86,9 +81,9 @@ def cast( if dtype is None: dtype = self.dtype elif isinstance(dtype, type): - dtype = torch_dtype_dict[dtype.__name__] - elif not isinstance(dtype, torch.dtype): - dtype = torch_dtype_dict[str(dtype)] + dtype = self._torch_dtype(dtype.__name__) + elif not isinstance(dtype, self.np.dtype): + dtype = self._torch_dtype(str(dtype)) if isinstance(x, self.np.Tensor): x = x.to(dtype)