Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Add input dimension of CosineSimiliarity
Browse files Browse the repository at this point in the history
Supports dimensions smaller than 3 on CosineSimiliarity
  • Loading branch information
mejai1206 committed Oct 5, 2023
1 parent 04073ec commit d7e13fb
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 46 deletions.
91 changes: 49 additions & 42 deletions tests/test_cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,59 @@


@pytest.mark.parametrize(
"z_size, y_size, x_size, dim",
[(1431, 500, 200, 0), (221, 1250, 200, 1), (21, 6400, 86, 2)],
"x_shape",
[(1431, 500, 200), (21, 6400), (86,)],
)
def test_forward(z_size, y_size, x_size, dim, device):
def test_forward(x_shape, device):
factory_kwargs = {"device": device}
x1 = torch.randn(x_shape, **factory_kwargs)
x2 = torch.randn(x_shape, **factory_kwargs)

x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs)
x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs)

assert util.equal(
torch.nn.functional.cosine_similarity(x1, x2, dim=dim),
trident.function.cosine_similarity(x1, x2, dim=dim),
)
for dim in range(len(x_shape)):
assert util.equal(
torch.nn.functional.cosine_similarity(x1, x2, dim),
trident.function.cosine_similarity(x1, x2, dim),
)


@pytest.mark.parametrize(
"z_size, y_size, x_size, dim",
[(1280, 1000, 200, 0), (200, 1280, 200, 1), (640, 21, 86, 2)],
"x_shape",
[(1280, 1000, 200), (640, 21), (90,)],
)
def test_backward(z_size, y_size, x_size, dim, device):
def test_backward(x_shape, device):
factory_kwargs = {"device": device}

x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs)
x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs)

if dim == 0:
target_dim = (y_size, x_size)
elif dim == 1:
target_dim = (z_size, x_size)
else:
target_dim = (z_size, y_size)

grad_output = torch.randn(target_dim, **factory_kwargs)

def train(func):
i = x1.clone()
j = x2.clone()
i.requires_grad = j.requires_grad = True
func(i, j).backward(grad_output, retain_graph=True)
return i.grad, j.grad

(x, y) = train(torch.nn.CosineSimilarity(dim))
(a, b) = train(trident.CosineSimilarity(dim))

assert util.equal(x, a)
assert util.equal(y, b)
x1 = torch.randn(x_shape, **factory_kwargs)
x2 = torch.randn(x_shape, **factory_kwargs)

def get_output_shape(x_shape, dim):
if len(x_shape) == 1:
return "scalar"
elif len(x_shape) == 2:
return x_shape[1] if dim == 0 else x_shape[0]
else:
if dim == 0:
return (x_shape[1], x_shape[2])
elif dim == 1:
return (x_shape[0], x_shape[2])
else:
return (x_shape[0], x_shape[1])

for dim in range(len(x_shape)):
output_shape = get_output_shape(x_shape, dim)
grad_output = torch.randn(output_shape, **factory_kwargs) if output_shape != "scalar" else None

def train(func):
i = x1.clone()
j = x2.clone()
i.requires_grad = j.requires_grad = True
func(i, j).backward(grad_output, retain_graph=True)
return i.grad, j.grad

(x, y) = train(torch.nn.CosineSimilarity(dim))
(a, b) = train(trident.CosineSimilarity(dim))

assert util.equal(x, a)
assert util.equal(y, b)


@pytest.mark.parametrize("z_size, y_size, x_size, dim", [(640, 21, 86, 2)])
Expand All @@ -65,13 +72,13 @@ def test_cosine_similarity(z_size, y_size, x_size, dim, device, dtype):
assert output.dtype == dtype

if dim == 0:
target_dim = (y_size, x_size)
output_shape = (y_size, x_size)
elif dim == 1:
target_dim = (z_size, x_size)
output_shape = (z_size, x_size)
else:
target_dim = (z_size, y_size)
output_shape = (z_size, y_size)

grad_output = torch.randn(target_dim, **factory_kwargs)
grad_output = torch.randn(output_shape, **factory_kwargs)

output.backward(grad_output)
assert x1.grad is not None
Expand Down
47 changes: 43 additions & 4 deletions trident/operation/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,39 @@ class CosineSimilarity(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any):
x1, x2, dim, eps = args
assert dim < x1.dim()

util.push_trace("CosineSimilarity.__forward")
output, denominator, numerator = CosineSimilarity.__forward(x1, x2, dim, eps)
output, denominator, numerator = CosineSimilarity.__forward(
x1.view(CosineSimilarity.__input_shape(x1)),
x2.view(CosineSimilarity.__input_shape(x2)),
CosineSimilarity.__dim(x1, dim),
eps,
)
util.pop_trace()

ctx.save_for_backward(x1, x2, denominator, numerator)
ctx.dim = dim

return output
return output.view(CosineSimilarity.__output_shape(x1, output))

@staticmethod
def backward(ctx: Any, *grad_outputs: Any):
grad_output = grad_outputs[0]
x1, x2, denominator, numerator = ctx.saved_tensors

util.push_trace("CosineSimilarity.__backward")
grad_x1, grad_x2 = CosineSimilarity.__backward(grad_output, x1, x2, denominator, numerator, ctx.dim)
grad_x1, grad_x2 = CosineSimilarity.__backward(
grad_output,
x1.view(CosineSimilarity.__input_shape(x1)),
x2.view(CosineSimilarity.__input_shape(x2)),
denominator,
numerator,
CosineSimilarity.__dim(x1, ctx.dim),
)
util.pop_trace()

return grad_x1, grad_x2, None, None
return grad_x1.view(x1.shape), grad_x2.view(x2.shape), None, None

@staticmethod
def __forward(x1: torch.Tensor, x2: torch.Tensor, dim: torch.int32, eps: torch.float32):
Expand Down Expand Up @@ -131,3 +144,29 @@ def __output_size_and_size_along_dim(input: torch.Tensor, dim: int):
size_along_dim = x_size

return output_y_size, output_x_size, size_along_dim

@staticmethod
def __input_shape(input: torch.Tensor):
if input.dim() == 1:
return (1, 1, *input.shape)
elif input.dim() == 2:
return (1, *input.shape)
elif input.dim() == 3:
return input.shape
else:
raise ValueError(f"Unable to convert the given input: '{input}'.")

@staticmethod
def __output_shape(input: torch.Tensor, output: torch.Tensor):
if input.dim() == 1:
return -1
if input.dim() == 2:
return output.shape[1]
elif input.dim() == 3:
return output.shape
else:
raise ValueError(f"Unable to convert the given x: '{input}'.")

@staticmethod
def __dim(input: torch.Tensor, dim: torch.int32):
return dim + 3 - input.dim()

0 comments on commit d7e13fb

Please sign in to comment.