Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repartition Operands. #540

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 40 additions & 196 deletions nums/core/array/blockarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nums.core.array.view import ArrayView
from nums.core.grid.grid import ArrayGrid
from nums.core.compute.compute_manager import ComputeManager
from nums.core.array import ops


# pylint: disable=too-many-lines
Expand Down Expand Up @@ -630,10 +631,6 @@ def to_block_array(obj, cm: ComputeManager, block_shape=None):
block_shape = cm.get_block_shape(np_array.shape, np_array.dtype)
return BlockArray.from_np(np_array, block_shape, False, cm)

def check_or_convert_other(self, other, compute_block_shape=False):
block_shape = None if compute_block_shape else self.block_shape
return BlockArray.to_block_array(other, self.cm, block_shape=block_shape)

def ufunc(self, op_name):
result = self.copy()
for grid_entry in self.grid.get_entry_iterator():
Expand Down Expand Up @@ -769,7 +766,7 @@ def reduce_axis(self, op_name, axis, keepdims=False):
return result

#################
# Linear Algebra
# Tensor Algebra
#################

def _compute_tensordot_syskwargs(self, self_block: Block, other_block: Block):
Expand All @@ -779,136 +776,24 @@ def _compute_tensordot_syskwargs(self, self_block: Block, other_block: Block):
else:
return other_block.true_grid_entry(), other_block.true_grid_shape()

def tensordot(self, other, axes=2):
if isinstance(axes, int):
pass
elif array_utils.is_array_like(axes):
raise NotImplementedError("Non-integer axes is currently not supported.")
else:
raise TypeError(f"Unexpected axes type '{type(axes).__name__}'")

other = self.check_or_convert_other(other, compute_block_shape=True)

if array_utils.np_tensordot_param_test(
self.shape, self.ndim, other.shape, other.ndim, axes
):
raise ValueError("shape-mismatch for sum")

this_axes = self.grid.grid_shape[:-axes]
this_sum_axes = self.grid.grid_shape[-axes:]
other_axes = other.grid.grid_shape[axes:]
other_sum_axes = other.grid.grid_shape[:axes]
assert this_sum_axes == other_sum_axes
result_shape = tuple(self.shape[:-axes] + other.shape[axes:])
result_block_shape = tuple(self.block_shape[:-axes] + other.block_shape[axes:])
result_grid = ArrayGrid(
shape=result_shape,
block_shape=result_block_shape,
dtype=array_utils.get_bop_output_type(
"tensordot", self.dtype, other.dtype
).__name__,
)
assert result_grid.grid_shape == tuple(this_axes + other_axes)
result = BlockArray(result_grid, self.cm)
this_dims = list(itertools.product(*map(range, this_axes)))
other_dims = list(itertools.product(*map(range, other_axes)))
sum_dims = list(itertools.product(*map(range, this_sum_axes)))
for i in this_dims:
for j in other_dims:
grid_entry = tuple(i + j)
result_block: Block = result.blocks[grid_entry]
sum_oids = []
for k in sum_dims:
self_block: Block = self.blocks[tuple(i + k)]
other_block: Block = other.blocks[tuple(k + j)]
dot_grid_args = self._compute_tensordot_syskwargs(
self_block, other_block
)
dotted_oid = self.cm.bop(
"tensordot",
self_block.oid,
other_block.oid,
self_block.transposed,
other_block.transposed,
axes=axes,
syskwargs={
"grid_entry": dot_grid_args[0],
"grid_shape": dot_grid_args[1],
},
)
sum_oids.append(
(dotted_oid, dot_grid_args[0], dot_grid_args[1], False)
)
result_block.oid = self._tree_reduce(
"sum", sum_oids, result_block.grid_entry, result_block.grid_shape
)
return result

def __matmul__(self, other):
if len(self.shape) > 2:
# TODO (bcp): NumPy's implementation does a stacked matmul, which is not supported yet.
raise NotImplementedError(
"Matrix multiply for tensors of rank > 2 not supported yet."
)
else:
return self.tensordot(other, 1)
return ops.tensordot(self.cm, self, other, axes=1)

def __rmatmul__(self, other):
other = self.check_or_convert_other(other)
return other @ self
return ops.tensordot(self.cm, other, self, axes=1)

__imatmul__ = __matmul__

#################
# Arithmetic
#################

def _fast_element_wise(self, op_name, other):
"""
Implements fast scheduling for basic element-wise operations.
"""
dtype = array_utils.get_bop_output_type(op_name, self.dtype, other.dtype)
# Schedule the op first.
blocks = np.empty(shape=self.grid.grid_shape, dtype=Block)
for grid_entry in self.grid.get_entry_iterator():
self_block: Block = self.blocks[grid_entry]
other_block: Block = other.blocks[grid_entry]
blocks[grid_entry] = block = Block(
grid_entry=grid_entry,
grid_shape=self_block.grid_shape,
rect=self_block.rect,
shape=self_block.shape,
dtype=dtype,
transposed=False,
cm=self.cm,
)
block.oid = self.cm.bop(
op_name,
self_block.oid,
other_block.oid,
self_block.transposed,
other_block.transposed,
axes={},
syskwargs={
"grid_entry": grid_entry,
"grid_shape": self.grid.grid_shape,
},
)
return BlockArray(
ArrayGrid(self.shape, self.block_shape, dtype.__name__),
self.cm,
blocks=blocks,
)

def __elementwise__(self, op_name, other):
other = self.check_or_convert_other(other)
if self.shape == other.shape and self.block_shape == other.block_shape:
return self._fast_element_wise(op_name, other)
blocks_op = self.blocks.__getattribute__("__%s__" % op_name)
return BlockArray.from_blocks(
blocks_op(other.blocks), result_shape=None, cm=self.cm
)

def __neg__(self):
return self.ufunc("negative")

Expand All @@ -919,136 +804,100 @@ def __abs__(self):
return self.ufunc("abs")

def __mod__(self, other):
return self.__elementwise__("mod", other)
return ops.elementwise("mod", self, other, self.cm)

def __rmod__(self, other):
other = self.check_or_convert_other(other)
return other.__elementwise__("mod", self)
return ops.elementwise("mod", other, self, self.cm)

__imod__ = __mod__

def __add__(self, other):
return self.__elementwise__("add", other)
return ops.elementwise("add", self, other, self.cm)

def __radd__(self, other):
other = self.check_or_convert_other(other)
return other.__elementwise__("add", self)
return ops.elementwise("add", other, self, self.cm)

__iadd__ = __add__

def __sub__(self, other):
return self.__elementwise__("sub", other)
return ops.elementwise("sub", self, other, self.cm)

def __rsub__(self, other):
other = self.check_or_convert_other(other)
return other.__elementwise__("sub", self)
return ops.elementwise("sub", other, self, self.cm)

__isub__ = __sub__

def __mul__(self, other):
return self.__elementwise__("mul", other)
return ops.elementwise("mul", self, other, self.cm)

def __rmul__(self, other):
other = self.check_or_convert_other(other)
return other.__elementwise__("mul", self)
return ops.elementwise("mul", other, self, self.cm)

__imul__ = __mul__

def __truediv__(self, other):
return self.__elementwise__("truediv", other)
return ops.elementwise("truediv", self, other, self.cm)

def __rtruediv__(self, other):
other = self.check_or_convert_other(other)
return other / self
return ops.elementwise("truediv", other, self, self.cm)

__itruediv__ = __truediv__

def __floordiv__(self, other):
return self.__elementwise__("floor_divide", other)
return ops.elementwise("floor_divide", self, other, self.cm)

def __rfloordiv__(self, other):
other = self.check_or_convert_other(other)
return other.__elementwise__("floor_divide", self)
return ops.elementwise("floor_divide", other, self, self.cm)

__ifloordiv__ = __floordiv__

def __pow__(self, other):
return self.__elementwise__("pow", other)
return ops.elementwise("pow", self, other, self.cm)

def __rpow__(self, other):
other = self.check_or_convert_other(other)
return other ** self
return ops.elementwise("pow", other, self, self.cm)

__ipow__ = __pow__

#################
# Inequalities
#################

def __inequality__(self, op, other):
other = self.check_or_convert_other(other)
assert (
other.shape == () or other.shape == self.shape
), "Currently supports comparison with scalars only."
shape = array_utils.broadcast(self.shape, other.shape).shape
block_shape = array_utils.broadcast_block_shape(
self.shape, other.shape, self.block_shape
)
dtype = bool.__name__
grid = ArrayGrid(shape, block_shape, dtype)
result = BlockArray(grid, self.cm)
for grid_entry in result.grid.get_entry_iterator():
if other.shape == ():
other_block: Block = other.blocks.item()
else:
other_block: Block = other.blocks[grid_entry]
result.blocks[grid_entry] = self.blocks[grid_entry].bop(
op, other_block, args={}
)

return result

def __ge__(self, other):
return self.__inequality__("ge", other)
return ops.inequality("ge", self, other, self.cm)

def __rge__(self, other):
other = self.check_or_convert_other(other)
return other.__inequality__("ge", self)
return ops.inequality("ge", other, self, self.cm)

def __gt__(self, other):
return self.__inequality__("gt", other)
return ops.inequality("gt", self, other, self.cm)

def __rgt__(self, other):
other = self.check_or_convert_other(other)
return other.__inequality__("gt", self)
return ops.inequality("gt", other, self, self.cm)

def __le__(self, other):
return self.__inequality__("le", other)
return ops.inequality("le", self, other, self.cm)

def __rle__(self, other):
other = self.check_or_convert_other(other)
return other.__inequality__("le", self)
return ops.inequality("le", other, self, self.cm)

def __lt__(self, other):
return self.__inequality__("lt", other)
return ops.inequality("lt", self, other, self.cm)

def __rlt__(self, other):
other = self.check_or_convert_other(other)
return other.__inequality__("lt", self)
return ops.inequality("lt", other, self, self.cm)

def __eq__(self, other):
return self.__inequality__("eq", other)
return ops.inequality("eq", self, other, self.cm)

def __req__(self, other):
other = self.check_or_convert_other(other)
return other.__inequality__("eq", self)
return ops.inequality("eq", other, self, self.cm)

def __ne__(self, other):
return self.__inequality__("ne", other)
return ops.inequality("ne", self, other, self.cm)

def __rne__(self, other):
other = self.check_or_convert_other(other)
return other.__inequality__("ne", self)
return ops.inequality("ne", other, self, self.cm)

##################
# Boolean
Expand All @@ -1066,47 +915,42 @@ def __invert__(self):
return self.ufunc("invert")

def __or__(self, other):
return self.__elementwise__("bitwise_or", other)
return ops.elementwise("bitwise_or", self, other, self.cm)

def __ror__(self, other):
other = self.check_or_convert_other(other)
return other.__elementwise__("bitwise_or", self)
return ops.elementwise("bitwise_or", other, self, self.cm)

__ior__ = __or__

def __and__(self, other):
return self.__elementwise__("bitwise_and", other)
return ops.elementwise("bitwise_and", self, other, self.cm)

def __rand__(self, other):
other = self.check_or_convert_other(other)
return other.__elementwise__("bitwise_and", self)
return ops.elementwise("bitwise_and", other, self, self.cm)

__iand__ = __and__

def __xor__(self, other):
return self.__elementwise__("bitwise_xor", other)
return ops.elementwise("bitwise_xor", self, other, self.cm)

def __rxor__(self, other):
other = self.check_or_convert_other(other)
return other.__elementwise__("bitwise_xor", self)
return ops.elementwise("bitwise_xor", other, self, self.cm)

__ixor__ = __xor__

def __lshift__(self, other):
return self.__elementwise__("left_shift", other)
return ops.elementwise("left_shift", self, other, self.cm)

def __rlshift__(self, other):
other = self.check_or_convert_other(other)
return other.__elementwise__("left_shift", self)
return ops.elementwise("left_shift", other, self, self.cm)

__ilshift__ = __lshift__

def __rshift__(self, other):
return self.__elementwise__("right_shift", other)
return ops.elementwise("right_shift", self, other, self.cm)

def __rrshift__(self, other):
other = self.check_or_convert_other(other)
return other.__elementwise__("right_shift", self)
return ops.elementwise("right_shift", other, self, self.cm)

__irshift__ = __rshift__

Expand Down
Loading