Skip to content

Commit

Permalink
Implement np.linspace in fake_numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Sep 14, 2023
1 parent 1cbad01 commit cd6bad5
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
89 changes: 89 additions & 0 deletions arraycontext/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
"""


import operator
from typing import Any

import numpy as np

from arraycontext.container import NotAnArrayContainerError, serialize_container
Expand Down Expand Up @@ -100,6 +103,91 @@ def conjugate(self, x):

conj = conjugate

# {{{ linspace

# based on
# https://github.com/numpy/numpy/blob/v1.25.0/numpy/core/function_base.py#L24-L182

def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis=0):
num = operator.index(num)
if num < 0:
raise ValueError("Number of samples, %s, must be non-negative." % num)
div = (num - 1) if endpoint else num

# Convert float/complex array scalars to float, gh-3504
# and make sure one can use variables that have an __array_interface__,
# gh-6634

if isinstance(start, self._array_context.array_types):
raise NotImplementedError("start as an actx array")
if isinstance(stop, self._array_context.array_types):
raise NotImplementedError("stop as an actx array")

start = np.array(start) * 1.0
stop = np.array(stop) * 1.0

dt = np.result_type(start, stop, float(num))
if dtype is None:
dtype = dt
integer_dtype = False
else:
integer_dtype = np.issubdtype(dtype, np.integer)

delta = stop - start

y = self.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim)

if div > 0:
step = delta / div
#any_step_zero = _nx.asanyarray(step == 0).any()
any_step_zero = self._array_context.to_numpy((step == 0)).any()
if any_step_zero:
delta_actx = self._array_context.from_numpy(delta)

# Special handling for denormal numbers, gh-5437
y = y / div
y = y * delta_actx
else:
step_actx = self._array_context.from_numpy(step)
y = y * step_actx
else:
delta_actx = self._array_context.from_numpy(delta)
# sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0)
# have an undefined step
step = np.NaN
# Multiply with delta to allow possible override of output class.
y = y * delta_actx

y += start

if endpoint and num > 1:
y[-1, ...] = stop

if axis != 0:
# y = _nx.moveaxis(y, 0, axis)
raise NotImplementedError("axis != 0")

if integer_dtype:
y = self.floor(y)

# FIXME: Use astype
# https://github.com/inducer/pytato/issues/456
if retstep:
return y, step
#return y.astype(dtype), step
else:
return y
#return y.astype(dtype)

# }}}

def floor(self, ary):
raise NotImplementedError

def arange(self, *args: Any, **kwargs: Any):
raise NotImplementedError

# }}}


Expand Down Expand Up @@ -180,6 +268,7 @@ def norm(self, ary, ord=None):
return actx.np.sum(abs(ary)**ord)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")

# }}}


Expand Down
1 change: 1 addition & 0 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
THE SOFTWARE.
"""
from functools import partial, reduce
from typing import Any

import numpy as np

Expand Down

0 comments on commit cd6bad5

Please sign in to comment.