From 20c43be83a1cb1a3b53747be7c709c63b89dd556 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 May 2023 09:11:13 -0700 Subject: [PATCH] copy axes and tags in zeros_like/full_like --- arraycontext/impl/pytato/fake_numpy.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 4dad159f..f2e3d2e8 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -73,7 +73,8 @@ def __getattr__(self, name): def zeros_like(self, ary): def _zeros_like(array): - return self._array_context.zeros(array.shape, array.dtype) + return self._array_context.zeros( + array.shape, array.dtype).copy(axes=array.axes, tags=array.tags) return self._array_context._rec_map_container( _zeros_like, ary, default_scalar=0) @@ -83,7 +84,8 @@ def ones_like(self, ary): def full_like(self, ary, fill_value): def _full_like(subary): - return pt.full(subary.shape, fill_value, subary.dtype) + return pt.full(subary.shape, fill_value, subary.dtype).copy( + axes=subary.axes, tags=subary.tags) return self._array_context._rec_map_container( _full_like, ary, default_scalar=fill_value)