diff --git a/gem/gem.py b/gem/gem.py index 35812348..2859ded3 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -80,25 +80,52 @@ def __getitem__(self, indices): return Indexed(self, indices) def __add__(self, other): - return Sum(self, as_gem(other)) + return componentwise(Sum, self, as_gem(other)) def __radd__(self, other): return as_gem(other).__add__(self) def __sub__(self, other): - return Sum(self, Product(Literal(-1), as_gem(other))) + return componentwise( + Sum, self, + componentwise(Product, Literal(-1), as_gem(other))) def __rsub__(self, other): return as_gem(other).__sub__(self) def __mul__(self, other): - return Product(self, as_gem(other)) + return componentwise(Product, self, as_gem(other)) def __rmul__(self, other): return as_gem(other).__mul__(self) + def __matmul__(self, other): + other = as_gem(other) + if not self.shape and not other.shape: + return Product(self, other) + elif not (self.shape and other.shape): + raise ValueError("Both objects must have shape for matmul") + elif self.shape[-1] != other.shape[0]: + raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul") + *i, k = indices(len(self.shape)) + _, *j = indices(len(other.shape)) + expr = Product(Indexed(self, tuple(i) + (k, )), + Indexed(other, (k, ) + tuple(j))) + return ComponentTensor(IndexSum(expr, (k, )), tuple(i) + tuple(j)) + + def __rmatmul__(self, other): + return as_gem(other).__matmul__(self) + + @property + def T(self): + i = indices(len(self.shape)) + return ComponentTensor(Indexed(self, i), tuple(reversed(i))) + def __truediv__(self, other): - return Division(self, as_gem(other)) + other = as_gem(other) + if other.shape: + raise ValueError("Denominator must be scalar") + return componentwise(Division, self, other) def __rtruediv__(self, other): return as_gem(other).__truediv__(self) @@ -979,6 +1006,28 @@ def indices(n): return tuple(Index() for _ in range(n)) +def componentwise(op, *exprs): + """Apply gem op to exprs component-wise and wrap up in a ComponentTensor. + + :arg op: function that returns a gem Node. + :arg exprs: expressions to apply op to. + :raises ValueError: if the expressions have mismatching shapes. + :returns: New gem Node constructed from op. + + Each expression must either have the same shape, or else be + scalar. Shaped expressions are indexed, the op is applied to the + scalar expressions and the result is wrapped up in a ComponentTensor. + + """ + shapes = set(e.shape for e in exprs) + if len(shapes - {()}) > 1: + raise ValueError("expressions must have matching shape (or else be scalar)") + shape = max(shapes) + i = indices(len(shape)) + exprs = tuple(Indexed(e, i) if e.shape else e for e in exprs) + return ComponentTensor(op(*exprs), i) + + def as_gem(expr): """Attempt to convert an expression into GEM. diff --git a/tests/test_syntax_sugar.py b/tests/test_syntax_sugar.py index a2615e84..56bbc4f4 100644 --- a/tests/test_syntax_sugar.py +++ b/tests/test_syntax_sugar.py @@ -21,8 +21,14 @@ def test_expressions(): assert xij + 1 == gem.Sum(xij, gem.Literal(1)) assert 1 + xij == gem.Sum(gem.Literal(1), xij) - with pytest.raises(AssertionError): - xij + y + assert (xij + y).shape == (4, ) + + assert (x @ y).shape == (3, ) + + assert x.T.shape == (4, 3) + + with pytest.raises(ValueError): + xij.T @ y with pytest.raises(ValueError): xij + "foo"