From eea0cc08333e1609b496d5ec29241b576a5c6a32 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Tue, 3 Sep 2024 16:23:14 +0200 Subject: [PATCH] Backpropagation and fix test for OrthogonalBlock Co-authored-by: Dario Coscia Co-authored-by: Gabriele Codega --- pina/model/layers/__init__.py | 2 + pina/model/layers/orthogonal.py | 103 +++++++++++++++++++++------ tests/test_layers/test_orthogonal.py | 45 ++++++++++-- 3 files changed, 125 insertions(+), 25 deletions(-) diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 898ca43b..5108522c 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -9,6 +9,7 @@ "FourierBlock2D", "FourierBlock3D", "PODBlock", + "OrthogonalBlock", "PeriodicBoundaryEmbedding", "FourierFeatureEmbedding", "AVNOBlock", @@ -25,6 +26,7 @@ ) from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D from .pod import PODBlock +from .orthogonal import OrthogonalBlock from .embedding import PeriodicBoundaryEmbedding, FourierFeatureEmbedding from .avno_layer import AVNOBlock from .lowrank_layer import LowRankBlock diff --git a/pina/model/layers/orthogonal.py b/pina/model/layers/orthogonal.py index 0edcace8..1d7cdddc 100644 --- a/pina/model/layers/orthogonal.py +++ b/pina/model/layers/orthogonal.py @@ -1,23 +1,33 @@ -"""Module for OrthogonalBlock layer, to make the input orthonormal.""" +"""Module for OrthogonalBlock.""" import torch +from ...utils import check_consistency class OrthogonalBlock(torch.nn.Module): """ Module to make the input orthonormal. - The module takes a tensor of size [N, M] and returns a tensor of - size [N, M] where the columns are orthonormal. + The module takes a tensor of size :math:`[N, M]` and returns a tensor of + size :math:`[N, M]` where the columns are orthonormal. The block performs a + Gram Schmidt orthogonalization process for the input, see + `here ` for + details. """ - def __init__(self, dim=-1): + def __init__(self, dim=-1, requires_grad=True): """ Initialize the OrthogonalBlock module. :param int dim: The dimension where to orthogonalize. + :param bool requires_grad: If autograd should record operations on + the returned tensor, defaults to True. """ super().__init__() + # store dim self.dim = dim + # store requires_grad + check_consistency(requires_grad, bool) + self._requires_grad = requires_grad def forward(self, X): """ @@ -26,7 +36,8 @@ def forward(self, X): :raises Warning: If the dimension is greater than the other dimensions. - :param torch.Tensor X: The input tensor to orthogonalize. + :param torch.Tensor X: The input tensor to orthogonalize. The input must + be of dimensions :math:`[N, M]`. :return: The orthonormal tensor. """ # check dim is less than all the other dimensions @@ -36,23 +47,75 @@ def forward(self, X): " than the other dimensions" ) - result = torch.zeros_like(X) - - # normalize first basis - X_0 = torch.select(X, self.dim, 0) - result_0 = torch.select(result, self.dim, 0) - result_0 += X_0 / torch.norm(X_0) + result = torch.zeros_like(X, requires_grad=self._requires_grad) + X_0 = torch.select(X, self.dim, 0).clone() + result_0 = X_0/torch.linalg.norm(X_0) + result = self._differentiable_copy(result, 0, result_0) # iterate over the rest of the basis with Gram-Schmidt for i in range(1, X.shape[self.dim]): - v = torch.select(X, self.dim, i) + v = torch.select(X, self.dim, i).clone() for j in range(i): - v -= torch.sum( - v * torch.select(result, self.dim, j), - dim=self.dim, - keepdim=True, - ) * torch.select(result, self.dim, j) - result_i = torch.select(result, self.dim, i) - result_i += v / torch.norm(v) - + vj = torch.select(result,self.dim,j).clone() + v = v - torch.sum(v * vj, + dim=self.dim, keepdim=True) * vj + #result_i = torch.select(result, self.dim, i) + result_i = v/torch.linalg.norm(v) + result = self._differentiable_copy(result, i, result_i) return result + + + def _differentiable_copy(self, result, idx, value): + """ + Perform a differentiable copy operation on a tensor. + + :param torch.Tensor result: The tensor where values will be copied to. + :param int idx: The index along the specified dimension where the + value will be copied. + :param torch.Tensor value: The tensor value to copy into the + result tensor. + :return: A new tensor with the copied values. + :rtype: torch.Tensor + """ + return result.index_copy( + self.dim, torch.tensor([idx]), value.unsqueeze(self.dim) + ) + + @property + def dim(self): + """ + Get the dimension along which operations are performed. + + :return: The current dimension value. + :rtype: int + """ + return self._dim + + @dim.setter + def dim(self, value): + """ + Set the dimension along which operations are performed. + + :param value: The dimension to be set, which must be 0, 1, or -1. + :type value: int + :raises IndexError: If the provided dimension is not in the + range [-1, 1]. + """ + # check consistency + check_consistency(value, int) + if value not in [0, 1, -1]: + raise IndexError('Dimension out of range (expected to be in ' + f'range of [-1, 1], but got {value})') + # assign value + self._dim = value + + @property + def requires_grad(self): + """ + Indicates whether gradient computation is required for operations + on the tensors. + + :return: True if gradients are required, False otherwise. + :rtype: bool + """ + return self._requires_grad diff --git a/tests/test_layers/test_orthogonal.py b/tests/test_layers/test_orthogonal.py index 30b59cd5..d443c177 100644 --- a/tests/test_layers/test_orthogonal.py +++ b/tests/test_layers/test_orthogonal.py @@ -1,6 +1,8 @@ import torch import pytest -from pina.model.layers.orthogonal import OrthogonalBlock +from pina.model.layers import OrthogonalBlock + +torch.manual_seed(111) list_matrices = [ torch.randn(10, 3), @@ -10,10 +12,28 @@ list_prohibited_matrices_dim0 = list_matrices[:-1] -def test_constructor(): - orth = OrthogonalBlock(1) - orth = OrthogonalBlock(0) - orth = OrthogonalBlock() +@pytest.mark.parametrize("dim", [-1, 0, 1, None]) +@pytest.mark.parametrize("requires_grad", [True, False, None]) +def test_constructor(dim, requires_grad): + if dim is None and requires_grad is None: + block = OrthogonalBlock() + elif dim is None: + block = OrthogonalBlock(requires_grad=requires_grad) + elif requires_grad is None: + block = OrthogonalBlock(dim=dim) + else: + block = OrthogonalBlock(dim=dim, requires_grad=requires_grad) + + if dim is not None: + assert block.dim == dim + if requires_grad is not None: + assert block.requires_grad == requires_grad + +def test_wrong_constructor(): + with pytest.raises(IndexError): + OrthogonalBlock(2) + with pytest.raises(ValueError): + OrthogonalBlock('a') @pytest.mark.parametrize("V", list_matrices) def test_forward(V): @@ -24,6 +44,21 @@ def test_forward(V): assert torch.allclose(V_orth.T @ V_orth, torch.eye(V.shape[1]), atol=1e-6) assert torch.allclose(V_orth_row @ V_orth_row.T, torch.eye(V.shape[1]), atol=1e-6) +@pytest.mark.parametrize("V", list_matrices) +def test_backward(V): + orth = OrthogonalBlock(requires_grad=True) + V_orth = orth(V) + loss = V_orth.mean() + loss.backward() + +@pytest.mark.parametrize("V", list_matrices) +def test_wrong_backward(V): + orth = OrthogonalBlock(requires_grad=False) + V_orth = orth(V) + loss = V_orth.mean() + with pytest.raises(RuntimeError): + loss.backward() + @pytest.mark.parametrize("V", list_prohibited_matrices_dim0) def test_forward_prohibited(V): orth = OrthogonalBlock(0)