Skip to content

Commit

Permalink
implement backward compatibility for sparse matrix multiplication for…
Browse files Browse the repository at this point in the history
… older xarray versions
  • Loading branch information
strawpants committed Jun 27, 2024
1 parent 000fb50 commit 8eaa583
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 5 deletions.
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=61.0","setuptools-scm>=8","numpy<=1.26","Cython>=3","wheel","pytest","scipy"]
requires = ["setuptools>=61.0","setuptools-scm>=8","numpy<2","Cython>=3","wheel","pytest","scipy"]
build-backend = "setuptools.build_meta"
[project]
name = "shxarray"
Expand All @@ -20,7 +20,11 @@ classifiers = [
"Development Status :: 1 - Planning"
]
dependencies = [ "pandas >= 2.0", "pyaml >= 23.9.0", "scipy", "xarray >= 2023.1.0",
"numpy<=1.26","numba", "sparse","importlib_metadata","requests","openpyxl"]
"numpy<2","numba", "sparse","importlib_metadata","requests","openpyxl"]

[project.optional-dependencies]
#you need dask in combination with when using older xarray versions
dask=["dask>=2022.9.2"]

[tool.setuptools_scm]
# empty for now
Expand Down
34 changes: 33 additions & 1 deletion src/shxarray/kernels/anisokernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from shxarray.core.sh_indexing import SHindexBase
import sparse

from packaging import version





Expand All @@ -21,6 +24,13 @@ def __init__(self,dsobj,name="aniso",truncate=True):
self._dskernel=dsobj
self.name=name
self.truncate=truncate
self.useDask=version.parse(xr.__version__) < version.parse('2023.11.0')
if self.useDask:
from dask.array.core import einsum_lookup
#Register the einsum functions which are needed to do the sparse dot functions (for earlier versions of xarray)
einsum_lookup.register(sparse.COO,AnisoKernel.daskeinsumReplace)
#convert to xarray with dask structure
self._dskernel=self._dskernel.chunk()

@property
def nmax(self):
Expand All @@ -38,7 +48,10 @@ def __call__(self,dain:xr.DataArray):
daout=xr.dot(self._dskernel.mat,dain,dims=[SHindexBase.name])
#rename nm and convert to dense array
daout=daout.sh.toggle_nm()
daout=xr.DataArray(daout.data.todense(),coords=daout.coords,name=self.name)
if self.useDask:
daout=xr.DataArray(daout.compute().data,coords=daout.coords,name=self.name)
else:
daout=xr.DataArray(daout.data.todense(),coords=daout.coords,name=self.name)

if not self.truncate and self.nmin > 0:
#also add the unfiltered lower degree coefficients back to the results
Expand All @@ -60,4 +73,23 @@ def position(self,lon,lat):
ynmdata=ynmdata/normv

return self.__call__(ynmdata)

@staticmethod
def daskeinsumReplace(subscripts, *operands, out=None, dtype=None, order='K', casting='safe', optimize=False):
"""Mimics the interface of https://numpy.org/doc/stable/reference/generated/numpy.einsum.html, but uses the sparse.COO dot function"""
if subscripts == "ab,cb->ac":
return operands[0].dot(operands[1].T)
elif subscripts == "ab,ca->bc":
return operands[0].T.dot(operands[1].T)
elif subscripts == "ab,bc->ac":
return operands[0].dot(operands[1])
elif subscripts == "ab,b->a":
return operands[0].dot(operands[1])
elif subscripts == "ab,a->b":
return operands[0].T.dot(operands[1])
elif subscripts == "ab,ac->bc":
return operands[0].T.dot(operands[1])

else:
raise NotImplementedError(f"Don't know (yet) how to handle this einsum: {subscripts} with sparse.dot operations")

5 changes: 3 additions & 2 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ def test_ddkmult(shinput,shoutput):

dacheck=shinm.sh.filter('DDK2')
dadiff=(dacheck-shoutm)/shoutm
maxdiff=max(abs(dadiff.max()),abs(dadiff.min()))
assert(maxdiff < tol)
maxdiff=max(abs(dadiff.max()),abs(dadiff.min()))
reltol=3e-11
assert(maxdiff < reltol)

@pytest.fixture
def shgausstest():
Expand Down

0 comments on commit 8eaa583

Please sign in to comment.