Skip to content

Commit

Permalink
add copy argument to __array__ methods
Browse files Browse the repository at this point in the history
  • Loading branch information
t20100 committed Aug 27, 2024
1 parent 357fc5d commit b1543b7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
10 changes: 8 additions & 2 deletions src/silx/io/commonh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import numpy

from . import utils
from .._utils import NP_OPTIONAL_COPY


__authors__ = ["V. Valls", "P. Knobel"]
__license__ = "MIT"
Expand Down Expand Up @@ -347,12 +349,16 @@ def external(self):
:rtype: list or None"""
return None

def __array__(self, dtype=None):
def __array__(self, dtype=None, copy=None):
# Special case for (0,)*-shape datasets
if numpy.prod(self.shape) == 0:
return self[()]
else:
return numpy.array(self[...], dtype=self.dtype if dtype is None else dtype)
return numpy.array(
self[...],
dtype=self.dtype if dtype is None else dtype,
copy=NP_OPTIONAL_COPY if copy is None else copy,
)

def __iter__(self):
"""Iterate over the first axis. TypeError if scalar."""
Expand Down
21 changes: 16 additions & 5 deletions src/silx/utils/array_like.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# /*##########################################################################
#
# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
# Copyright (c) 2016-2024 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -49,6 +49,9 @@
import numpy
import numbers

from .._utils import NP_OPTIONAL_COPY


__authors__ = ["P. Knobel"]
__license__ = "MIT"
__date__ = "26/04/2017"
Expand Down Expand Up @@ -276,13 +279,17 @@ def __sort_indices(self, indices):
)
return sorted_indices

def __array__(self, dtype=None):
def __array__(self, dtype=None, copy=None):
"""Cast the images into a numpy array, and return it.
If a transposition has been done on this images, return
a transposed view of a numpy array."""
return numpy.transpose(
numpy.array(self.images, dtype=dtype), self.transposition
numpy.array(
self.images,
dtype=dtype,
copy=NP_OPTIONAL_COPY if copy is None else copy),
self.transposition,
)

def __len__(self):
Expand Down Expand Up @@ -543,13 +550,17 @@ def __getitem__(self, item):

return numpy.transpose(output_data_not_transposed, axes=output_dimensions)

def __array__(self, dtype=None):
def __array__(self, dtype=None, copy=None):
"""Cast the dataset into a numpy array, and return it.
If a transposition has been done on this dataset, return
a transposed view of a numpy array."""
return numpy.transpose(
numpy.array(self.dataset, dtype=dtype), self.transposition
numpy.array(
self.dataset,
dtype=dtype,
copy=NP_OPTIONAL_COPY if copy is None else copy),
self.transposition,
)

def __len__(self):
Expand Down

0 comments on commit b1543b7

Please sign in to comment.