diff --git a/src/lgdo/types/histogram.py b/src/lgdo/types/histogram.py index e2017df6..63d8191f 100644 --- a/src/lgdo/types/histogram.py +++ b/src/lgdo/types/histogram.py @@ -62,7 +62,7 @@ def __init__( if not isinstance(edges, Array): edges = Array(edges, attrs=binedge_attrs) elif binedge_attrs is not None: - msg = "passed both binedge types.Array instance and binedge_attrs" + msg = "passed both binedge as Array LGDO instance and binedge_attrs" raise ValueError(msg) if len(edges.nda.shape) != 1: @@ -72,11 +72,13 @@ def __init__( super().__init__({"binedges": edges, "closedleft": Scalar(closedleft)}) @classmethod - def from_edges(cls, edges: np.ndarray) -> Histogram.Axis: + def from_edges( + cls, edges: np.ndarray, binedge_attrs: dict[str, Any] | None = None + ) -> Histogram.Axis: edge_diff = np.diff(edges) if np.any(~np.isclose(edge_diff, edge_diff[0])): - return cls(edges, None, None, None, True) - return cls(None, edges[0], edges[-1], edge_diff[0], True) + return cls(edges, None, None, None, True, binedge_attrs) + return cls(None, edges[0], edges[-1], edge_diff[0], True, binedge_attrs) @property def is_range(self) -> bool: @@ -132,13 +134,24 @@ def __str__(self) -> str: string = f"edges={self.edges}" string += f", closedleft={self.closedleft}" - attrs = self["binedges"].getattrs() + attrs = self.get_binedgeattrs() if attrs: string += f" with attrs={attrs}" np.set_printoptions(threshold=thr_orig) return string + def get_binedgeattrs(self, datatype: bool = False) -> dict: + """Return a copy of the LGDO attributes dictionary of the binedges + + Parameters + ---------- + datatype + if ``False``, remove ``datatype`` attribute from the output + dictionary. + """ + return self["binedges"].getattrs(datatype) + def __init__( self, weights: hist.Hist | np.ndarray, @@ -148,6 +161,7 @@ def __init__( | list[tuple[float, float, float]] = None, isdensity: bool = False, attrs: dict[str, Any] | None = None, + binedge_attrs: dict[str, Any] | None = None, ) -> None: """A special struct to contain histogrammed data. @@ -163,6 +177,8 @@ def __init__( * can be a list of pre-initialized :class:`Histogram.Axis` * can be a list of tuples, representing a range, ``(first, last, step)`` * can be a list of numpy arrays, as returned by :func:`numpy.histogramdd`. + binedge_attrs + attributes that will be added to the all ``binedges`` of all axes. """ if isinstance(weights, hist.Hist): if binning is not None: @@ -186,7 +202,7 @@ def __init__( if not isinstance(ax, (hist.axis.Regular, hist.axis.Variable)): msg = "only regular or variable axes of hist.Hist can be converted" raise ValueError(msg) - b.append(Histogram.Axis.from_edges(ax.edges)) + b.append(Histogram.Axis.from_edges(ax.edges, binedge_attrs)) b = self._create_binning(b) else: if binning is None: @@ -195,11 +211,14 @@ def __init__( w = Array(weights) if all(isinstance(ax, Histogram.Axis) for ax in binning): + if binedge_attrs is not None: + msg = "passed both binedges as Axis instances and binedge_attrs" + raise ValueError(msg) b = binning elif all(isinstance(ax, np.ndarray) for ax in binning): - b = [Histogram.Axis.from_edges(ax) for ax in binning] + b = [Histogram.Axis.from_edges(ax, binedge_attrs) for ax in binning] elif all(isinstance(ax, tuple) for ax in binning): - b = [Histogram.Axis(None, *ax, True) for ax in binning] + b = [Histogram.Axis(None, *ax, True, binedge_attrs) for ax in binning] else: msg = "invalid binning object passed" raise ValueError(msg) diff --git a/tests/types/test_histogram.py b/tests/types/test_histogram.py index f3623aed..ee45c0a9 100644 --- a/tests/types/test_histogram.py +++ b/tests/types/test_histogram.py @@ -149,6 +149,13 @@ def test_axes(): assert h.binning[0].nbins == 3 assert str(h.binning[0]) == "edges=[0. 1. 2.5 3. ], closedleft=True" + with pytest.raises(ValueError, match="either from edges or from range"): + Histogram.Axis(np.array([0, 1, 2.5, 3]), 0, 1, None) + with pytest.raises(ValueError, match="all range parameters"): + Histogram.Axis(None, 0, 1, None) + + +def test_ax_attributes(): Histogram.Axis( np.array([0, 1, 2.5, 3]), None, None, None, binedge_attrs={"units": "m"} ) @@ -175,10 +182,21 @@ def test_axes(): str(ax) == "edges=[0. 1. 2.5 3. ], closedleft=True with attrs={'units': 'm'}" ) - with pytest.raises(ValueError, match="either from edges or from range"): - Histogram.Axis(np.array([0, 1, 2.5, 3]), 0, 1, None) - with pytest.raises(ValueError, match="all range parameters"): - Histogram.Axis(None, 0, 1, None) + h = Histogram( + np.array([[1, 1], [1, 1]]), + (np.array([0, 1, 2]), np.array([0, 1, 2])), + binedge_attrs={"units": "m"}, + ) + assert str(h.binning[0]).endswith(", closedleft=True with attrs={'units': 'm'}") + assert str(h.binning[1]).endswith(", closedleft=True with attrs={'units': 'm'}") + assert h.binning[0].get_binedgeattrs() == {"units": "m"} + + with pytest.raises(ValueError): + h = Histogram( + np.array([[1, 1], [1, 1]]), + [Histogram.Axis(None, 1, 3, 1, True), Histogram.Axis(None, 4, 6, 1, False)], + binedge_attrs={"units": "m"}, + ) def test_view_as_hist():