Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zernike_eval notebook and Use integer division instead of gammaln() #1037

Merged
merged 23 commits into from
Jun 20, 2024

Conversation

YigitElma
Copy link
Collaborator

@YigitElma YigitElma commented May 26, 2024

There was also a small typo (when I was trying to fix the format I miss typed - as +).

Uses integer division (the actual formula) instead of gammaln() (which was used for convenience) to improve the accuracy of zernike_radial_coeffs

@PlasmaControl PlasmaControl deleted a comment from review-notebook-app bot May 26, 2024
@YigitElma YigitElma requested a review from f0uriest May 26, 2024 23:59
Copy link
Contributor

github-actions bot commented May 27, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +0.55 +/- 10.81    | +2.92e-03 +/- 5.71e-02 |  5.31e-01 +/- 5.2e-02  |  5.28e-01 +/- 2.3e-02  |
 test_build_transform_fft_midres         |     +0.01 +/- 7.39     | +4.68e-05 +/- 4.51e-02 |  6.11e-01 +/- 3.5e-02  |  6.11e-01 +/- 2.9e-02  |
 test_build_transform_fft_highres        |     -0.73 +/- 5.12     | -7.34e-03 +/- 5.17e-02 |  1.00e+00 +/- 4.4e-02  |  1.01e+00 +/- 2.7e-02  |
 test_equilibrium_init_lowres            |     -2.58 +/- 2.89     | -1.00e-01 +/- 1.12e-01 |  3.79e+00 +/- 8.2e-02  |  3.89e+00 +/- 7.7e-02  |
 test_equilibrium_init_medres            |     -0.05 +/- 3.33     | -2.00e-03 +/- 1.43e-01 |  4.30e+00 +/- 9.5e-02  |  4.30e+00 +/- 1.1e-01  |
 test_equilibrium_init_highres           |     -1.49 +/- 0.99     | -8.58e-02 +/- 5.67e-02 |  5.66e+00 +/- 4.4e-02  |  5.75e+00 +/- 3.6e-02  |
 test_objective_compile_dshape_current   |     +0.79 +/- 1.46     | +3.06e-02 +/- 5.68e-02 |  3.91e+00 +/- 3.8e-02  |  3.88e+00 +/- 4.2e-02  |
 test_objective_compile_atf              |     +0.51 +/- 2.34     | +4.24e-02 +/- 1.95e-01 |  8.37e+00 +/- 1.7e-01  |  8.33e+00 +/- 9.5e-02  |
 test_objective_compute_dshape_current   |     -2.35 +/- 4.11     | -2.79e-05 +/- 4.88e-05 |  1.16e-03 +/- 2.6e-05  |  1.19e-03 +/- 4.1e-05  |
 test_objective_compute_atf              |     -0.14 +/- 5.64     | -5.70e-06 +/- 2.31e-04 |  4.09e-03 +/- 1.7e-04  |  4.09e-03 +/- 1.6e-04  |
 test_objective_jac_dshape_current       |     -1.19 +/- 10.66    | -4.39e-04 +/- 3.92e-03 |  3.64e-02 +/- 2.7e-03  |  3.68e-02 +/- 2.9e-03  |
 test_objective_jac_atf                  |     -1.18 +/- 3.64     | -2.27e-02 +/- 6.98e-02 |  1.90e+00 +/- 4.9e-02  |  1.92e+00 +/- 5.0e-02  |
 test_perturb_1                          |     +0.55 +/- 1.14     | +7.32e-02 +/- 1.51e-01 |  1.34e+01 +/- 8.7e-02  |  1.33e+01 +/- 1.2e-01  |
 test_perturb_2                          |     -0.18 +/- 1.46     | -3.29e-02 +/- 2.67e-01 |  1.83e+01 +/- 1.9e-01  |  1.83e+01 +/- 1.9e-01  |
 test_proximal_jac_atf                   |     -1.12 +/- 1.15     | -8.32e-02 +/- 8.53e-02 |  7.35e+00 +/- 8.2e-02  |  7.44e+00 +/- 2.5e-02  |
 test_proximal_freeb_compute             |     +0.80 +/- 1.10     | +1.42e-03 +/- 1.95e-03 |  1.78e-01 +/- 1.6e-03  |  1.77e-01 +/- 1.1e-03  |
 test_proximal_freeb_jac                 |     +0.18 +/- 1.21     | +1.35e-02 +/- 8.92e-02 |  7.40e+00 +/- 6.6e-02  |  7.38e+00 +/- 6.0e-02  |
 test_solve_fixed_iter                   |     +1.56 +/- 6.46     | +2.31e-01 +/- 9.57e-01 |  1.50e+01 +/- 6.9e-01  |  1.48e+01 +/- 6.6e-01  |

@YigitElma
Copy link
Collaborator Author

@f0uriest 3f544fd is this the reason you wanted to use "//" instead of "/"?

@YigitElma YigitElma self-assigned this May 27, 2024
Copy link

codecov bot commented May 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 94.86%. Comparing base (0d1da0e) to head (7b164ef).
Report is 1868 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #1037   +/-   ##
=======================================
  Coverage   94.86%   94.86%           
=======================================
  Files          87       87           
  Lines       21711    21711           
=======================================
  Hits        20597    20597           
  Misses       1114     1114           
Files with missing lines Coverage Δ
desc/basis.py 98.34% <100.00%> (ø)
desc/utils.py 92.06% <100.00%> (ø)

Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Undo decimal change

@YigitElma YigitElma requested review from dpanici and f0uriest May 29, 2024 22:13
@YigitElma YigitElma added the documentation Add documentation or better warnings etc. label Jun 6, 2024
@YigitElma
Copy link
Collaborator Author

YigitElma commented Jun 11, 2024

I found a little bit of accuracy improvement on zernike_radial function. Instead of using gammaln and then taking the exponential of it to find coefficients for the derivatives, we can just use the multiplication, as I mentioned in the updated notebook.
image

# coefficient for derivative
    coeffs = jnp.array(
        [
            1,
            (alpha + n + 1) / 2,
            (alpha + n + 2) * (alpha + n + 1) / 4,
            (alpha + n + 3) * (alpha + n + 2) * (alpha + n + 1) / 8,
            (alpha + n + 4) * (alpha + n + 3) * (alpha + n + 2) * (alpha + n + 1) / 16,
        ]
    )
    c = coeffs[dx]

All we need is use above code instead of

# coefficient for derivative
    c = (
        gammaln(alpha + beta + n + 1 + dx)
        - dx * jnp.log(2)
        - gammaln(alpha + beta + n + 1)
    )
    c = jnp.exp(c)

This part. I also wrote and tested a version like this

# coefficient for derivative
    def poch(x, dx):
        def body(k, val):
            return val * (x + k)

        return fori_loop(0, dx, body, 1)

    c = poch(alpha + n + 1, dx) / (2**dx)

GPU executing time doesn't change but I didn't want to risk fwd/reverse AD. @f0uriest I can add this with this or some other PR?

Here is the error plots for first derivative,
New:
image
Current:
image

@f0uriest
Copy link
Member

Sure, I'd recommend this version:

# coefficient for derivative
    coeffs = jnp.array(
        [
            1,
            (alpha + n + 1) / 2,
            (alpha + n + 2) * (alpha + n + 1) / 4,
            (alpha + n + 3) * (alpha + n + 2) * (alpha + n + 1) / 8,
            (alpha + n + 4) * (alpha + n + 3) * (alpha + n + 2) * (alpha + n + 1) / 16,
        ]
    )
    c = coeffs[dx]

to avoid loops and special function calls

Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, how different were the quantities that caused test_compute_everything to fail after you changed the part in the zernike algorithm?
EDIT: nevermind I see from the tests they are negligible changes in the higher order derived quantities

@YigitElma
Copy link
Collaborator Author

Out of curiosity, how different were the quantities that caused test_compute_everything to fail after you changed the part in the zernike algorithm? EDIT: nevermind I see from the tests they are negligible changes in the higher order derived quantities

Not equal to tolerance rtol=1e-10, atol=1e-10
Parameterization: desc.equilibrium.equilibrium.Equilibrium. Name: J^rho.
Mismatched elements: 23 / 660 (3.48%)
Max absolute difference: 2.29007128e-07
Max relative difference: 6.56482638e-09
 x: array([-4.121012e+02, -1.125881e+03, -1.537982e+03, -1.537982e+03,
       -1.125881e+03, -4.121012e+02, -6.978519e+02, -1.423246e+03,
       -1.440419e+03, -2.124563e+03, -2.423728e+03, -1.091549e+03,...
 y: array([-4.121012e+02, -1.125881e+03, -1.537982e+03, -1.537982e+03,
       -1.125881e+03, -4.121012e+02, -6.978519e+02, -1.423246e+03,
       -1.440419e+03, -2.124563e+03, -2.423728e+03, -1.091549e+03,...

Not equal to tolerance rtol=1e-10, atol=1e-10
Parameterization: desc.equilibrium.equilibrium.Equilibrium. Name: J^theta*sqrt(g).
Mismatched elements: 1 / 660 (0.152%)
Max absolute difference: 4.81682946e-07
Max relative difference: 1.12579674e-10
 x: array([-2.301279e+03, -1.684653e+03, -6.166258e+02,  6.166258e+02,
        1.684653e+03,  2.301279e+03,  1.783800e+05,  1.785994e+05,
        1.776159e+05,  1.768467e+05,  1.784834e+05,  1.805368e+05,...
 y: array([-2.301279e+03, -1.684653e+03, -6.166258e+02,  6.166258e+02,
        1.684653e+03,  2.301279e+03,  1.783800e+05,  1.785994e+05,
        1.776159e+05,  1.768467e+05,  1.784834e+05,  1.805368e+05,...

Not equal to tolerance rtol=1e-10, atol=1e-10
Parameterization: desc.equilibrium.equilibrium.Equilibrium. Name: J^theta.
Mismatched elements: 1 / 660 (0.152%)
Max absolute difference: 3.80525307e-07
Max relative difference: 1.1257934e-10
 x: array([         -inf,          -inf,          -inf,           inf,
                 inf,           inf,  1.135946e+06,  1.119998e+06,
        1.084518e+06,  1.049024e+06,  1.035086e+06,  1.034512e+06,...
 y: array([         -inf,          -inf,          -inf,           inf,
                 inf,           inf,  1.135946e+06,  1.119998e+06,
        1.084518e+06,  1.049024e+06,  1.035086e+06,  1.034512e+06,...

@dpanici
Copy link
Collaborator

dpanici commented Jun 14, 2024

@f0uriest @ddudt can we just get rid of test_1d_optimization_old? it is testing the old _optimize function which I don't quite understand why we still have?

@@ -457,8 +457,8 @@ def combination_permutation(m, n, equals=True):
def multinomial_coefficients(m, n):
"""Number of ways to place n objects into m bins."""
k = combination_permutation(m, n)
num = factorial(n)
den = factorial(k).prod(axis=-1)
num = factorial(n, exact=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does exact=True do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we are using scipy s factorial function which is by default an approximation with gamma function. To be more accurate, we should use integer operations instead. Exact=true converts it to integer operation. https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.factorial.html

# hence they are all integers. So, we can use exact arithmetic with integer
# division instead of floating point division.
# [1]https://en.wikipedia.org/wiki/Zernike_polynomials#Other_representations
coeffs[ii, s] = ((-1) ** ((ll - s) // 2) * factorial((ll + s) // 2)) // (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no difference here right? Just the extra comments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is no change. Additional lines confused git, I guess.

@f0uriest
Copy link
Member

@f0uriest @ddudt can we just get rid of test_1d_optimization_old? it is testing the old _optimize function which I don't quite understand why we still have?

I'm fine getting rid of the test, along with the _optimize method. It's still there in older versions for anyone who wants to reproduce the results from the paper.

@YigitElma
Copy link
Collaborator Author

@f0uriest @ddudt can we just get rid of test_1d_optimization_old? it is testing the old _optimize function which I don't quite understand why we still have?

I'm fine getting rid of the test, along with the _optimize method. It's still there in older versions for anyone who wants to reproduce the results from the paper.

Since we have some other changes that affect _optimize(), maybe we should refer them to some version of DESC instead of this method.

If you are okay, I can remove the method and the test with this PR.

@YigitElma YigitElma merged commit 0830262 into master Jun 20, 2024
18 checks passed
@YigitElma YigitElma deleted the yge/hotfix branch June 20, 2024 03:00
@YigitElma YigitElma changed the title Fix zernike_eval notebook Fix zernike_eval notebook and Use integer division instead of gammaln() Sep 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Add documentation or better warnings etc.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants