Skip to content

Commit

Permalink
handle empty contraction list in PathInfo (#229)
Browse files Browse the repository at this point in the history
* handle empty contraction list in PathInfo

* update test

* format test file

* satisfy mypy
  • Loading branch information
jcmgray authored May 13, 2024
1 parent 4fc457f commit 0aaad7b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
6 changes: 3 additions & 3 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def __init__(
self.scale_list = scale_list
self.naive_cost = Decimal(naive_cost)
self.opt_cost = Decimal(opt_cost)
self.speedup = self.naive_cost / self.opt_cost
self.speedup = self.naive_cost / max(self.opt_cost, Decimal(1))
self.size_list = size_list
self.size_dict = size_dict

self.shapes = [tuple(size_dict[k] for k in ks) for ks in input_subscripts.split(",")]
self.eq = "{}->{}".format(input_subscripts, output_subscript)
self.largest_intermediate = Decimal(max(size_list))
self.largest_intermediate = Decimal(max(size_list, default=1))

def __repr__(self) -> str:
# Return the path along with a nice string representation
Expand All @@ -65,7 +65,7 @@ def __repr__(self) -> str:
path_print = [
" Complete contraction: {}\n".format(self.eq),
" Naive scaling: {}\n".format(len(self.indices)),
" Optimized scaling: {}\n".format(max(self.scale_list)),
" Optimized scaling: {}\n".format(max(self.scale_list, default=0)),
" Naive FLOP count: {:.3e}\n".format(self.naive_cost),
" Optimized FLOP count: {:.3e}\n".format(self.opt_cost),
" Theoretical speedup: {:.3e}\n".format(self.speedup),
Expand Down
17 changes: 11 additions & 6 deletions opt_einsum/tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest

from opt_einsum import contract, contract_expression
from opt_einsum import contract, contract_path, contract_expression


def test_contract_expression_checks():
Expand Down Expand Up @@ -53,7 +53,6 @@ def test_contract_expression_checks():


def test_broadcasting_contraction():

a = np.random.rand(1, 5, 4)
b = np.random.rand(4, 6)
c = np.random.rand(5, 6)
Expand All @@ -73,7 +72,6 @@ def test_broadcasting_contraction():


def test_broadcasting_contraction2():

a = np.random.rand(1, 1, 5, 4)
b = np.random.rand(4, 6)
c = np.random.rand(5, 6)
Expand All @@ -93,7 +91,6 @@ def test_broadcasting_contraction2():


def test_broadcasting_contraction3():

a = np.random.rand(1, 5, 4)
b = np.random.rand(4, 1, 6)
c = np.random.rand(5, 6)
Expand All @@ -106,7 +103,6 @@ def test_broadcasting_contraction3():


def test_broadcasting_contraction4():

a = np.arange(64).reshape(2, 4, 8)
ein = contract("obk,ijk->ioj", a, a, optimize=False)
opt = contract("obk,ijk->ioj", a, a, optimize=True)
Expand All @@ -115,11 +111,20 @@ def test_broadcasting_contraction4():


def test_can_blas_on_healed_broadcast_dimensions():

expr = contract_expression("ab,bc,bd->acd", (5, 4), (1, 5), (4, 20))
# first contraction involves broadcasting
assert expr.contraction_list[0][2] == "bc,ab->bca"
assert expr.contraction_list[0][-1] is False
# but then is healed GEMM is usable
assert expr.contraction_list[1][2] == "bca,bd->acd"
assert expr.contraction_list[1][-1] == "GEMM"


def test_pathinfo_for_empty_contraction():
eq = "->"
arrays = (1.0,)
path = []
_, info = contract_path(eq, *arrays, optimize=path)
# some info is built lazily, so check repr
assert repr(info)
assert info.largest_intermediate == 1

0 comments on commit 0aaad7b

Please sign in to comment.