From d7e13fb4a1780e201269db8cd37522d9a054b9b2 Mon Sep 17 00:00:00 2001 From: "mejai.p" Date: Wed, 4 Oct 2023 16:52:48 +0900 Subject: [PATCH] Add input dimension of CosineSimiliarity Supports dimensions smaller than 3 on CosineSimiliarity --- tests/test_cosine_similarity.py | 91 ++++++++++++++------------ trident/operation/cosine_similarity.py | 47 +++++++++++-- 2 files changed, 92 insertions(+), 46 deletions(-) diff --git a/tests/test_cosine_similarity.py b/tests/test_cosine_similarity.py index c26a50ac..640731b5 100644 --- a/tests/test_cosine_similarity.py +++ b/tests/test_cosine_similarity.py @@ -6,52 +6,59 @@ @pytest.mark.parametrize( - "z_size, y_size, x_size, dim", - [(1431, 500, 200, 0), (221, 1250, 200, 1), (21, 6400, 86, 2)], + "x_shape", + [(1431, 500, 200), (21, 6400), (86,)], ) -def test_forward(z_size, y_size, x_size, dim, device): +def test_forward(x_shape, device): factory_kwargs = {"device": device} + x1 = torch.randn(x_shape, **factory_kwargs) + x2 = torch.randn(x_shape, **factory_kwargs) - x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs) - x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs) - - assert util.equal( - torch.nn.functional.cosine_similarity(x1, x2, dim=dim), - trident.function.cosine_similarity(x1, x2, dim=dim), - ) + for dim in range(len(x_shape)): + assert util.equal( + torch.nn.functional.cosine_similarity(x1, x2, dim), + trident.function.cosine_similarity(x1, x2, dim), + ) @pytest.mark.parametrize( - "z_size, y_size, x_size, dim", - [(1280, 1000, 200, 0), (200, 1280, 200, 1), (640, 21, 86, 2)], + "x_shape", + [(1280, 1000, 200), (640, 21), (90,)], ) -def test_backward(z_size, y_size, x_size, dim, device): +def test_backward(x_shape, device): factory_kwargs = {"device": device} - - x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs) - x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs) - - if dim == 0: - target_dim = (y_size, x_size) - elif dim == 1: - target_dim = (z_size, x_size) - else: - target_dim = (z_size, y_size) - - grad_output = torch.randn(target_dim, **factory_kwargs) - - def train(func): - i = x1.clone() - j = x2.clone() - i.requires_grad = j.requires_grad = True - func(i, j).backward(grad_output, retain_graph=True) - return i.grad, j.grad - - (x, y) = train(torch.nn.CosineSimilarity(dim)) - (a, b) = train(trident.CosineSimilarity(dim)) - - assert util.equal(x, a) - assert util.equal(y, b) + x1 = torch.randn(x_shape, **factory_kwargs) + x2 = torch.randn(x_shape, **factory_kwargs) + + def get_output_shape(x_shape, dim): + if len(x_shape) == 1: + return "scalar" + elif len(x_shape) == 2: + return x_shape[1] if dim == 0 else x_shape[0] + else: + if dim == 0: + return (x_shape[1], x_shape[2]) + elif dim == 1: + return (x_shape[0], x_shape[2]) + else: + return (x_shape[0], x_shape[1]) + + for dim in range(len(x_shape)): + output_shape = get_output_shape(x_shape, dim) + grad_output = torch.randn(output_shape, **factory_kwargs) if output_shape != "scalar" else None + + def train(func): + i = x1.clone() + j = x2.clone() + i.requires_grad = j.requires_grad = True + func(i, j).backward(grad_output, retain_graph=True) + return i.grad, j.grad + + (x, y) = train(torch.nn.CosineSimilarity(dim)) + (a, b) = train(trident.CosineSimilarity(dim)) + + assert util.equal(x, a) + assert util.equal(y, b) @pytest.mark.parametrize("z_size, y_size, x_size, dim", [(640, 21, 86, 2)]) @@ -65,13 +72,13 @@ def test_cosine_similarity(z_size, y_size, x_size, dim, device, dtype): assert output.dtype == dtype if dim == 0: - target_dim = (y_size, x_size) + output_shape = (y_size, x_size) elif dim == 1: - target_dim = (z_size, x_size) + output_shape = (z_size, x_size) else: - target_dim = (z_size, y_size) + output_shape = (z_size, y_size) - grad_output = torch.randn(target_dim, **factory_kwargs) + grad_output = torch.randn(output_shape, **factory_kwargs) output.backward(grad_output) assert x1.grad is not None diff --git a/trident/operation/cosine_similarity.py b/trident/operation/cosine_similarity.py index 91bc31ee..788925ba 100644 --- a/trident/operation/cosine_similarity.py +++ b/trident/operation/cosine_similarity.py @@ -23,15 +23,21 @@ class CosineSimilarity(torch.autograd.Function): @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any): x1, x2, dim, eps = args + assert dim < x1.dim() util.push_trace("CosineSimilarity.__forward") - output, denominator, numerator = CosineSimilarity.__forward(x1, x2, dim, eps) + output, denominator, numerator = CosineSimilarity.__forward( + x1.view(CosineSimilarity.__input_shape(x1)), + x2.view(CosineSimilarity.__input_shape(x2)), + CosineSimilarity.__dim(x1, dim), + eps, + ) util.pop_trace() ctx.save_for_backward(x1, x2, denominator, numerator) ctx.dim = dim - return output + return output.view(CosineSimilarity.__output_shape(x1, output)) @staticmethod def backward(ctx: Any, *grad_outputs: Any): @@ -39,10 +45,17 @@ def backward(ctx: Any, *grad_outputs: Any): x1, x2, denominator, numerator = ctx.saved_tensors util.push_trace("CosineSimilarity.__backward") - grad_x1, grad_x2 = CosineSimilarity.__backward(grad_output, x1, x2, denominator, numerator, ctx.dim) + grad_x1, grad_x2 = CosineSimilarity.__backward( + grad_output, + x1.view(CosineSimilarity.__input_shape(x1)), + x2.view(CosineSimilarity.__input_shape(x2)), + denominator, + numerator, + CosineSimilarity.__dim(x1, ctx.dim), + ) util.pop_trace() - return grad_x1, grad_x2, None, None + return grad_x1.view(x1.shape), grad_x2.view(x2.shape), None, None @staticmethod def __forward(x1: torch.Tensor, x2: torch.Tensor, dim: torch.int32, eps: torch.float32): @@ -131,3 +144,29 @@ def __output_size_and_size_along_dim(input: torch.Tensor, dim: int): size_along_dim = x_size return output_y_size, output_x_size, size_along_dim + + @staticmethod + def __input_shape(input: torch.Tensor): + if input.dim() == 1: + return (1, 1, *input.shape) + elif input.dim() == 2: + return (1, *input.shape) + elif input.dim() == 3: + return input.shape + else: + raise ValueError(f"Unable to convert the given input: '{input}'.") + + @staticmethod + def __output_shape(input: torch.Tensor, output: torch.Tensor): + if input.dim() == 1: + return -1 + if input.dim() == 2: + return output.shape[1] + elif input.dim() == 3: + return output.shape + else: + raise ValueError(f"Unable to convert the given x: '{input}'.") + + @staticmethod + def __dim(input: torch.Tensor, dim: torch.int32): + return dim + 3 - input.dim()