-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix zernike_eval notebook and Use integer division instead of gammaln() #1037
Conversation
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +0.55 +/- 10.81 | +2.92e-03 +/- 5.71e-02 | 5.31e-01 +/- 5.2e-02 | 5.28e-01 +/- 2.3e-02 |
test_build_transform_fft_midres | +0.01 +/- 7.39 | +4.68e-05 +/- 4.51e-02 | 6.11e-01 +/- 3.5e-02 | 6.11e-01 +/- 2.9e-02 |
test_build_transform_fft_highres | -0.73 +/- 5.12 | -7.34e-03 +/- 5.17e-02 | 1.00e+00 +/- 4.4e-02 | 1.01e+00 +/- 2.7e-02 |
test_equilibrium_init_lowres | -2.58 +/- 2.89 | -1.00e-01 +/- 1.12e-01 | 3.79e+00 +/- 8.2e-02 | 3.89e+00 +/- 7.7e-02 |
test_equilibrium_init_medres | -0.05 +/- 3.33 | -2.00e-03 +/- 1.43e-01 | 4.30e+00 +/- 9.5e-02 | 4.30e+00 +/- 1.1e-01 |
test_equilibrium_init_highres | -1.49 +/- 0.99 | -8.58e-02 +/- 5.67e-02 | 5.66e+00 +/- 4.4e-02 | 5.75e+00 +/- 3.6e-02 |
test_objective_compile_dshape_current | +0.79 +/- 1.46 | +3.06e-02 +/- 5.68e-02 | 3.91e+00 +/- 3.8e-02 | 3.88e+00 +/- 4.2e-02 |
test_objective_compile_atf | +0.51 +/- 2.34 | +4.24e-02 +/- 1.95e-01 | 8.37e+00 +/- 1.7e-01 | 8.33e+00 +/- 9.5e-02 |
test_objective_compute_dshape_current | -2.35 +/- 4.11 | -2.79e-05 +/- 4.88e-05 | 1.16e-03 +/- 2.6e-05 | 1.19e-03 +/- 4.1e-05 |
test_objective_compute_atf | -0.14 +/- 5.64 | -5.70e-06 +/- 2.31e-04 | 4.09e-03 +/- 1.7e-04 | 4.09e-03 +/- 1.6e-04 |
test_objective_jac_dshape_current | -1.19 +/- 10.66 | -4.39e-04 +/- 3.92e-03 | 3.64e-02 +/- 2.7e-03 | 3.68e-02 +/- 2.9e-03 |
test_objective_jac_atf | -1.18 +/- 3.64 | -2.27e-02 +/- 6.98e-02 | 1.90e+00 +/- 4.9e-02 | 1.92e+00 +/- 5.0e-02 |
test_perturb_1 | +0.55 +/- 1.14 | +7.32e-02 +/- 1.51e-01 | 1.34e+01 +/- 8.7e-02 | 1.33e+01 +/- 1.2e-01 |
test_perturb_2 | -0.18 +/- 1.46 | -3.29e-02 +/- 2.67e-01 | 1.83e+01 +/- 1.9e-01 | 1.83e+01 +/- 1.9e-01 |
test_proximal_jac_atf | -1.12 +/- 1.15 | -8.32e-02 +/- 8.53e-02 | 7.35e+00 +/- 8.2e-02 | 7.44e+00 +/- 2.5e-02 |
test_proximal_freeb_compute | +0.80 +/- 1.10 | +1.42e-03 +/- 1.95e-03 | 1.78e-01 +/- 1.6e-03 | 1.77e-01 +/- 1.1e-03 |
test_proximal_freeb_jac | +0.18 +/- 1.21 | +1.35e-02 +/- 8.92e-02 | 7.40e+00 +/- 6.6e-02 | 7.38e+00 +/- 6.0e-02 |
test_solve_fixed_iter | +1.56 +/- 6.46 | +2.31e-01 +/- 9.57e-01 | 1.50e+01 +/- 6.9e-01 | 1.48e+01 +/- 6.6e-01 | |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1037 +/- ##
=======================================
Coverage 94.86% 94.86%
=======================================
Files 87 87
Lines 21711 21711
=======================================
Hits 20597 20597
Misses 1114 1114
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undo decimal change
I found a little bit of accuracy improvement on zernike_radial function. Instead of using gammaln and then taking the exponential of it to find coefficients for the derivatives, we can just use the multiplication, as I mentioned in the updated notebook. # coefficient for derivative
coeffs = jnp.array(
[
1,
(alpha + n + 1) / 2,
(alpha + n + 2) * (alpha + n + 1) / 4,
(alpha + n + 3) * (alpha + n + 2) * (alpha + n + 1) / 8,
(alpha + n + 4) * (alpha + n + 3) * (alpha + n + 2) * (alpha + n + 1) / 16,
]
)
c = coeffs[dx] All we need is use above code instead of # coefficient for derivative
c = (
gammaln(alpha + beta + n + 1 + dx)
- dx * jnp.log(2)
- gammaln(alpha + beta + n + 1)
)
c = jnp.exp(c) This part. I also wrote and tested a version like this # coefficient for derivative
def poch(x, dx):
def body(k, val):
return val * (x + k)
return fori_loop(0, dx, body, 1)
c = poch(alpha + n + 1, dx) / (2**dx) GPU executing time doesn't change but I didn't want to risk fwd/reverse AD. @f0uriest I can add this with this or some other PR? |
Sure, I'd recommend this version: # coefficient for derivative
coeffs = jnp.array(
[
1,
(alpha + n + 1) / 2,
(alpha + n + 2) * (alpha + n + 1) / 4,
(alpha + n + 3) * (alpha + n + 2) * (alpha + n + 1) / 8,
(alpha + n + 4) * (alpha + n + 3) * (alpha + n + 2) * (alpha + n + 1) / 16,
]
)
c = coeffs[dx] to avoid loops and special function calls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, how different were the quantities that caused test_compute_everything
to fail after you changed the part in the zernike algorithm?
EDIT: nevermind I see from the tests they are negligible changes in the higher order derived quantities
|
@@ -457,8 +457,8 @@ def combination_permutation(m, n, equals=True): | |||
def multinomial_coefficients(m, n): | |||
"""Number of ways to place n objects into m bins.""" | |||
k = combination_permutation(m, n) | |||
num = factorial(n) | |||
den = factorial(k).prod(axis=-1) | |||
num = factorial(n, exact=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does exact=True
do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, we are using scipy s factorial function which is by default an approximation with gamma function. To be more accurate, we should use integer operations instead. Exact=true converts it to integer operation. https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.factorial.html
# hence they are all integers. So, we can use exact arithmetic with integer | ||
# division instead of floating point division. | ||
# [1]https://en.wikipedia.org/wiki/Zernike_polynomials#Other_representations | ||
coeffs[ii, s] = ((-1) ** ((ll - s) // 2) * factorial((ll + s) // 2)) // ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's no difference here right? Just the extra comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there is no change. Additional lines confused git, I guess.
I'm fine getting rid of the test, along with the |
Since we have some other changes that affect If you are okay, I can remove the method and the test with this PR. |
There was also a small typo (when I was trying to fix the format I miss typed - as +).
Uses integer division (the actual formula) instead of
gammaln()
(which was used for convenience) to improve the accuracy ofzernike_radial_coeffs