Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Support various dimensions in CosineSimilarity
Browse files Browse the repository at this point in the history
  • Loading branch information
mejai1206 authored Oct 5, 2023
1 parent c72e077 commit 9b77767
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 31 deletions.
59 changes: 32 additions & 27 deletions tests/test_cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,50 @@


@pytest.mark.parametrize(
"z_size, y_size, x_size, dim",
[(1431, 500, 200, 0), (221, 1250, 200, 1), (21, 6400, 86, 2)],
"x_shape, dim",
[((1431, 500, 200), 2), ((21, 6400), 1), ((86,), 0)],
)
def test_forward(z_size, y_size, x_size, dim, device):
def test_forward(x_shape, dim, device):
factory_kwargs = {"device": device}

x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs)
x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs)
x1 = torch.randn(x_shape, **factory_kwargs)
x2 = torch.randn(x_shape, **factory_kwargs)

assert util.equal(
torch.nn.functional.cosine_similarity(x1, x2, dim=dim),
trident.function.cosine_similarity(x1, x2, dim=dim),
torch.nn.functional.cosine_similarity(x1, x2, dim),
trident.function.cosine_similarity(x1, x2, dim),
)


@pytest.mark.parametrize(
"z_size, y_size, x_size, dim",
[(1280, 1000, 200, 0), (200, 1280, 200, 1), (640, 21, 86, 2)],
"x_shape, dim",
[((1280, 1000, 200), 1), ((640, 21), 0), ((90,), 0)],
)
def test_backward(z_size, y_size, x_size, dim, device):
def test_backward(x_shape, dim, device):
factory_kwargs = {"device": device}

x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs)
x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs)

if dim == 0:
target_dim = (y_size, x_size)
elif dim == 1:
target_dim = (z_size, x_size)
else:
target_dim = (z_size, y_size)

grad_output = torch.randn(target_dim, **factory_kwargs)
x1 = torch.randn(x_shape, **factory_kwargs)
x2 = torch.randn(x_shape, **factory_kwargs)

def get_output_shape(x_shape, dim):
if len(x_shape) == 1:
return ()
elif len(x_shape) == 2:
return x_shape[1] if dim == 0 else x_shape[0]
else:
if dim == 0:
return (x_shape[1], x_shape[2])
elif dim == 1:
return (x_shape[0], x_shape[2])
else:
return (x_shape[0], x_shape[1])

grad_output = torch.randn(get_output_shape(x_shape, dim), **factory_kwargs)

def train(func):
i = x1.clone()
j = x2.clone()
i.requires_grad = j.requires_grad = True
func(i, j).backward(grad_output, retain_graph=True)

return i.grad, j.grad

(x, y) = train(torch.nn.CosineSimilarity(dim))
Expand All @@ -65,13 +70,13 @@ def test_cosine_similarity(z_size, y_size, x_size, dim, device, dtype):
assert output.dtype == dtype

if dim == 0:
target_dim = (y_size, x_size)
output_shape = (y_size, x_size)
elif dim == 1:
target_dim = (z_size, x_size)
output_shape = (z_size, x_size)
else:
target_dim = (z_size, y_size)
output_shape = (z_size, y_size)

grad_output = torch.randn(target_dim, **factory_kwargs)
grad_output = torch.randn(output_shape, **factory_kwargs)

output.backward(grad_output)
assert x1.grad is not None
Expand Down
49 changes: 45 additions & 4 deletions trident/operation/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,40 @@ class CosineSimilarity(torch.autograd.Function):
def forward(ctx: Any, *args: Any, **kwargs: Any):
x1, x2, dim, eps = args

if dim >= x1.dim():
raise ValueError(f"Unable to process the given dim: '{dim}'.")

util.push_trace("CosineSimilarity.__forward")
output, denominator, numerator = CosineSimilarity.__forward(x1, x2, dim, eps)
output, denominator, numerator = CosineSimilarity.__forward(
x1.view(CosineSimilarity.__input_shape(x1)),
x2.view(CosineSimilarity.__input_shape(x2)),
CosineSimilarity.__dim(x1, dim),
eps,
)
util.pop_trace()

ctx.save_for_backward(x1, x2, denominator, numerator)
ctx.dim = dim

return output
return output.view(CosineSimilarity.__output_shape(x1, output))

@staticmethod
def backward(ctx: Any, *grad_outputs: Any):
grad_output = grad_outputs[0]
x1, x2, denominator, numerator = ctx.saved_tensors

util.push_trace("CosineSimilarity.__backward")
grad_x1, grad_x2 = CosineSimilarity.__backward(grad_output, x1, x2, denominator, numerator, ctx.dim)
grad_x1, grad_x2 = CosineSimilarity.__backward(
grad_output,
x1.view(CosineSimilarity.__input_shape(x1)),
x2.view(CosineSimilarity.__input_shape(x2)),
denominator,
numerator,
CosineSimilarity.__dim(x1, ctx.dim),
)
util.pop_trace()

return grad_x1, grad_x2, None, None
return grad_x1.view(x1.shape), grad_x2.view(x2.shape), None, None

@staticmethod
def __forward(x1: torch.Tensor, x2: torch.Tensor, dim: torch.int32, eps: torch.float32):
Expand Down Expand Up @@ -131,3 +146,29 @@ def __output_size_and_size_along_dim(input: torch.Tensor, dim: int):
size_along_dim = x_size

return output_y_size, output_x_size, size_along_dim

@staticmethod
def __input_shape(input: torch.Tensor):
if input.dim() == 1:
return (1, 1, *input.shape)
elif input.dim() == 2:
return (1, *input.shape)
elif input.dim() == 3:
return input.shape
else:
raise ValueError(f"Unable to convert the given input: '{input}'.")

@staticmethod
def __output_shape(input: torch.Tensor, output: torch.Tensor):
if input.dim() == 1:
return ()
if input.dim() == 2:
return output.shape[1]
elif input.dim() == 3:
return output.shape
else:
raise ValueError(f"Unable to convert the given x: '{input}'.")

@staticmethod
def __dim(input: torch.Tensor, dim: torch.int32):
return dim + 3 - input.dim()

0 comments on commit 9b77767

Please sign in to comment.