Skip to content

Commit

Permalink
copy axes and tags in zeros_like/full_like
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm authored and inducer committed May 5, 2023
1 parent 85e4285 commit 00e008b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def __getattr__(self, name):

def zeros_like(self, ary):
def _zeros_like(array):
return self._array_context.zeros(array.shape, array.dtype)
return self._array_context.zeros(
array.shape, array.dtype).copy(axes=array.axes, tags=array.tags)

return self._array_context._rec_map_container(
_zeros_like, ary, default_scalar=0)
Expand All @@ -83,7 +84,8 @@ def ones_like(self, ary):

def full_like(self, ary, fill_value):
def _full_like(subary):
return pt.full(subary.shape, fill_value, subary.dtype)
return pt.full(subary.shape, fill_value, subary.dtype).copy(
axes=subary.axes, tags=subary.tags)

return self._array_context._rec_map_container(
_full_like, ary, default_scalar=fill_value)
Expand Down

0 comments on commit 00e008b

Please sign in to comment.