Skip to content

Commit

Permalink
Additional refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed May 26, 2024
1 parent 7e2a610 commit 5bc05b3
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
isort = isort opt_einsum scripts/
black = black opt_einsum scripts/
autoflake = autoflake -ir --remove-all-unused-imports --ignore-init-module-imports --remove-unused-variables opt_einsum scripts/
mypy = mypy --ignore-missing-imports codex opt_einsum scripts/
mypy = mypy --ignore-missing-imports opt_einsum scripts/

.PHONY: install
install:
Expand Down
106 changes: 72 additions & 34 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,13 @@ def _choose_memory_arg(memory_limit: _MemoryLimit, size_list: List[int]) -> Opti
return int(memory_limit)


def _filter_einsum_defaults(kwargs: Dict[Literal["order", "casting", "dtype", "out"], Any]) -> Dict[str, Any]:
_EinsumDefaultKeys = Literal["order", "casting", "dtype", "out"]


def _filter_einsum_defaults(kwargs: Dict[_EinsumDefaultKeys, Any]) -> Dict[_EinsumDefaultKeys, Any]:
"""Filters out default contract kwargs to pass to various backends."""
kwargs = kwargs.copy()
ret = {}
ret: Dict[_EinsumDefaultKeys, Any] = {}
if (order := kwargs.pop("order", "K")) != "K":
ret["order"] = order

Expand All @@ -143,7 +146,7 @@ def _filter_einsum_defaults(kwargs: Dict[Literal["order", "casting", "dtype", "o
return ret


# Overlaod for contract(str, *operands)
# Overlaod for contract(einsum_string, *operands)
@overload
def contract_path(
subscripts: str,
Expand All @@ -156,7 +159,7 @@ def contract_path(
) -> Tuple[PathType, PathInfo]: ...


# Overlaod for contract(operand, indices, ....)
# Overlaod for contract(operand, indices, operand, indices, ....)
@overload
def contract_path(
subscripts: ArrayType,
Expand Down Expand Up @@ -449,7 +452,7 @@ def _einsum(*operands: Any, **kwargs: Any) -> ArrayType:

einsum_str = parser.convert_to_valid_einsum_chars(einsum_str)

kwargs = _filter_einsum_defaults(kwargs)
kwargs = _filter_einsum_defaults(kwargs) # type: ignore
return fn(einsum_str, *operands, **kwargs)


Expand Down Expand Up @@ -494,7 +497,7 @@ def contract(
@overload
def contract(
subscripts: ArrayType,
*operands: ArrayType | Collection[int],
*operands: Union[ArrayType | Collection[int]],
out: ArrayType = ...,
dtype: Any = ...,
order: _OrderKACF = ...,
Expand All @@ -508,8 +511,8 @@ def contract(


def contract(
subscripts: Any,
*operands: Any,
subscripts: str | ArrayType,
*operands: Union[ArrayType | Collection[int]],
out: Optional[ArrayType] = None,
dtype: Optional[str] = None,
order: _OrderKACF = "K",
Expand Down Expand Up @@ -720,7 +723,7 @@ def _core_contract(

else:
# Call einsum
out_kwarg: None | ArrayType = None
out_kwarg: Union[None, ArrayType] = None
if handle_out:
out_kwarg = out
new_view = _einsum(
Expand Down Expand Up @@ -881,29 +884,28 @@ def _contract_with_conversion(

return result

def __call__(self, *arrays: ArrayType, **kwargs: Any) -> ArrayType:
def __call__(
self,
*arrays: ArrayType,
out: Union[None, ArrayType] = None,
backend: str = "auto",
evaluate_constants: bool = False,
) -> ArrayType:
"""Evaluate this expression with a set of arrays.
Parameters
----------
arrays : seq of array
The arrays to supply as input to the expression.
out : array, optional (default: ``None``)
If specified, output the result into this array.
backend : str, optional (default: ``numpy``)
Perform the contraction with this backend library. If numpy arrays
are supplied then try to convert them to and from the correct
backend array type.
Parameters:
arrays: The arrays to supply as input to the expression.
out: If specified, output the result into this array.
backend: Perform the contraction with this backend library. If numpy arrays
are supplied then try to convert them to and from the correct
backend array type.
evaluate_constants: Pre-evaluates constants with the appropriate backend.
Returns:
The contracted result.
"""
out = kwargs.pop("out", None)
backend = parse_backend(arrays, kwargs.pop("backend", "auto"))
evaluate_constants = kwargs.pop("evaluate_constants", False)

if kwargs:
raise ValueError(
"The only valid keyword arguments to a `ContractExpression` "
"call are `out=` or `backend=`. Got: {}.".format(kwargs)
)
backend = parse_backend(arrays, backend)

correct_num_args = self._full_num_args if evaluate_constants else self.num_args

Expand Down Expand Up @@ -965,7 +967,41 @@ def shape_only(shape: TensorShapeType) -> Shaped:
return Shaped(shape)


def contract_expression(subscripts: str, *shapes: TensorShapeType | ArrayType, **kwargs: Any) -> Any:
# Overlaod for contract(einsum_string, *operands)
@overload
def contract_expression(
subscripts: str,
*operands: Union[ArrayType, TensorShapeType],
constants: Union[Collection[int], None] = ...,
use_blas: bool = ...,
optimize: OptimizeKind = ...,
memory_limit: _MemoryLimit = ...,
**kwargs: Any,
) -> ContractExpression: ...


# Overlaod for contract(operand, indices, operand, indices, ....)
@overload
def contract_expression(
subscripts: Union[ArrayType, TensorShapeType],
*operands: Union[ArrayType, TensorShapeType, Collection[int]],
constants: Union[Collection[int], None] = ...,
use_blas: bool = ...,
optimize: OptimizeKind = ...,
memory_limit: _MemoryLimit = ...,
**kwargs: Any,
) -> ContractExpression: ...


def contract_expression(
subscripts: Union[str, ArrayType, TensorShapeType],
*shapes: Union[ArrayType, TensorShapeType, Collection[int]],
constants: Union[Collection[int], None] = None,
use_blas: bool = True,
optimize: OptimizeKind = True,
memory_limit: _MemoryLimit = None,
**kwargs: Any,
) -> ContractExpression:
"""Generate a reusable expression for a given contraction with
specific shapes, which can, for example, be cached.
Expand Down Expand Up @@ -1022,7 +1058,7 @@ def contract_expression(subscripts: str, *shapes: TensorShapeType | ArrayType, *
```
"""
if not kwargs.get("optimize", True):
if not optimize:
raise ValueError("Can only generate expressions for optimized contractions.")

for arg in ("out", "backend"):
Expand All @@ -1033,16 +1069,18 @@ def contract_expression(subscripts: str, *shapes: TensorShapeType | ArrayType, *
)

if not isinstance(subscripts, str):
subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes)
subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes) # type: ignore

kwargs["_gen_expression"] = True

# build dict of constant indices mapped to arrays
constants = kwargs.pop("constants", ())
constants = constants or tuple()
constants_dict = {i: shapes[i] for i in constants}
kwargs["_constants_dict"] = constants_dict

# apart from constant arguments, make dummy arrays
dummy_arrays = [s if i in constants else shape_only(s) for i, s in enumerate(shapes)]
dummy_arrays = [s if i in constants else shape_only(s) for i, s in enumerate(shapes)] # type: ignore

return contract(subscripts, *dummy_arrays, **kwargs)
return contract(
subscripts, *dummy_arrays, use_blas=use_blas, optimize=optimize, memory_limit=memory_limit, **kwargs
)
4 changes: 2 additions & 2 deletions opt_einsum/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str:
return new_sub


def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[str, List[Any]]:
def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[str, Tuple[ArrayType, ...]]:
"""Convert 'interleaved' input to standard einsum input."""
tmp_operands = list(operands)
operand_list = []
Expand Down Expand Up @@ -259,7 +259,7 @@ def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[s
subscripts += "->"
subscripts += convert_subscripts(output_list, symbol_map)

return subscripts, operands
return subscripts, tuple(operands)


def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]:
Expand Down
6 changes: 3 additions & 3 deletions opt_einsum/tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def test_contract_expression_checks() -> None:
assert "Internal error while evaluating `ContractExpression`" in str(err.value)

# should only be able to specify out
with pytest.raises(ValueError) as err:
expr(np.random.rand(2, 3), np.random.rand(3, 4), order="F")
assert "only valid keyword arguments to a `ContractExpression`" in str(err.value)
with pytest.raises(TypeError) as err_type:
expr(np.random.rand(2, 3), np.random.rand(3, 4), order="F") # type: ignore
assert "only valid keyword arguments to a `ContractExpression`" in str(err_type.value)


def test_broadcasting_contraction() -> None:
Expand Down
5 changes: 3 additions & 2 deletions scripts/compare_random_paths.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import resource
import timeit
from typing import Literal

import numpy as np
import pandas as pd
Expand All @@ -12,7 +13,7 @@

pd.set_option("display.width", 200)

opt_path = "optimal"
opt_path: Literal["optimal"] = "optimal"

# Number of dimensions
max_dims = 4
Expand Down Expand Up @@ -108,7 +109,7 @@ def random_contraction():

diff_flags = df["Flag"] is not True
print("\nNumber of contract different than einsum: %d." % np.sum(diff_flags))
if sum(diff_flags) > 0:
if diff_flags > 0:
print("Terms different than einsum")
print(df[df["Flag"] is not True])

Expand Down

0 comments on commit 5bc05b3

Please sign in to comment.