Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Take p_newton out of inner while loop #1165

Merged
merged 17 commits into from
Aug 21, 2024
Merged

Take p_newton out of inner while loop #1165

merged 17 commits into from
Aug 21, 2024

Conversation

YigitElma
Copy link
Collaborator

@YigitElma YigitElma commented Aug 6, 2024

Resolves #1078

Some performance improvements for QR decomposition used in optimization which was first introduced in #1050.

  • Take the p_newton calculation out of inner while loop, since it is basically calculating the same QR over and over again
  • Use proper QR update procedure for the falsefun in trust_region_step_exact_qr. That is we already now QR decomposition of J=QR, if we stack a diagonal matrix aI to J then instead of taking the whole QR decomposition again, there is a more clever way of updating the QR.There are methods for updating a QR factorization when you add rows. Suppose we have

$$ QR = J $$

what we want is

$$ \tilde{Q} \tilde{R} = \begin{pmatrix} J \\ \alpha I \end{pmatrix} $$

The QR update procedure can be implemented on a later PR with Householder matrices, but for now, it seems a bit inefficient to implement using JAX since QR is calculated by Fortran package LAPACK on Scipy and Jax, our custom QR'ish thing will be slow.

@YigitElma YigitElma marked this pull request as draft August 6, 2024 15:17
@YigitElma
Copy link
Collaborator Author

I have used Givens rotations for zeroing the elements but maybe Householder reflections are better? Maybe try implementing that.

Copy link
Contributor

github-actions bot commented Aug 6, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +1.92 +/- 5.56     | +1.05e-02 +/- 3.03e-02 |  5.56e-01 +/- 2.5e-02  |  5.46e-01 +/- 1.7e-02  |
 test_build_transform_fft_midres         |     +0.87 +/- 4.25     | +5.54e-03 +/- 2.72e-02 |  6.44e-01 +/- 1.9e-02  |  6.39e-01 +/- 2.0e-02  |
 test_build_transform_fft_highres        |     +1.76 +/- 2.48     | +1.80e-02 +/- 2.54e-02 |  1.04e+00 +/- 1.9e-02  |  1.02e+00 +/- 1.7e-02  |
 test_equilibrium_init_lowres            |     -0.45 +/- 5.89     | -1.88e-02 +/- 2.43e-01 |  4.11e+00 +/- 1.2e-01  |  4.13e+00 +/- 2.1e-01  |
 test_equilibrium_init_medres            |     -0.85 +/- 4.26     | -3.94e-02 +/- 1.97e-01 |  4.58e+00 +/- 1.4e-01  |  4.62e+00 +/- 1.4e-01  |
 test_equilibrium_init_highres           |     +0.16 +/- 4.19     | +9.31e-03 +/- 2.49e-01 |  5.96e+00 +/- 1.5e-01  |  5.96e+00 +/- 2.0e-01  |
 test_objective_compile_dshape_current   |     +1.86 +/- 2.05     | +7.26e-02 +/- 8.01e-02 |  3.98e+00 +/- 7.5e-02  |  3.91e+00 +/- 2.9e-02  |
 test_objective_compile_atf              |     +2.32 +/- 1.34     | +1.97e-01 +/- 1.13e-01 |  8.66e+00 +/- 9.8e-02  |  8.47e+00 +/- 5.6e-02  |
 test_objective_compute_dshape_current   |     +1.54 +/- 5.01     | +1.94e-05 +/- 6.29e-05 |  1.27e-03 +/- 5.1e-05  |  1.25e-03 +/- 3.7e-05  |
 test_objective_compute_atf              |     +6.33 +/- 6.47     | +2.73e-04 +/- 2.79e-04 |  4.58e-03 +/- 2.3e-04  |  4.31e-03 +/- 1.5e-04  |
 test_objective_jac_dshape_current       |     +0.16 +/- 7.26     | +6.21e-05 +/- 2.86e-03 |  3.94e-02 +/- 1.5e-03  |  3.93e-02 +/- 2.4e-03  |
 test_objective_jac_atf                  |     +2.61 +/- 3.05     | +4.93e-02 +/- 5.77e-02 |  1.94e+00 +/- 3.8e-02  |  1.89e+00 +/- 4.3e-02  |
 test_perturb_1                          |     +3.80 +/- 1.90     | +5.35e-01 +/- 2.68e-01 |  1.46e+01 +/- 1.7e-01  |  1.41e+01 +/- 2.1e-01  |
 test_perturb_2                          |     +4.14 +/- 1.76     | +7.98e-01 +/- 3.39e-01 |  2.01e+01 +/- 2.7e-01  |  1.93e+01 +/- 2.1e-01  |
 test_proximal_jac_atf                   |     +0.76 +/- 1.02     | +6.20e-02 +/- 8.29e-02 |  8.18e+00 +/- 6.6e-02  |  8.11e+00 +/- 5.0e-02  |
 test_proximal_freeb_compute             |     +1.69 +/- 1.18     | +3.04e-03 +/- 2.13e-03 |  1.83e-01 +/- 1.8e-03  |  1.80e-01 +/- 1.2e-03  |
 test_proximal_freeb_jac                 |     +0.01 +/- 1.64     | +9.30e-04 +/- 1.22e-01 |  7.44e+00 +/- 6.8e-02  |  7.44e+00 +/- 1.0e-01  |
 test_solve_fixed_iter                   |     -1.98 +/- 16.14    | -3.64e-01 +/- 2.97e+00 |  1.81e+01 +/- 2.1e+00  |  1.84e+01 +/- 2.1e+00  |

@YigitElma
Copy link
Collaborator Author

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     -2.08 +/- 9.64     | -1.10e-02 +/- 5.10e-02 |  5.18e-01 +/- 4.1e-02  |  5.29e-01 +/- 3.1e-02  |
 test_build_transform_fft_midres         |     -1.19 +/- 5.78     | -7.22e-03 +/- 3.52e-02 |  6.02e-01 +/- 2.9e-02  |  6.09e-01 +/- 2.0e-02  |
 test_build_transform_fft_highres        |     -2.08 +/- 3.18     | -2.10e-02 +/- 3.21e-02 |  9.89e-01 +/- 1.3e-02  |  1.01e+00 +/- 3.0e-02  |
 test_equilibrium_init_lowres            |     -2.47 +/- 5.04     | -9.46e-02 +/- 1.93e-01 |  3.73e+00 +/- 1.3e-01  |  3.82e+00 +/- 1.4e-01  |
 test_equilibrium_init_medres            |     -1.82 +/- 4.48     | -7.74e-02 +/- 1.91e-01 |  4.18e+00 +/- 9.5e-02  |  4.26e+00 +/- 1.7e-01  |
 test_equilibrium_init_highres           |     -1.33 +/- 2.01     | -7.53e-02 +/- 1.14e-01 |  5.59e+00 +/- 5.3e-02  |  5.66e+00 +/- 1.0e-01  |
 test_objective_compile_dshape_current   |     -1.55 +/- 3.14     | -6.12e-02 +/- 1.24e-01 |  3.89e+00 +/- 2.5e-02  |  3.95e+00 +/- 1.2e-01  |
 test_objective_compile_atf              |     -1.07 +/- 2.79     | -9.00e-02 +/- 2.36e-01 |  8.35e+00 +/- 1.0e-01  |  8.44e+00 +/- 2.1e-01  |
 test_objective_compute_dshape_current   |     -1.31 +/- 3.37     | -1.65e-05 +/- 4.25e-05 |  1.25e-03 +/- 2.5e-05  |  1.26e-03 +/- 3.5e-05  |
 test_objective_compute_atf              |     -0.65 +/- 4.70     | -2.77e-05 +/- 1.99e-04 |  4.21e-03 +/- 1.1e-04  |  4.24e-03 +/- 1.6e-04  |
 test_objective_jac_dshape_current       |     -0.23 +/- 7.42     | -8.59e-05 +/- 2.71e-03 |  3.65e-02 +/- 1.6e-03  |  3.66e-02 +/- 2.2e-03  |
 test_objective_jac_atf                  |     -0.22 +/- 2.64     | -4.22e-03 +/- 4.95e-02 |  1.87e+00 +/- 2.8e-02  |  1.88e+00 +/- 4.1e-02  |
 test_perturb_1                          |     -1.39 +/- 1.07     | -1.96e-01 +/- 1.51e-01 |  1.39e+01 +/- 8.6e-02  |  1.41e+01 +/- 1.2e-01  |
 test_perturb_2                          |     -1.95 +/- 1.07     | -3.73e-01 +/- 2.05e-01 |  1.88e+01 +/- 1.3e-01  |  1.92e+01 +/- 1.6e-01  |
 test_proximal_jac_atf                   |     +0.76 +/- 0.84     | +5.57e-02 +/- 6.13e-02 |  7.38e+00 +/- 4.8e-02  |  7.33e+00 +/- 3.9e-02  |
 test_proximal_freeb_compute             |     -0.43 +/- 1.32     | -7.79e-04 +/- 2.40e-03 |  1.81e-01 +/- 1.4e-03  |  1.82e-01 +/- 1.9e-03  |
 test_proximal_freeb_jac                 |     -0.32 +/- 1.32     | -2.36e-02 +/- 9.73e-02 |  7.36e+00 +/- 8.7e-02  |  7.38e+00 +/- 4.3e-02  |
-test_solve_fixed_iter                   |   +5639.94 +/- 7.00    | +1.04e+03 +/- 1.29e+00 |  1.06e+03 +/- 1.1e+00  |  1.85e+01 +/- 6.4e-01  |

😶🫣🤔

@YigitElma
Copy link
Collaborator Author

I think the method is correct. The new Q and R matrices are almost the same as the ones found by q,r = jax.scipy.linalg.qr(A_t). The only difference is that Some rows of the our method differs in sign, so instead of Q-Q_our==0 we have |Q|-|Q_our|==0. The double for loop (even with jax) is very slow (actually it was known but I thought it only applies to GPU). I will try to implement Househoulder reflections since they don't require nested for loops but a single one over columns.

@YigitElma
Copy link
Collaborator Author

Just take p_newton out for this PR. Maybe try Householder later

@YigitElma YigitElma marked this pull request as ready for review August 15, 2024 05:41
@YigitElma YigitElma changed the title Optimize QR decomposition Take p_newton out of inner while loop Aug 15, 2024
@YigitElma YigitElma self-assigned this Aug 15, 2024
Copy link

codecov bot commented Aug 15, 2024

Codecov Report

Attention: Patch coverage is 87.50000% with 2 lines in your changes missing coverage. Please review.

Project coverage is 95.42%. Comparing base (13108f6) to head (7f3858d).
Report is 1705 commits behind head on master.

Files with missing lines Patch % Lines
desc/optimize/aug_lagrangian_ls.py 75.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1165      +/-   ##
==========================================
- Coverage   95.43%   95.42%   -0.02%     
==========================================
  Files          87       87              
  Lines       22313    22321       +8     
==========================================
+ Hits        21294    21299       +5     
- Misses       1019     1022       +3     
Files with missing lines Coverage Δ
desc/optimize/least_squares.py 99.33% <100.00%> (+0.03%) ⬆️
desc/optimize/tr_subproblems.py 99.44% <ø> (-0.02%) ⬇️
desc/optimize/aug_lagrangian_ls.py 95.67% <75.00%> (-0.85%) ⬇️

... and 3 files with indirect coverage changes

---- 🚨 Try these New Features:

@YigitElma YigitElma added the easy Short and simple to code or review label Aug 15, 2024
@unalmis
Copy link
Collaborator

unalmis commented Aug 20, 2024

Can you add a test?

@YigitElma
Copy link
Collaborator Author

Can you add a test?

Technically I didn't change any logic. The code coverage is lower because previously the qr part was only in the trust_region_step_exact_subproblem and it was tested (or not, we don't have a test for that I guess) once. Now, the same qr part appears in 2 files (aug_lagrangian_ls.py and least_squares.py). I guess the only way to test these is to construct an optimization problem with tall/wide Jacobian for augmented lagrangian and least squares optimizers.

Copy link
Collaborator

@ddudt ddudt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the speed up with this change?

@YigitElma
Copy link
Collaborator Author

YigitElma commented Aug 20, 2024

What is the speed up with this change?

It is hard to quantify. This basically saves us from taking the QR of the same thing multiple times in the inner while loop. For most of the problems inner while loop is iterated once, so no speed up there, but for other problems with multiple iterations there is some speed up.

@unalmis unalmis removed the EZ-review label Aug 21, 2024
@ddudt ddudt merged commit 425fb02 into master Aug 21, 2024
17 of 18 checks passed
@ddudt ddudt deleted the yge/qr branch August 21, 2024 22:22
@YigitElma
Copy link
Collaborator Author

@ddudt Ok I think the total time saved can be found by (Total nfev - total iteration)*(time a single QR takes) and this is 0 for our benchmark case. But usually when I run optimizations, there are 10 15 more function evaluations than total iterations(which is equivalent to total jacobian iterations)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
easy Short and simple to code or review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Take p_newton calculation out of the inner while loop when solving trust region subproblem
5 participants