Skip to content

Commit

Permalink
Fix zernike_eval notebook (#1037)
Browse files Browse the repository at this point in the history
There was also a small typo (when I was trying to fix the format I miss
typed - as +).

I also changed the division inside `zernike_radial_coeffs` to be a float
division. Probably the accuracy for big numbers was an issue before, but
I used Decimal() to make division more accurate.
  • Loading branch information
YigitElma authored Jun 20, 2024
2 parents 0d1da0e + 7b164ef commit 0830262
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 455 deletions.
43 changes: 29 additions & 14 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import mpmath
import numpy as np

from desc.backend import custom_jvp, fori_loop, gammaln, jit, jnp, sign
from desc.backend import custom_jvp, fori_loop, jit, jnp, sign
from desc.io import IOAble
from desc.utils import check_nonnegint, check_posint, flatten_list

Expand Down Expand Up @@ -1389,6 +1389,16 @@ def body(k, y):
def zernike_radial_coeffs(l, m, exact=True):
"""Polynomial coefficients for radial part of zernike basis.
The for loop ranges from m to l+1 in steps of 2, as opposed to the
formula in the zernike_eval notebook. This is to make the coeffs array in
ascending powers of r, which is more natural for polynomial evaluation.
So, one should substitute s=(l-k)/s in the formula in the notebook to get
the coding implementation below.
(-1)^((l-k)/2) * ((l+k)/2)!
R_l^m(r) = sum_{k=m}^l -------------------------------------
((l-k)/2)! * ((k+m)/2)! * ((k-m)/2)!
Parameters
----------
l : ndarray of int, shape(K,)
Expand Down Expand Up @@ -1424,14 +1434,15 @@ def zernike_radial_coeffs(l, m, exact=True):
ll = lms[ii, 0]
mm = lms[ii, 1]
for s in range(mm, ll + 1, 2):
coeffs[ii, s] = (
(-1) ** ((ll - s) // 2)
* factorial((ll + s) // 2)
// (
factorial((ll - s) // 2)
* factorial((s + mm) // 2)
* factorial((s - mm) // 2)
)
# Zernike polynomials can also be written in the form of [1] which
# states that the coefficients are given by the binomial coefficients
# hence they are all integers. So, we can use exact arithmetic with integer
# division instead of floating point division.
# [1]https://en.wikipedia.org/wiki/Zernike_polynomials#Other_representations
coeffs[ii, s] = ((-1) ** ((ll - s) // 2) * factorial((ll + s) // 2)) // (
factorial((ll - s) // 2)
* factorial((s + mm) // 2)
* factorial((s - mm) // 2)
)
c = np.fliplr(np.where(lm_even, coeffs, 0))
if not exact:
Expand Down Expand Up @@ -1762,12 +1773,16 @@ def _jacobi_body_fun(kk, d_p_a_b_x):
n, alpha, beta, x = map(jnp.asarray, (n, alpha, beta, x))

# coefficient for derivative
c = (
gammaln(alpha + beta + n + 1 + dx)
- dx * jnp.log(2)
- gammaln(alpha + beta + n + 1)
coeffs = jnp.array(
[
1,
(alpha + n + 1) / 2,
(alpha + n + 2) * (alpha + n + 1) / 4,
(alpha + n + 3) * (alpha + n + 2) * (alpha + n + 1) / 8,
(alpha + n + 4) * (alpha + n + 3) * (alpha + n + 2) * (alpha + n + 1) / 16,
]
)
c = jnp.exp(c)
c = coeffs[dx]
# taking derivative is same as coeff*jacobi but for shifted n,a,b
n -= dx
alpha += dx
Expand Down
4 changes: 2 additions & 2 deletions desc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,8 @@ def combination_permutation(m, n, equals=True):
def multinomial_coefficients(m, n):
"""Number of ways to place n objects into m bins."""
k = combination_permutation(m, n)
num = factorial(n)
den = factorial(k).prod(axis=-1)
num = factorial(n, exact=True)
den = factorial(k, exact=True).prod(axis=-1)
return num / den


Expand Down
592 changes: 153 additions & 439 deletions docs/notebooks/zernike_eval.ipynb

Large diffs are not rendered by default.

Binary file modified tests/inputs/master_compute_data.pkl
Binary file not shown.

0 comments on commit 0830262

Please sign in to comment.