Skip to content

Commit

Permalink
More magic methods for gem Nodes
Browse files Browse the repository at this point in the history
Adds a __matmul__ method, and a componentwise helper so the existing
sugar also works transparently for shaped objects.
  • Loading branch information
wence- committed Aug 25, 2020
1 parent 3524174 commit cdd7115
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
57 changes: 53 additions & 4 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,52 @@ def __getitem__(self, indices):
return Indexed(self, indices)

def __add__(self, other):
return Sum(self, as_gem(other))
return componentwise(Sum, self, as_gem(other))

def __radd__(self, other):
return as_gem(other).__add__(self)

def __sub__(self, other):
return Sum(self, Product(Literal(-1), as_gem(other)))
return componentwise(
Sum, self,
componentwise(Product, Literal(-1), as_gem(other)))

def __rsub__(self, other):
return as_gem(other).__sub__(self)

def __mul__(self, other):
return Product(self, as_gem(other))
return componentwise(Product, self, as_gem(other))

def __rmul__(self, other):
return as_gem(other).__mul__(self)

def __matmul__(self, other):
other = as_gem(other)
if not self.shape and not other.shape:
return Product(self, other)
elif not (self.shape and other.shape):
raise ValueError("Both objects must have shape for matmul")
elif self.shape[-1] != other.shape[0]:
raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul")
*i, k = indices(len(self.shape))
_, *j = indices(len(other.shape))
expr = Product(Indexed(self, tuple(i) + (k, )),
Indexed(other, (k, ) + tuple(j)))
return ComponentTensor(IndexSum(expr, (k, )), tuple(i) + tuple(j))

def __rmatmul__(self, other):
return as_gem(other).__matmul__(self)

@property
def T(self):
i = indices(len(self.shape))
return ComponentTensor(Indexed(self, i), tuple(reversed(i)))

def __truediv__(self, other):
return Division(self, as_gem(other))
other = as_gem(other)
if other.shape:
raise ValueError("Denominator must be scalar")
return componentwise(Division, self, other)

def __rtruediv__(self, other):
return as_gem(other).__truediv__(self)
Expand Down Expand Up @@ -979,6 +1006,28 @@ def indices(n):
return tuple(Index() for _ in range(n))


def componentwise(op, *exprs):
"""Apply gem op to exprs component-wise and wrap up in a ComponentTensor.
:arg op: function that returns a gem Node.
:arg exprs: expressions to apply op to.
:raises ValueError: if the expressions have mismatching shapes.
:returns: New gem Node constructed from op.
Each expression must either have the same shape, or else be
scalar. Shaped expressions are indexed, the op is applied to the
scalar expressions and the result is wrapped up in a ComponentTensor.
"""
shapes = set(e.shape for e in exprs)
if len(shapes - {()}) > 1:
raise ValueError("expressions must have matching shape (or else be scalar)")
shape = max(shapes)
i = indices(len(shape))
exprs = tuple(Indexed(e, i) if e.shape else e for e in exprs)
return ComponentTensor(op(*exprs), i)


def as_gem(expr):
"""Attempt to convert an expression into GEM.
Expand Down
10 changes: 8 additions & 2 deletions tests/test_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@ def test_expressions():
assert xij + 1 == gem.Sum(xij, gem.Literal(1))
assert 1 + xij == gem.Sum(gem.Literal(1), xij)

with pytest.raises(AssertionError):
xij + y
assert (xij + y).shape == (4, )

assert (x @ y).shape == (3, )

assert x.T.shape == (4, 3)

with pytest.raises(ValueError):
xij.T @ y

with pytest.raises(ValueError):
xij + "foo"
Expand Down

0 comments on commit cdd7115

Please sign in to comment.