From 8eaa583857f5ec357c4275bcac9f3c00e7d63c0c Mon Sep 17 00:00:00 2001 From: Roelof Rietbroek Date: Thu, 27 Jun 2024 17:21:31 +0200 Subject: [PATCH] implement backward compatibility for sparse matrix multiplication for older xarray versions --- pyproject.toml | 8 +++++-- src/shxarray/kernels/anisokernel.py | 34 ++++++++++++++++++++++++++++- tests/test_filters.py | 5 +++-- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bf08fe0..3c21b56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=61.0","setuptools-scm>=8","numpy<=1.26","Cython>=3","wheel","pytest","scipy"] +requires = ["setuptools>=61.0","setuptools-scm>=8","numpy<2","Cython>=3","wheel","pytest","scipy"] build-backend = "setuptools.build_meta" [project] name = "shxarray" @@ -20,7 +20,11 @@ classifiers = [ "Development Status :: 1 - Planning" ] dependencies = [ "pandas >= 2.0", "pyaml >= 23.9.0", "scipy", "xarray >= 2023.1.0", -"numpy<=1.26","numba", "sparse","importlib_metadata","requests","openpyxl"] +"numpy<2","numba", "sparse","importlib_metadata","requests","openpyxl"] + +[project.optional-dependencies] +#you need dask in combination with when using older xarray versions +dask=["dask>=2022.9.2"] [tool.setuptools_scm] # empty for now diff --git a/src/shxarray/kernels/anisokernel.py b/src/shxarray/kernels/anisokernel.py index f7063ec..349ac2a 100644 --- a/src/shxarray/kernels/anisokernel.py +++ b/src/shxarray/kernels/anisokernel.py @@ -9,6 +9,9 @@ from shxarray.core.sh_indexing import SHindexBase import sparse +from packaging import version + + @@ -21,6 +24,13 @@ def __init__(self,dsobj,name="aniso",truncate=True): self._dskernel=dsobj self.name=name self.truncate=truncate + self.useDask=version.parse(xr.__version__) < version.parse('2023.11.0') + if self.useDask: + from dask.array.core import einsum_lookup + #Register the einsum functions which are needed to do the sparse dot functions (for earlier versions of xarray) + einsum_lookup.register(sparse.COO,AnisoKernel.daskeinsumReplace) + #convert to xarray with dask structure + self._dskernel=self._dskernel.chunk() @property def nmax(self): @@ -38,7 +48,10 @@ def __call__(self,dain:xr.DataArray): daout=xr.dot(self._dskernel.mat,dain,dims=[SHindexBase.name]) #rename nm and convert to dense array daout=daout.sh.toggle_nm() - daout=xr.DataArray(daout.data.todense(),coords=daout.coords,name=self.name) + if self.useDask: + daout=xr.DataArray(daout.compute().data,coords=daout.coords,name=self.name) + else: + daout=xr.DataArray(daout.data.todense(),coords=daout.coords,name=self.name) if not self.truncate and self.nmin > 0: #also add the unfiltered lower degree coefficients back to the results @@ -60,4 +73,23 @@ def position(self,lon,lat): ynmdata=ynmdata/normv return self.__call__(ynmdata) + + @staticmethod + def daskeinsumReplace(subscripts, *operands, out=None, dtype=None, order='K', casting='safe', optimize=False): + """Mimics the interface of https://numpy.org/doc/stable/reference/generated/numpy.einsum.html, but uses the sparse.COO dot function""" + if subscripts == "ab,cb->ac": + return operands[0].dot(operands[1].T) + elif subscripts == "ab,ca->bc": + return operands[0].T.dot(operands[1].T) + elif subscripts == "ab,bc->ac": + return operands[0].dot(operands[1]) + elif subscripts == "ab,b->a": + return operands[0].dot(operands[1]) + elif subscripts == "ab,a->b": + return operands[0].T.dot(operands[1]) + elif subscripts == "ab,ac->bc": + return operands[0].T.dot(operands[1]) + + else: + raise NotImplementedError(f"Don't know (yet) how to handle this einsum: {subscripts} with sparse.dot operations") diff --git a/tests/test_filters.py b/tests/test_filters.py index 15544e7..9d41aca 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -74,8 +74,9 @@ def test_ddkmult(shinput,shoutput): dacheck=shinm.sh.filter('DDK2') dadiff=(dacheck-shoutm)/shoutm - maxdiff=max(abs(dadiff.max()),abs(dadiff.min())) - assert(maxdiff < tol) + maxdiff=max(abs(dadiff.max()),abs(dadiff.min())) + reltol=3e-11 + assert(maxdiff < reltol) @pytest.fixture def shgausstest():