From 4ffd56ef6c8d25d7479b3f77de17f61746db40b4 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Mon, 9 Sep 2024 19:23:25 +0200 Subject: [PATCH] Improve Stats compute perf --- torch_frame/data/stats.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch_frame/data/stats.py b/torch_frame/data/stats.py index fbf6b633..6974a710 100644 --- a/torch_frame/data/stats.py +++ b/torch_frame/data/stats.py @@ -85,7 +85,8 @@ def compute( sep: str | None = None, ) -> Any: if self == StatType.MEAN: - flattened = np.hstack(np.hstack(ser.values)) + val = np.hstack(ser.values) if ser.values.ndim > 1 else ser.values + flattened = np.hstack(val) if val.ndim > 1 else val finite_mask = np.isfinite(flattened) if not finite_mask.any(): # NOTE: We may just error out here if eveything is NaN @@ -93,14 +94,16 @@ def compute( return np.mean(flattened[finite_mask]).item() elif self == StatType.STD: - flattened = np.hstack(np.hstack(ser.values)) + val = np.hstack(ser.values) if ser.values.ndim > 1 else ser.values + flattened = np.hstack(val) if val.ndim > 1 else val finite_mask = np.isfinite(flattened) if not finite_mask.any(): return np.nan return np.std(flattened[finite_mask]).item() elif self == StatType.QUANTILES: - flattened = np.hstack(np.hstack(ser.values)) + val = np.hstack(ser.values) if ser.values.ndim > 1 else ser.values + flattened = np.hstack(val) if val.ndim > 1 else val finite_mask = np.isfinite(flattened) if not finite_mask.any(): return [np.nan, np.nan, np.nan, np.nan, np.nan]