From 9b777676d6435c502ba536d486124ea224fa190e Mon Sep 17 00:00:00 2001 From: "mejai.p" Date: Fri, 6 Oct 2023 08:59:29 +0900 Subject: [PATCH] Support various dimensions in CosineSimilarity --- tests/test_cosine_similarity.py | 59 ++++++++++++++------------ trident/operation/cosine_similarity.py | 49 +++++++++++++++++++-- 2 files changed, 77 insertions(+), 31 deletions(-) diff --git a/tests/test_cosine_similarity.py b/tests/test_cosine_similarity.py index c26a50ac..a288936b 100644 --- a/tests/test_cosine_similarity.py +++ b/tests/test_cosine_similarity.py @@ -6,45 +6,50 @@ @pytest.mark.parametrize( - "z_size, y_size, x_size, dim", - [(1431, 500, 200, 0), (221, 1250, 200, 1), (21, 6400, 86, 2)], + "x_shape, dim", + [((1431, 500, 200), 2), ((21, 6400), 1), ((86,), 0)], ) -def test_forward(z_size, y_size, x_size, dim, device): +def test_forward(x_shape, dim, device): factory_kwargs = {"device": device} - - x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs) - x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs) + x1 = torch.randn(x_shape, **factory_kwargs) + x2 = torch.randn(x_shape, **factory_kwargs) assert util.equal( - torch.nn.functional.cosine_similarity(x1, x2, dim=dim), - trident.function.cosine_similarity(x1, x2, dim=dim), + torch.nn.functional.cosine_similarity(x1, x2, dim), + trident.function.cosine_similarity(x1, x2, dim), ) @pytest.mark.parametrize( - "z_size, y_size, x_size, dim", - [(1280, 1000, 200, 0), (200, 1280, 200, 1), (640, 21, 86, 2)], + "x_shape, dim", + [((1280, 1000, 200), 1), ((640, 21), 0), ((90,), 0)], ) -def test_backward(z_size, y_size, x_size, dim, device): +def test_backward(x_shape, dim, device): factory_kwargs = {"device": device} - - x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs) - x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs) - - if dim == 0: - target_dim = (y_size, x_size) - elif dim == 1: - target_dim = (z_size, x_size) - else: - target_dim = (z_size, y_size) - - grad_output = torch.randn(target_dim, **factory_kwargs) + x1 = torch.randn(x_shape, **factory_kwargs) + x2 = torch.randn(x_shape, **factory_kwargs) + + def get_output_shape(x_shape, dim): + if len(x_shape) == 1: + return () + elif len(x_shape) == 2: + return x_shape[1] if dim == 0 else x_shape[0] + else: + if dim == 0: + return (x_shape[1], x_shape[2]) + elif dim == 1: + return (x_shape[0], x_shape[2]) + else: + return (x_shape[0], x_shape[1]) + + grad_output = torch.randn(get_output_shape(x_shape, dim), **factory_kwargs) def train(func): i = x1.clone() j = x2.clone() i.requires_grad = j.requires_grad = True func(i, j).backward(grad_output, retain_graph=True) + return i.grad, j.grad (x, y) = train(torch.nn.CosineSimilarity(dim)) @@ -65,13 +70,13 @@ def test_cosine_similarity(z_size, y_size, x_size, dim, device, dtype): assert output.dtype == dtype if dim == 0: - target_dim = (y_size, x_size) + output_shape = (y_size, x_size) elif dim == 1: - target_dim = (z_size, x_size) + output_shape = (z_size, x_size) else: - target_dim = (z_size, y_size) + output_shape = (z_size, y_size) - grad_output = torch.randn(target_dim, **factory_kwargs) + grad_output = torch.randn(output_shape, **factory_kwargs) output.backward(grad_output) assert x1.grad is not None diff --git a/trident/operation/cosine_similarity.py b/trident/operation/cosine_similarity.py index 91bc31ee..815b3457 100644 --- a/trident/operation/cosine_similarity.py +++ b/trident/operation/cosine_similarity.py @@ -24,14 +24,22 @@ class CosineSimilarity(torch.autograd.Function): def forward(ctx: Any, *args: Any, **kwargs: Any): x1, x2, dim, eps = args + if dim >= x1.dim(): + raise ValueError(f"Unable to process the given dim: '{dim}'.") + util.push_trace("CosineSimilarity.__forward") - output, denominator, numerator = CosineSimilarity.__forward(x1, x2, dim, eps) + output, denominator, numerator = CosineSimilarity.__forward( + x1.view(CosineSimilarity.__input_shape(x1)), + x2.view(CosineSimilarity.__input_shape(x2)), + CosineSimilarity.__dim(x1, dim), + eps, + ) util.pop_trace() ctx.save_for_backward(x1, x2, denominator, numerator) ctx.dim = dim - return output + return output.view(CosineSimilarity.__output_shape(x1, output)) @staticmethod def backward(ctx: Any, *grad_outputs: Any): @@ -39,10 +47,17 @@ def backward(ctx: Any, *grad_outputs: Any): x1, x2, denominator, numerator = ctx.saved_tensors util.push_trace("CosineSimilarity.__backward") - grad_x1, grad_x2 = CosineSimilarity.__backward(grad_output, x1, x2, denominator, numerator, ctx.dim) + grad_x1, grad_x2 = CosineSimilarity.__backward( + grad_output, + x1.view(CosineSimilarity.__input_shape(x1)), + x2.view(CosineSimilarity.__input_shape(x2)), + denominator, + numerator, + CosineSimilarity.__dim(x1, ctx.dim), + ) util.pop_trace() - return grad_x1, grad_x2, None, None + return grad_x1.view(x1.shape), grad_x2.view(x2.shape), None, None @staticmethod def __forward(x1: torch.Tensor, x2: torch.Tensor, dim: torch.int32, eps: torch.float32): @@ -131,3 +146,29 @@ def __output_size_and_size_along_dim(input: torch.Tensor, dim: int): size_along_dim = x_size return output_y_size, output_x_size, size_along_dim + + @staticmethod + def __input_shape(input: torch.Tensor): + if input.dim() == 1: + return (1, 1, *input.shape) + elif input.dim() == 2: + return (1, *input.shape) + elif input.dim() == 3: + return input.shape + else: + raise ValueError(f"Unable to convert the given input: '{input}'.") + + @staticmethod + def __output_shape(input: torch.Tensor, output: torch.Tensor): + if input.dim() == 1: + return () + if input.dim() == 2: + return output.shape[1] + elif input.dim() == 3: + return output.shape + else: + raise ValueError(f"Unable to convert the given x: '{input}'.") + + @staticmethod + def __dim(input: torch.Tensor, dim: torch.int32): + return dim + 3 - input.dim()