Skip to content

Commit

Permalink
remove testing dependency on sympy and use more extensive testing for…
Browse files Browse the repository at this point in the history
… Wigner3j and Gaunt Coefficients
  • Loading branch information
strawpants committed Feb 14, 2024
1 parent 6462b1f commit 1107a23
Show file tree
Hide file tree
Showing 10 changed files with 336 additions and 291 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
else:
openmp_arg = '-fopenmp'

debug=True
debug=False
#don't necessarily use cython
if "USE_CYTHON" in os.environ:
useCython=True
Expand Down
12 changes: 9 additions & 3 deletions src/builtin_backend/Gaunt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,20 @@ class GauntReal{
return 0+0j;
}
else if (mu == 0){
return kronecker(m,0)*kronecker(mu,0) + 0j;
return 1 + 0j;
}else if (mu > 0){
///real elements only
return M_SQRT1_2*(kronecker(m,mu)+csphase(m)*kronecker(m,-mu))+0j;
//according to Homeier et al 1996
//return M_SQRT1_2*(kronecker(m,mu)+csphase(m)*kronecker(m,-mu))+0j;
//according to sympy(https://docs.sympy.org/latest/modules/physics/wigner.html#sympy.physics.wigner.gaunt)
return M_SQRT1_2*(kronecker(m,mu)+csphase(mu)*kronecker(m,-mu))+0j;
}else{
///mu < 0
///purely imaginary elements only
return 1j*M_SQRT1_2*(csphase(m)*kronecker(m,mu)+kronecker(m,-mu));
//according to Homeier et al 1996
//return 1j*M_SQRT1_2*(csphase(m)*kronecker(m,mu)+kronecker(m,-mu));
//according to sympy(https://docs.sympy.org/latest/modules/physics/wigner.html#sympy.physics.wigner.gaunt)
return 1j*M_SQRT1_2*(csphase(m)*kronecker(-m,mu)-kronecker(m,mu)*csphase(mu-m));
}
}
enum ABCcase{A,B,Bt,C}mcase;
Expand Down
1 change: 1 addition & 0 deletions src/builtin_backend/gaunt.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ from libcpp.vector cimport vector
from libcpp.pair cimport pair
# C++ / Cython interface declaration


cdef extern from "Gaunt.hpp":
cdef cppclass Gaunt[T]:
Gaunt() except +
Expand Down
6 changes: 6 additions & 0 deletions src/builtin_backend/gaunt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ cimport numpy as np
import xarray as xr
import pandas as pd

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.initializedcheck(False)
def getGaunt(n2,n3,m2,m3):
"""
Compute non-zero Gaunt coefficients for valid values of n1 and m1
Expand All @@ -24,6 +27,9 @@ def getGaunt(n2,n3,m2,m3):
nm=pd.MultiIndex.from_tuples([(n,m) for n in range(gaunt.nmin(),gaunt.nmax()+1,2)],names=("n","m"))
return xr.DataArray(gaunt.get(), coords=dict(nm=nm),dims=["nm"])

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.initializedcheck(False)
def getGauntReal(n2,n3,m2,m3):
"""
Compute non-zero Real Gaunt coefficients for valid values of n1 and m1
Expand Down
374 changes: 187 additions & 187 deletions src/builtin_backend/shlib.cpp

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/builtin_backend/wigner3j.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ cimport numpy as np
import xarray as xr
import pandas as pd

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.initializedcheck(False)
def getWigner3j(j2,j3,m2,m3):
"""
Compute non-zero Wigner3J symbols with their valid (j1,m1) for j2,j3,m2,m3 input
Expand Down
229 changes: 129 additions & 100 deletions tests/test_wigner3j_gaunt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,105 +9,134 @@
import shxarray
import xarray as xr
import numpy as np
from sympy.physics.wigner import wigner_3j,gaunt,real_gaunt
import gzip
import os

tol=1e-16


# def get_sympyvaldata():
# """Create a on-disk dataset with validation coefficients for Wigner3j, Gaunt and realGaunt coefficients"""
# sympyfile=os.path.join(os.path.dirname(__file__),'testdata/sympy_valdata.pkl')
# sympyval={}
# gauntreal={}
# precrg=3
# if not os.path.exists(sympyfile):
# #generate new data
# nmax=4
# for n2 in range(0,nmax+1):
# for n3 in range(0,nmax+1):
# for m2 in range(0,n2+1):
# for m3 in range(0,n3+1):
# n1minrg=max(abs(n2-n3),abs(m2+m3))
# n1minrg+=(n2+n3+n1minrg)%2
# n1maxrg=n2+n3

# # gauntreal[(n2,n3,m2,m3)]=[float(val) for val in [real_gaunt(n1,n2,n3,m1,m2,m3,prec=precrg) for n1 in n]]





def generate_w3j_validation(jmin,jmax,j2,j3,m1,m2,m3):
#note this requires the python module sympy
w3jval=[float(val) for val in [wigner_3j(j1,j2,j3,m1,m2,m3) for j1 in range(jmin,jmax+1)]]
return w3jval

def test_Wigner3j():
for j2 in [16,41,300]:
for j3 in [13,121]:
for m2 in [-2,100]:
for m3 in [13,5]:
if j2 < m2 or j3 < m3:
continue

daw3j=xr.DataArray.sh.wigner3j(j2,j3,m2,m3)
#note only one valid m1 should be present

m1=daw3j.m.data[0]
w3jval=generate_w3j_validation(daw3j.j.min().item(),daw3j.j.max().item(),j2,j3,m1,m2,m3)
assert(np.allclose(daw3j.data,w3jval,rtol=tol))


def generate_gaunt_validation(n,n2,n3,m1,m2,m3):
#note this requires the python module sympy
gauntval=[float(val) for val in [gaunt(n1,n2,n3,m1,m2,m3) for n1 in n]]
return gauntval

def test_Gauntnormal():
# for n2 in [10,11,300]:
# for n3 in [90,121]:
# for m2 in [-11,100]:
# for m3 in [31,5]:
nmax=8
for n2 in range(0,nmax+1):
for n3 in range(0,nmax+1):
for m2 in range(0,n2+1):
for m3 in range(0,n3+1):
if n2 < m2 or n3 < m3:
continue
dagaunt=xr.DataArray.sh.gaunt(n2,n3,m2,m3)
m1=dagaunt.m.data[0]

gauntval=generate_gaunt_validation(dagaunt.n.data,n2,n3,m1,m2,m3)
assert(np.allclose(dagaunt.data,gauntval,rtol=tol))

def generate_gauntreal_validation(n,n2,n3,m1,m2,m3):
#note this requires the python module sympy
prec=3
gauntval=[float(val) for val in [real_gaunt(n1,n2,n3,m1,m2,m3,prec=prec) for n1 in n]]
return gauntval

def test_Gauntreal():
nmax=5
for n2 in range(0,nmax+1):
for n3 in range(0,nmax+1):
for m2 in range(0,n2+1):
for m3 in range(0,n3+1):
if n2 < m2 or n3 < m3:
assert(False)
continue

dagaunt=xr.DataArray.sh.gauntReal(n2,n3,m2,m3)
ms=np.unique(dagaunt.m)
gauntval=[]
# breakpoint()
for m1 in ms:
gauntval.extend(generate_gauntreal_validation(dagaunt.sel(m=m1).n.data,n2,n3,m1,m2,m3))

# breakpoint()
closeEnough=np.allclose(dagaunt.data,gauntval,rtol=1e-3)
if not closeEnough:
breakpoint()
assert(closeEnough)
import pickle


@pytest.fixture
def wigner3jVal():
"""Create a on-disk dataset with validation coefficients for Wigner3j symbols"""
sympyfile=os.path.join(os.path.dirname(__file__),'testdata/sympy_wigner3jvalidation.pkl.gz')

wigner3j=[]
if not os.path.exists(sympyfile):
from sympy.physics.wigner import wigner_3j
#generate new data
nmax=9
for n2 in range(0,nmax+1):
for n3 in range(0,nmax+1):
for m2 in range(-n2,n2+1):
for m3 in range(-n3,n3+1):
#retrieve non-zero orders and degrees
daw3j=xr.DataArray.sh.wigner3j(n2,n3,m2,m3)
nm=daw3j.jm.data
m1=np.unique(daw3j.m)[0]
data=[float(val) for val in [wigner_3j(n1,n2,n3,m1,m2,m3) for n1 in daw3j.j]]
wigner3j.append({"nm":nm.copy(),"n2":n2,"n3":n3,"m2":m2,"m3":m3,"data":data.copy()})
# breakpoint()
with gzip.open(sympyfile,'wb') as fid:
pickle.dump(wigner3j,fid,protocol=1)
else:
with gzip.open(sympyfile,'rb') as fid:
wigner3j=pickle.load(fid)

return wigner3j



def test_W3j(wigner3jVal):
rtol=1e-8
for valdata in wigner3jVal:
daw3j=xr.DataArray.sh.wigner3j(valdata["n2"],valdata["n3"],valdata["m2"],valdata["m3"])
assert(len(valdata["data"]) == len(daw3j))
closeEnough=np.allclose(daw3j.data,valdata["data"],rtol=rtol)
assert(closeEnough)



@pytest.fixture
def gauntVal():
"""Create a on-disk dataset with validation coefficients for Gaunt coefficients"""
sympyfile=os.path.join(os.path.dirname(__file__),'testdata/sympy_gauntvalidation.pkl.gz')

gauntv=[]
if not os.path.exists(sympyfile):
from sympy.physics.wigner import gaunt
#generate new data
nmax=9
for n2 in range(0,nmax+1):
for n3 in range(0,nmax+1):
for m2 in range(-n2,n2+1):
for m3 in range(-n3,n3+1):
#retrieve non-zero orders and degrees
dagaunt=xr.DataArray.sh.gaunt(n2,n3,m2,m3)
nm=dagaunt.nm.data
m1=np.unique(dagaunt.m)[0]
data=[float(val) for val in [gaunt(n1,n2,n3,m1,m2,m3) for n1 in dagaunt.n.data]]
gauntv.append({"nm":nm.copy(),"n2":n2,"n3":n3,"m2":m2,"m3":m3,"data":data.copy()})
with gzip.open(sympyfile,'wb') as fid:
pickle.dump(gauntv,fid,protocol=1)
else:
with gzip.open(sympyfile,'rb') as fid:
gauntv=pickle.load(fid)

return gauntv


def test_Gauntnormal(gauntVal):
rtol=1e-8
for valdata in gauntVal:
dagaunt=xr.DataArray.sh.gaunt(valdata["n2"],valdata["n3"],valdata["m2"],valdata["m3"])
assert(len(valdata["data"]) == len(dagaunt))
closeEnough=np.allclose(dagaunt.data,valdata["data"],rtol=rtol)
assert(closeEnough)


@pytest.fixture
def gauntrealVal():
"""Create a on-disk dataset with validation coefficients for real Gaunt coefficients"""
sympyfile=os.path.join(os.path.dirname(__file__),'testdata/sympy_realgauntvalidation.pkl.gz')

gauntv=[]
if not os.path.exists(sympyfile):
from sympy.physics.wigner import real_gaunt
#generate new data
nmax=8
for n2 in range(0,nmax+1):
for n3 in range(0,nmax+1):
for m2 in range(-n2,n2+1):
for m3 in range(-n3,n3+1):
#retrieve non-zero orders and degrees
dagaunt=xr.DataArray.sh.gauntReal(n2,n3,m2,m3)
if len(dagaunt) == 0:
# no-nonzero values
continue
nm=dagaunt.nm.data

m1=np.unique(dagaunt.m)[0]
data=[float(val) for val in [real_gaunt(n1,n2,n3,m1,m2,m3) for n1 in dagaunt.n.data]]
gauntv.append({"nm":nm.copy(),"n2":n2,"n3":n3,"m2":m2,"m3":m3,"data":data.copy()})
with gzip.open(sympyfile,'wb') as fid:
pickle.dump(gauntv,fid,protocol=1)
else:
with gzip.open(sympyfile,'rb') as fid:
gauntv=pickle.load(fid)

return gauntv


def test_Gauntreal(gauntrealVal):
rtol=1e-8
for valdata in gauntrealVal:
dagaunt=xr.DataArray.sh.gauntReal(valdata["n2"],valdata["n3"],valdata["m2"],valdata["m3"])
assert(len(valdata["data"]) == len(dagaunt))
closeEnough=np.allclose(dagaunt.data,valdata["data"],rtol=rtol)
# m1=dagaunt.m.data[0]
# m2=valdata["m2"]
# m3=valdata["m3"]
# if not closeEnough:
# print(f'not ok {np.sign(m1*m2*m3)}')
# else:
# print(f'OK {np.sign(m1*m2*m3)}')
assert(closeEnough)
Binary file added tests/testdata/sympy_gauntvalidation.pkl.gz
Binary file not shown.
Binary file added tests/testdata/sympy_realgauntvalidation.pkl.gz
Binary file not shown.
Binary file added tests/testdata/sympy_wigner3jvalidation.pkl.gz
Binary file not shown.

0 comments on commit 1107a23

Please sign in to comment.