Skip to content

Commit

Permalink
pass binedge_attrs through Histogram constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
ManuelHu committed Jul 17, 2024
1 parent e77af46 commit 33caf61
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 12 deletions.
35 changes: 27 additions & 8 deletions src/lgdo/types/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
if not isinstance(edges, Array):
edges = Array(edges, attrs=binedge_attrs)
elif binedge_attrs is not None:
msg = "passed both binedge types.Array instance and binedge_attrs"
msg = "passed both binedge as Array LGDO instance and binedge_attrs"
raise ValueError(msg)

if len(edges.nda.shape) != 1:
Expand All @@ -72,11 +72,13 @@ def __init__(
super().__init__({"binedges": edges, "closedleft": Scalar(closedleft)})

@classmethod
def from_edges(cls, edges: np.ndarray) -> Histogram.Axis:
def from_edges(
cls, edges: np.ndarray, binedge_attrs: dict[str, Any] | None = None
) -> Histogram.Axis:
edge_diff = np.diff(edges)
if np.any(~np.isclose(edge_diff, edge_diff[0])):
return cls(edges, None, None, None, True)
return cls(None, edges[0], edges[-1], edge_diff[0], True)
return cls(edges, None, None, None, True, binedge_attrs)
return cls(None, edges[0], edges[-1], edge_diff[0], True, binedge_attrs)

@property
def is_range(self) -> bool:
Expand Down Expand Up @@ -132,13 +134,24 @@ def __str__(self) -> str:
string = f"edges={self.edges}"
string += f", closedleft={self.closedleft}"

attrs = self["binedges"].getattrs()
attrs = self.get_binedgeattrs()
if attrs:
string += f" with attrs={attrs}"

np.set_printoptions(threshold=thr_orig)
return string

def get_binedgeattrs(self, datatype: bool = False) -> dict:
"""Return a copy of the LGDO attributes dictionary of the binedges
Parameters
----------
datatype
if ``False``, remove ``datatype`` attribute from the output
dictionary.
"""
return self["binedges"].getattrs(datatype)

def __init__(
self,
weights: hist.Hist | np.ndarray,
Expand All @@ -148,6 +161,7 @@ def __init__(
| list[tuple[float, float, float]] = None,
isdensity: bool = False,
attrs: dict[str, Any] | None = None,
binedge_attrs: dict[str, Any] | None = None,
) -> None:
"""A special struct to contain histogrammed data.
Expand All @@ -163,6 +177,8 @@ def __init__(
* can be a list of pre-initialized :class:`Histogram.Axis`
* can be a list of tuples, representing a range, ``(first, last, step)``
* can be a list of numpy arrays, as returned by :func:`numpy.histogramdd`.
binedge_attrs
attributes that will be added to the all ``binedges`` of all axes.
"""
if isinstance(weights, hist.Hist):
if binning is not None:
Expand All @@ -186,7 +202,7 @@ def __init__(
if not isinstance(ax, (hist.axis.Regular, hist.axis.Variable)):
msg = "only regular or variable axes of hist.Hist can be converted"
raise ValueError(msg)
b.append(Histogram.Axis.from_edges(ax.edges))
b.append(Histogram.Axis.from_edges(ax.edges, binedge_attrs))
b = self._create_binning(b)
else:
if binning is None:
Expand All @@ -195,11 +211,14 @@ def __init__(
w = Array(weights)

if all(isinstance(ax, Histogram.Axis) for ax in binning):
if binedge_attrs is not None:
msg = "passed both binedges as Axis instances and binedge_attrs"
raise ValueError(msg)
b = binning
elif all(isinstance(ax, np.ndarray) for ax in binning):
b = [Histogram.Axis.from_edges(ax) for ax in binning]
b = [Histogram.Axis.from_edges(ax, binedge_attrs) for ax in binning]
elif all(isinstance(ax, tuple) for ax in binning):
b = [Histogram.Axis(None, *ax, True) for ax in binning]
b = [Histogram.Axis(None, *ax, True, binedge_attrs) for ax in binning]
else:
msg = "invalid binning object passed"
raise ValueError(msg)
Expand Down
26 changes: 22 additions & 4 deletions tests/types/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ def test_axes():
assert h.binning[0].nbins == 3
assert str(h.binning[0]) == "edges=[0. 1. 2.5 3. ], closedleft=True"

with pytest.raises(ValueError, match="either from edges or from range"):
Histogram.Axis(np.array([0, 1, 2.5, 3]), 0, 1, None)
with pytest.raises(ValueError, match="all range parameters"):
Histogram.Axis(None, 0, 1, None)


def test_ax_attributes():
Histogram.Axis(
np.array([0, 1, 2.5, 3]), None, None, None, binedge_attrs={"units": "m"}
)
Expand All @@ -175,10 +182,21 @@ def test_axes():
str(ax) == "edges=[0. 1. 2.5 3. ], closedleft=True with attrs={'units': 'm'}"
)

with pytest.raises(ValueError, match="either from edges or from range"):
Histogram.Axis(np.array([0, 1, 2.5, 3]), 0, 1, None)
with pytest.raises(ValueError, match="all range parameters"):
Histogram.Axis(None, 0, 1, None)
h = Histogram(
np.array([[1, 1], [1, 1]]),
(np.array([0, 1, 2]), np.array([0, 1, 2])),
binedge_attrs={"units": "m"},
)
assert str(h.binning[0]).endswith(", closedleft=True with attrs={'units': 'm'}")
assert str(h.binning[1]).endswith(", closedleft=True with attrs={'units': 'm'}")
assert h.binning[0].get_binedgeattrs() == {"units": "m"}

with pytest.raises(ValueError):
h = Histogram(
np.array([[1, 1], [1, 1]]),
[Histogram.Axis(None, 1, 3, 1, True), Histogram.Axis(None, 4, 6, 1, False)],
binedge_attrs={"units": "m"},
)


def test_view_as_hist():
Expand Down

0 comments on commit 33caf61

Please sign in to comment.